diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 97f850c33..affe4770c 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -261,16 +261,22 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .close: context.close(promise: nil) oldRequest.succeedRequest(buffer) - case .sendRequestEnd(let writePromise): + case .sendRequestEnd(let writePromise, let shouldClose): let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) // We need to defer succeeding the old request to avoid ordering issues - writePromise.futureResult.whenComplete { result in + writePromise.futureResult.hop(to: context.eventLoop).whenComplete { result in switch result { case .success: // If our final action was `sendRequestEnd`, that means we've already received // the complete response. As a result, once we've uploaded all the body parts - // we need to tell the pool that the connection is idle. - self.connection.taskCompleted() + // we need to tell the pool that the connection is idle or, if we were asked to + // close when we're done, send the close. Either way, we then succeed the request + if shouldClose { + context.close(promise: nil) + } else { + self.connection.taskCompleted() + } + oldRequest.succeedRequest(buffer) case .failure(let error): oldRequest.fail(error) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index f0aff762c..e7258611c 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -35,7 +35,10 @@ struct HTTP1ConnectionStateMachine { /// as soon as we wrote the request end onto the wire. /// /// The promise is an optional write promise. - case sendRequestEnd(EventLoopPromise?) + /// + /// `shouldClose` records whether we have attached a Connection: close header to this request, and so the connection should + /// be terminated + case sendRequestEnd(EventLoopPromise?, shouldClose: Bool) /// Inform an observer that the connection has become idle case informConnectionIsIdle } @@ -413,7 +416,7 @@ extension HTTP1ConnectionStateMachine.State { newFinalAction = .close case .sendRequestEnd(let writePromise): self = .idle - newFinalAction = .sendRequestEnd(writePromise) + newFinalAction = .sendRequestEnd(writePromise, shouldClose: close) case .none: self = .idle newFinalAction = close ? .close : .informConnectionIsIdle diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index 55014f8c6..fd771aca0 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -338,8 +338,8 @@ extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equata switch (lhs, rhs) { case (.close, .close): return true - case (sendRequestEnd(let lhsPromise), sendRequestEnd(let rhsPromise)): - return lhsPromise?.futureResult == rhsPromise?.futureResult + case (sendRequestEnd(let lhsPromise, let lhsShouldClose), sendRequestEnd(let rhsPromise, let rhsShouldClose)): + return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsShouldClose == rhsShouldClose case (informConnectionIsIdle, informConnectionIsIdle): return true default: diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index f2cc7b1d8..c99facc3f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -1017,6 +1017,32 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { } } +final class ExpectClosureServerHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private let onClosePromise: EventLoopPromise + + init(onClosePromise: EventLoopPromise) { + self.onClosePromise = onClosePromise + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head: + let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "0"]) + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .body, .end: + () + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.onClosePromise.succeed(()) + } +} + struct EventLoopFutureTimeoutError: Error {} extension EventLoopFuture { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index a709cf2d6..603c1aa9c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -131,6 +131,7 @@ extension HTTPClientTests { ("testBiDirectionalStreaming", testBiDirectionalStreaming), ("testBiDirectionalStreamingEarly200", testBiDirectionalStreamingEarly200), ("testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests", testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests), + ("testCloseConnectionAfterEarly2XXWhenStreaming", testCloseConnectionAfterEarly2XXWhenStreaming), ("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting), ("testFileDownloadChunked", testFileDownloadChunked), ("testCloseWhileBackpressureIsExertedIsFine", testCloseWhileBackpressureIsExertedIsFine), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index e5d935fb9..8f2c7c1aa 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -3075,6 +3075,60 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try future2.wait()) } + // This test validates that we correctly close the connection after our body completes when we've streamed a + // body and received the 2XX response _before_ we finished our stream. + func testCloseConnectionAfterEarly2XXWhenStreaming() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let onClosePromise = eventLoopGroup.next().makePromise(of: Void.self) + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in ExpectClosureServerHandler(onClosePromise: onClosePromise) } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let writeEL = eventLoopGroup.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) + + func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() + + if index >= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let headers = HTTPHeaders([("Connection", "close")]) + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", headers: headers, body: body) + let future = httpClient.execute(request: request) + XCTAssertNoThrow(try future.wait()) + XCTAssertNoThrow(try onClosePromise.futureResult.wait()) + } + func testSynchronousHandshakeErrorReporting() throws { // This only affects cases where we use NIOSSL. guard !isTestingNIOTS() else { return }