Skip to content

Commit 779a51d

Browse files
committed
Use _beginthreadex instead of CreatThread on Windows for AsyncIO
1 parent 22e4b38 commit 779a51d

File tree

9 files changed

+124
-84
lines changed

9 files changed

+124
-84
lines changed

Sources/Subprocess/AsyncBufferSequence.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ extension AsyncBufferSequence {
153153
)
154154
}
155155
#else
156-
// Cast data to CodeUnitg type
156+
// Cast data to CodeUnit type
157157
let result = buffer.withUnsafeBytes { ptr in
158158
return ptr.withMemoryRebound(to: Encoding.CodeUnit.self) { codeUnitPtr in
159159
return Array(codeUnitPtr)

Sources/Subprocess/Configuration.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ internal func _safelyClose(_ target: _CloseTarget) throws {
602602
#if canImport(WinSDK)
603603
case .handle(let handle):
604604
/// Windows does not provide a “deregistration” API (the reverse of
605-
/// `CreateIoCompletionPort`) for handles and it it reuses HANDLE
605+
/// `CreateIoCompletionPort`) for handles and it reuses HANDLE
606606
/// values once they are closed. Since we rely on the handle value
607607
/// as the completion key for `CreateIoCompletionPort`, we should
608608
/// remove the registration when the handle is closed to allow
@@ -688,7 +688,7 @@ internal struct IODescriptor: ~Copyable {
688688
type: .stream,
689689
fileDescriptor: self.platformDescriptor(),
690690
queue: .global(),
691-
cleanupHandler: { error in
691+
cleanupHandler: { @Sendable error in
692692
// Close the file descriptor
693693
if shouldClose {
694694
try? closeFd.close()

Sources/Subprocess/IO/AsyncIO+Darwin.swift

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ final class AsyncIO: Sendable {
5757
)
5858
return
5959
}
60-
if let data = data {
60+
if let data {
6161
if buffer.isEmpty {
6262
buffer = data
6363
} else {
@@ -81,8 +81,8 @@ final class AsyncIO: Sendable {
8181
to diskIO: borrowing IOChannel
8282
) async throws -> Int {
8383
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Int, any Error>) in
84-
let dispatchData = span.withUnsafeBytes {
85-
return DispatchData(
84+
span.withUnsafeBytes {
85+
let dispatchData = DispatchData(
8686
bytesNoCopy: $0,
8787
deallocator: .custom(
8888
nil,
@@ -91,12 +91,13 @@ final class AsyncIO: Sendable {
9191
}
9292
)
9393
)
94-
}
95-
self.write(dispatchData, to: diskIO) { writtenLength, error in
96-
if let error = error {
97-
continuation.resume(throwing: error)
98-
} else {
99-
continuation.resume(returning: writtenLength)
94+
95+
self.write(dispatchData, to: diskIO) { writtenLength, error in
96+
if let error {
97+
continuation.resume(throwing: error)
98+
} else {
99+
continuation.resume(returning: writtenLength)
100+
}
100101
}
101102
}
102103
}
@@ -108,8 +109,8 @@ final class AsyncIO: Sendable {
108109
to diskIO: borrowing IOChannel
109110
) async throws -> Int {
110111
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Int, any Error>) in
111-
let dispatchData = array.withUnsafeBytes {
112-
return DispatchData(
112+
array.withUnsafeBytes {
113+
let dispatchData = DispatchData(
113114
bytesNoCopy: $0,
114115
deallocator: .custom(
115116
nil,
@@ -118,12 +119,13 @@ final class AsyncIO: Sendable {
118119
}
119120
)
120121
)
121-
}
122-
self.write(dispatchData, to: diskIO) { writtenLength, error in
123-
if let error = error {
124-
continuation.resume(throwing: error)
125-
} else {
126-
continuation.resume(returning: writtenLength)
122+
123+
self.write(dispatchData, to: diskIO) { writtenLength, error in
124+
if let error {
125+
continuation.resume(throwing: error)
126+
} else {
127+
continuation.resume(returning: writtenLength)
128+
}
127129
}
128130
}
129131
}

Sources/Subprocess/IO/AsyncIO+Linux.swift

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ final class AsyncIO: Sendable {
118118
shutdownFileDescriptor: shutdownFileDescriptor
119119
)
120120
let threadContext = Unmanaged.passRetained(context)
121-
#if os(FreeBSD) || os(OpenBSD)
122-
var thread: pthread_t? = nil
123-
#else
124121
var thread: pthread_t = pthread_t()
125-
#endif
126122
rc = pthread_create(&thread, nil, { args in
127123
func reportError(_ error: SubprocessError) {
128124
_registration.withLock { store in
@@ -175,11 +171,13 @@ final class AsyncIO: Sendable {
175171
}
176172

177173
// Notify the continuation
178-
_registration.withLock { store in
174+
let continuation = _registration.withLock { store -> SignalStream.Continuation? in
179175
if let continuation = store[targetFileDescriptor] {
180-
continuation.yield(true)
176+
return continuation
181177
}
178+
return nil
182179
}
180+
continuation?.yield(true)
183181
}
184182
}
185183

@@ -194,16 +192,10 @@ final class AsyncIO: Sendable {
194192
return
195193
}
196194

197-
#if os(FreeBSD) || os(OpenBSD)
198-
let monitorThread = thread!
199-
#else
200-
let monitorThread = thread
201-
#endif
202-
203195
let state = State(
204196
epollFileDescriptor: epollFileDescriptor,
205197
shutdownFileDescriptor: shutdownFileDescriptor,
206-
monitorThread: monitorThread
198+
monitorThread: thread
207199
)
208200
self.state = .success(state)
209201

@@ -222,6 +214,8 @@ final class AsyncIO: Sendable {
222214
_ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout<UInt64>.stride)
223215
// Cleanup the monitor thread
224216
pthread_join(currentState.monitorThread, nil)
217+
close(currentState.epollFileDescriptor)
218+
close(currentState.shutdownFileDescriptor)
225219
}
226220

227221

@@ -394,7 +388,7 @@ extension AsyncIO {
394388
resultBuffer.removeLast(resultBuffer.count - readLength)
395389
return resultBuffer
396390
} else {
397-
if errno == EAGAIN || errno == EWOULDBLOCK {
391+
if self.shouldWaitForNextSignal(with: errno) {
398392
// No more data for now wait for the next signal
399393
break
400394
} else {
@@ -443,7 +437,7 @@ extension AsyncIO {
443437
return writtenLength
444438
}
445439
} else {
446-
if errno == EAGAIN || errno == EWOULDBLOCK {
440+
if self.shouldWaitForNextSignal(with: errno) {
447441
// No more data for now wait for the next signal
448442
break
449443
} else {
@@ -486,7 +480,7 @@ extension AsyncIO {
486480
return writtenLength
487481
}
488482
} else {
489-
if errno == EAGAIN || errno == EWOULDBLOCK {
483+
if self.shouldWaitForNextSignal(with: errno) {
490484
// No more data for now wait for the next signal
491485
break
492486
} else {
@@ -500,6 +494,11 @@ extension AsyncIO {
500494
return 0
501495
}
502496
#endif
497+
498+
@inline(__always)
499+
private func shouldWaitForNextSignal(with error: CInt) -> Bool {
500+
return error == EAGAIN || error == EWOULDBLOCK || error == EINTR
501+
}
503502
}
504503

505504
extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {}

Sources/Subprocess/IO/AsyncIO+Windows.swift

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
@preconcurrency import SystemPackage
2020
#endif
2121

22+
import _SubprocessCShims
2223
import Synchronization
2324
internal import Dispatch
2425
@preconcurrency import WinSDK
@@ -71,7 +72,11 @@ final class AsyncIO: @unchecked Sendable {
7172
// Create monitor thread
7273
let threadContext = MonitorThreadContext(ioCompletionPort: port)
7374
let threadContextPtr = Unmanaged.passRetained(threadContext)
74-
let threadHandle = CreateThread(nil, 0, { args in
75+
/// Microsoft documentation for `CreateThread` states:
76+
/// > A thread in an executable that calls the C run-time library (CRT)
77+
/// > should use the _beginthreadex and _endthreadex functions for
78+
/// > thread management rather than CreateThread and ExitThread
79+
let threadHandleValue = _beginthreadex(nil, 0, { args in
7580
func reportError(_ error: SubprocessError) {
7681
_registration.withLock { store in
7782
for continuation in store.values {
@@ -126,18 +131,23 @@ final class AsyncIO: @unchecked Sendable {
126131
break
127132
}
128133
// Notify the continuations
129-
_registration.withLock { store in
134+
let continuation = _registration.withLock { store -> SignalStream.Continuation? in
130135
if let continuation = store[targetFileDescriptor] {
131-
continuation.yield(bytesTransferred)
136+
return continuation
132137
}
138+
return nil
133139
}
140+
continuation?.yield(bytesTransferred)
134141
}
135142
return 0
136143
}, threadContextPtr.toOpaque(), 0, nil)
137-
guard let threadHandle = threadHandle else {
144+
guard threadHandleValue > 0,
145+
let threadHandle = HANDLE(bitPattern: threadHandleValue) else {
146+
// _beginthreadex uses errno instead of GetLastError()
147+
let capturedError = _subprocess_windows_get_errno()
138148
let error = SubprocessError(
139-
code: .init(.asyncIOFailed("CreateThread failed")),
140-
underlyingError: .init(rawValue: GetLastError())
149+
code: .init(.asyncIOFailed("_beginthreadex failed")),
150+
underlyingError: .init(rawValue: capturedError)
141151
)
142152
self.monitorThread = .failure(error)
143153
return
@@ -156,10 +166,10 @@ final class AsyncIO: @unchecked Sendable {
156166
return
157167
}
158168
PostQueuedCompletionStatus(
159-
ioPort,
160-
0,
161-
shutdownPort,
162-
nil
169+
ioPort, // CompletionPort
170+
0, // Number of bytes transferred.
171+
shutdownPort, // Completion key to post status
172+
nil // Overlapped
163173
)
164174
// Wait for monitor thread to exit
165175
WaitForSingleObject(monitorThreadHandle, INFINITE)
@@ -246,26 +256,24 @@ final class AsyncIO: @unchecked Sendable {
246256
var signalStream = self.registerHandle(handle).makeAsyncIterator()
247257

248258
while true {
259+
// We use an empty `_OVERLAPPED()` here because `ReadFile` below
260+
// only reads non-seekable files, aka pipes.
249261
var overlapped = _OVERLAPPED()
250262
let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in
251263
// Get a pointer to the memory at the specified offset
252264
// Windows ReadFile uses DWORD for target count, which means we can only
253265
// read up to DWORD (aka UInt32) max.
254-
let targetCount: DWORD
255-
if MemoryLayout<Int>.size == MemoryLayout<Int32>.size {
256-
// On 32 bit systems we don't have to worry about overflowing
257-
targetCount = DWORD(truncatingIfNeeded: bufferPointer.count - readLength)
258-
} else {
259-
// On 64 bit systems we need to cap the count at DWORD max
260-
targetCount = DWORD(truncatingIfNeeded: min(bufferPointer.count - readLength, Int(UInt32.max)))
261-
}
266+
let targetCount: DWORD = self.calculateRemainingCount(
267+
totalCount: bufferPointer.count,
268+
readCount: readLength
269+
)
262270

263271
let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength)
264272
// Read directly into the buffer at the offset
265273
return ReadFile(
266274
handle,
267275
offsetAddress,
268-
DWORD(truncatingIfNeeded: targetCount),
276+
targetCount,
269277
nil,
270278
&overlapped
271279
)
@@ -300,7 +308,7 @@ final class AsyncIO: @unchecked Sendable {
300308
return resultBuffer
301309
} else {
302310
// Read some data
303-
readLength += Int(bytesRead)
311+
readLength += Int(truncatingIfNeeded: bytesRead)
304312
if maxLength == .max {
305313
// Grow resultBuffer if needed
306314
guard Double(readLength) > 0.8 * Double(resultBuffer.count) else {
@@ -333,24 +341,22 @@ final class AsyncIO: @unchecked Sendable {
333341
var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator()
334342
var writtenLength: Int = 0
335343
while true {
344+
// We use an empty `_OVERLAPPED()` here because `WriteFile` below
345+
// only writes to non-seekable files, aka pipes.
336346
var overlapped = _OVERLAPPED()
337347
let succeed = try span.withUnsafeBytes { ptr in
338348
// Windows WriteFile uses DWORD for target count
339349
// which means we can only write up to DWORD max
340-
let remainingLength: DWORD
341-
if MemoryLayout<Int>.size == MemoryLayout<Int32>.size {
342-
// On 32 bit systems we don't have to worry about overflowing
343-
remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength)
344-
} else {
345-
// On 64 bit systems we need to cap the count at DWORD max
346-
remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max)))
347-
}
350+
let remainingLength: DWORD = self.calculateRemainingCount(
351+
totalCount: ptr.count,
352+
readCount: writtenLength
353+
)
348354

349355
let startPtr = ptr.baseAddress!.advanced(by: writtenLength)
350356
return WriteFile(
351357
handle,
352358
startPtr,
353-
DWORD(truncatingIfNeeded: remainingLength),
359+
remainingLength,
354360
nil,
355361
&overlapped
356362
)
@@ -371,7 +377,7 @@ final class AsyncIO: @unchecked Sendable {
371377
// Now wait for read to finish
372378
let bytesWritten: DWORD = try await signalStream.next() ?? 0
373379

374-
writtenLength += Int(bytesWritten)
380+
writtenLength += Int(truncatingIfNeeded: bytesWritten)
375381
if writtenLength >= span.byteCount {
376382
return writtenLength
377383
}
@@ -387,23 +393,21 @@ final class AsyncIO: @unchecked Sendable {
387393
var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator()
388394
var writtenLength: Int = 0
389395
while true {
396+
// We use an empty `_OVERLAPPED()` here because `WriteFile` below
397+
// only writes to non-seekable files, aka pipes.
390398
var overlapped = _OVERLAPPED()
391399
let succeed = try bytes.withUnsafeBytes { ptr in
392400
// Windows WriteFile uses DWORD for target count
393401
// which means we can only write up to DWORD max
394-
let remainingLength: DWORD
395-
if MemoryLayout<Int>.size == MemoryLayout<Int32>.size {
396-
// On 32 bit systems we don't have to worry about overflowing
397-
remainingLength = DWORD(truncatingIfNeeded: ptr.count - writtenLength)
398-
} else {
399-
// On 64 bit systems we need to cap the count at DWORD max
400-
remainingLength = DWORD(truncatingIfNeeded: min(ptr.count - writtenLength, Int(DWORD.max)))
401-
}
402+
let remainingLength: DWORD = self.calculateRemainingCount(
403+
totalCount: ptr.count,
404+
readCount: writtenLength
405+
)
402406
let startPtr = ptr.baseAddress!.advanced(by: writtenLength)
403407
return WriteFile(
404408
handle,
405409
startPtr,
406-
DWORD(truncatingIfNeeded: remainingLength),
410+
remainingLength,
407411
nil,
408412
&overlapped
409413
)
@@ -423,12 +427,25 @@ final class AsyncIO: @unchecked Sendable {
423427
}
424428
// Now wait for read to finish
425429
let bytesWritten: DWORD = try await signalStream.next() ?? 0
426-
writtenLength += Int(bytesWritten)
430+
writtenLength += Int(truncatingIfNeeded: bytesWritten)
427431
if writtenLength >= bytes.count {
428432
return writtenLength
429433
}
430434
}
431435
}
436+
437+
// Windows ReadFile uses DWORD for target count, which means we can only
438+
// read up to DWORD (aka UInt32) max.
439+
private func calculateRemainingCount(totalCount: Int, readCount: Int) -> DWORD {
440+
// We support both 32bit and 64bit systems for Windows
441+
if MemoryLayout<Int>.size == MemoryLayout<Int32>.size {
442+
// On 32 bit systems we don't have to worry about overflowing
443+
return DWORD(truncatingIfNeeded: totalCount - readCount)
444+
} else {
445+
// On 64 bit systems we need to cap the count at DWORD max
446+
return DWORD(truncatingIfNeeded: min(totalCount - readCount, Int(DWORD.max)))
447+
}
448+
}
432449
}
433450

434451
extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {}

0 commit comments

Comments
 (0)