diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 4ec7004c1..9c45728b7 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -19,6 +19,14 @@ import NIOHTTP1 import NIOSSL final class RequestBag { + /// Defends against the call stack getting too large when consuming body parts. + /// + /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users + /// one at a time. + private static var maxConsumeBodyPartStackDepth: Int { + 50 + } + let task: HTTPClient.Task var eventLoop: EventLoop { self.task.eventLoop @@ -30,6 +38,9 @@ final class RequestBag { // the request state is synchronized on the task eventLoop private var state: StateMachine + // the consume body part stack depth is synchronized on the task event loop. + private var consumeBodyPartStackDepth: Int + // MARK: HTTPClientTask properties var logger: Logger { @@ -55,6 +66,7 @@ final class RequestBag { self.eventLoopPreference = eventLoopPreference self.task = task self.state = .init(redirectHandler: redirectHandler) + self.consumeBodyPartStackDepth = 0 self.request = request self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions @@ -290,16 +302,39 @@ final class RequestBag { private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { self.task.eventLoop.assertInEventLoop() + // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` + // future to be returned to us completed. If it is, we will recurse back into this method. To + // break that recursion we have a max stack depth which we increment and decrement in this method: + // if it gets too large, instead of recurring we'll insert an `eventLoop.execute`, which will + // manually break the recursion and unwind the stack. + // + // Note that we don't bother starting this at the various other call sites that _begin_ stacks + // that risk ending up in this loop. That's because we don't need an accurate count: our limit is + // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just + // trying to prevent ourselves looping out of control. + self.consumeBodyPartStackDepth += 1 + defer { + self.consumeBodyPartStackDepth -= 1 + assert(self.consumeBodyPartStackDepth >= 0) + } + let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) - .whenComplete { - switch $0 { + .whenComplete { result in + switch result { case .success: - self.consumeMoreBodyData0(resultOfPreviousConsume: $0) + if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } else { + // We need to unwind the stack, let's take a break. + self.task.eventLoop.execute { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } case .failure(let error): self.fail(error) } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift index e7f399658..915791cdf 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift @@ -37,6 +37,7 @@ extension HTTP2ClientTests { ("testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline", testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline), ("testStressCancelingRunningRequestFromDifferentThreads", testStressCancelingRunningRequestFromDifferentThreads), ("testPlatformConnectErrorIsForwardedOnTimeout", testPlatformConnectErrorIsForwardedOnTimeout), + ("testMassiveDownload", testMassiveDownload), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift index eb1ac2ddc..7c0e1e56f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift @@ -432,6 +432,19 @@ class HTTP2ClientTests: XCTestCase { ) } } + + func testMassiveDownload() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try client.get(url: "https://localhost:\(bin.port)/mega-chunked").wait()) + + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual(response?.body?.readableBytes, 10_000) + } } private final class HeadReceivedCallback: HTTPClientResponseDelegate { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 59336d39f..8f7d4dfce 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -745,6 +745,22 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } + func writeManyChunks(context: ChannelHandlerContext) { + // This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work. + let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) + + context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + let message = ByteBuffer(integer: UInt8(ascii: "a")) + + // This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack + // in the old implementation on all testing platforms. Please don't change it without good reason. + for _ in 0..<10_000 { + context.write(wrapOutboundOut(.body(.byteBuffer(message))), promise: nil) + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isServingRequest = true switch self.unwrapInboundIn(data) { @@ -863,6 +879,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler { case "/chunked": self.writeChunked(context: context) return + case "/mega-chunked": + self.writeManyChunks(context: context) + return case "/close-on-response": var headers = self.responseHeaders headers.replaceOrAdd(name: "connection", value: "close") diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 421060b2e..655e3acc5 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -142,6 +142,7 @@ extension HTTPClientTests { ("testRequestSpecificTLS", testRequestSpecificTLS), ("testConnectionPoolSizeConfigValueIsRespected", testConnectionPoolSizeConfigValueIsRespected), ("testRequestWithHeaderTransferEncodingIdentityDoesNotFail", testRequestWithHeaderTransferEncodingIdentityDoesNotFail), + ("testMassiveDownload", testMassiveDownload), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 8918ea042..e2e34cf00 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -3454,4 +3454,13 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try client.execute(request: request).wait()) } + + func testMassiveDownload() { + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait()) + + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http1_1) + XCTAssertEqual(response?.body?.readableBytes, 10_000) + } }