diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 4da73531a..78100c6f5 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -108,63 +108,52 @@ extension HTTPClient { private static var hostRestrictedSchemes: Set = ["http", "https"] private static var allSupportedSchemes: Set = ["http", "https", "unix", "http+unix", "https+unix"] - init(forScheme scheme: String) throws { - switch scheme { - case "http", "https": self = .host - case "unix": self = .unixSocket(.baseURL) - case "http+unix": self = .unixSocket(.http_unix) - case "https+unix": self = .unixSocket(.https_unix) - default: - throw HTTPClientError.unsupportedScheme(scheme) - } - } + func supportsRedirects(to scheme: String?) -> Bool { + guard let scheme = scheme?.lowercased() else { return false } - func hostFromURL(_ url: URL) throws -> String { switch self { case .host: - guard let host = url.host else { - throw HTTPClientError.emptyHost - } - return host + return Kind.hostRestrictedSchemes.contains(scheme) case .unixSocket: - return "" + return Kind.allSupportedSchemes.contains(scheme) } } + } - func socketPathFromURL(_ url: URL) throws -> String { - switch self { - case .unixSocket(.baseURL): - return url.baseURL?.path ?? url.path - case .unixSocket: - guard let socketPath = url.host else { - throw HTTPClientError.missingSocketPath - } - return socketPath - case .host: - return "" - } - } + static func useTLS(_ scheme: String) -> Bool { + return scheme == "https" || scheme == "https+unix" + } - func uriFromURL(_ url: URL) -> String { - switch self { - case .host: - return url.uri - case .unixSocket(.baseURL): - return url.baseURL != nil ? url.uri : "/" - case .unixSocket: - return url.uri - } + static func deconstructURL( + _ url: URL + ) throws -> ( + kind: Kind, scheme: String, hostname: String, port: Int, socketPath: String, uri: String + ) { + guard let scheme = url.scheme?.lowercased() else { + throw HTTPClientError.emptyScheme } - - func supportsRedirects(to scheme: String?) -> Bool { - guard let scheme = scheme?.lowercased() else { return false } - - switch self { - case .host: - return Kind.hostRestrictedSchemes.contains(scheme) - case .unixSocket: - return Kind.allSupportedSchemes.contains(scheme) + switch scheme { + case "http", "https": + guard let host = url.host, !host.isEmpty else { + throw HTTPClientError.emptyHost + } + let defaultPort = self.useTLS(scheme) ? 443 : 80 + return (.host, scheme, host, url.port ?? defaultPort, "", url.uri) + case "http+unix", "https+unix": + guard let socketPath = url.host, !socketPath.isEmpty else { + throw HTTPClientError.missingSocketPath + } + let (kind, defaultPort) = self.useTLS(scheme) ? (Kind.UnixScheme.https_unix, 443) : (.http_unix, 80) + return (.unixSocket(kind), scheme, "", url.port ?? defaultPort, socketPath, url.uri) + case "unix": + let socketPath = url.baseURL?.path ?? url.path + let uri = url.baseURL != nil ? url.uri : "/" + guard !socketPath.isEmpty else { + throw HTTPClientError.missingSocketPath } + return (.unixSocket(.baseURL), scheme, "", url.port ?? 80, socketPath, uri) + default: + throw HTTPClientError.unsupportedScheme(url.scheme!) } } @@ -176,6 +165,8 @@ extension HTTPClient { public let scheme: String /// Remote host, resolved from `URL`. public let host: String + /// Resolved port. + public let port: Int /// Socket path, resolved from `URL`. let socketPath: String /// URI composed of the path and query, resolved from `URL`. @@ -264,19 +255,10 @@ extension HTTPClient { /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { - guard let scheme = url.scheme?.lowercased() else { - throw HTTPClientError.emptyScheme - } - - self.kind = try Kind(forScheme: scheme) - self.host = try self.kind.hostFromURL(url) - self.socketPath = try self.kind.socketPathFromURL(url) - self.uri = self.kind.uriFromURL(url) - + (self.kind, self.scheme, self.host, self.port, self.socketPath, self.uri) = try Request.deconstructURL(url) self.redirectState = nil self.url = url self.method = method - self.scheme = scheme self.headers = headers self.body = body self.tlsConfiguration = tlsConfiguration @@ -284,12 +266,7 @@ extension HTTPClient { /// Whether request will be executed using secure socket. public var useTLS: Bool { - return self.scheme == "https" || self.scheme == "https+unix" - } - - /// Resolved port. - public var port: Int { - return self.url.port ?? (self.useTLS ? 443 : 80) + return Request.useTLS(self.scheme) } func createRequestHead() throws -> (HTTPRequestHead, RequestFramingMetadata) { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 5e762c3bb..22659d32c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -29,6 +29,7 @@ extension HTTPClientTests { ("testBadRequestURI", testBadRequestURI), ("testSchemaCasing", testSchemaCasing), ("testURLSocketPathInitializers", testURLSocketPathInitializers), + ("testBadUnixWithBaseURL", testBadUnixWithBaseURL), ("testConvenienceExecuteMethods", testConvenienceExecuteMethods), ("testConvenienceExecuteMethodsOverSocket", testConvenienceExecuteMethodsOverSocket), ("testConvenienceExecuteMethodsOverSecureSocket", testConvenienceExecuteMethodsOverSecureSocket), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 16898de14..3c8b49e44 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -215,6 +215,14 @@ class HTTPClientTests: XCTestCase { XCTAssertNil(url10) } + func testBadUnixWithBaseURL() { + let badUnixBaseURL = URL(string: "/foo", relativeTo: URL(string: "unix:")!)! + XCTAssertEqual(badUnixBaseURL.baseURL?.path, "") + XCTAssertThrowsError(try Request(url: badUnixBaseURL)) { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.missingSocketPath) + } + } + func testConvenienceExecuteMethods() throws { XCTAssertNoThrow(XCTAssertEqual(["GET"[...]], try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]))