diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift index 3479d86b9..2c65a7f3b 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift @@ -131,7 +131,10 @@ class HTTPClientSOCKSTests: XCTestCase { XCTAssertNoThrow(try socksBin.shutdown()) } - // the server will send a bogus message in response to the clients request - XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) + // the server will send a bogus message in response to the clients greeting + // this will be first picked up as an invalid protocol + XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) { e in + XCTAssertTrue(e is SOCKSError.InvalidProtocolVersion) + } } } diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift index 38fa706df..d9b254698 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -22,21 +22,40 @@ struct MockSOCKSError: Error, Hashable { var description: String } +class TestSOCKSBadServerHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + // just write some nonsense bytes + let buffer = context.channel.allocator.buffer(bytes: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE]) + context.writeAndFlush(.init(buffer), promise: nil) + } +} + class MockSOCKSServer { let channel: Channel init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #file, line: UInt = #line) throws { let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let bootstrap = ServerBootstrap(group: elg) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - let handshakeHandler = SOCKSServerHandshakeHandler() - return channel.pipeline.addHandlers([ - handshakeHandler, - SOCKSTestHandler(handshakeHandler: handshakeHandler, misbehave: misbehave), - TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), - ]) - } + let bootstrap: ServerBootstrap + if misbehave { + bootstrap = ServerBootstrap(group: elg) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + channel.pipeline.addHandler(TestSOCKSBadServerHandler()) + } + } else { + bootstrap = ServerBootstrap(group: elg) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + let handshakeHandler = SOCKSServerHandshakeHandler() + return channel.pipeline.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler), + TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), + ]) + } + } self.channel = try bootstrap.bind(host: "localhost", port: 1080).wait() } @@ -49,11 +68,9 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = ClientMessage let handshakeHandler: SOCKSServerHandshakeHandler - let misbehave: Bool - init(handshakeHandler: SOCKSServerHandshakeHandler, misbehave: Bool) { + init(handshakeHandler: SOCKSServerHandshakeHandler) { self.handshakeHandler = handshakeHandler - self.misbehave = misbehave } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -69,12 +86,6 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { case .authenticationData: context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any.")) case .request(let request): - guard !self.misbehave else { - context.writeAndFlush( - .init(ServerMessage.authenticationData(context.channel.allocator.buffer(string: "bad server!"), complete: true)), promise: nil - ) - return - } context.writeAndFlush(.init( ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil) context.channel.pipeline.addHandlers([