19
19
@preconcurrency import SystemPackage
20
20
#endif
21
21
22
+ import _SubprocessCShims
22
23
import Synchronization
23
24
internal import Dispatch
24
25
@preconcurrency import WinSDK
@@ -71,7 +72,11 @@ final class AsyncIO: @unchecked Sendable {
71
72
// Create monitor thread
72
73
let threadContext = MonitorThreadContext ( ioCompletionPort: port)
73
74
let threadContextPtr = Unmanaged . passRetained ( threadContext)
74
- let threadHandle = CreateThread ( nil , 0 , { args in
75
+ /// Microsoft documentation for `CreateThread` states:
76
+ /// > A thread in an executable that calls the C run-time library (CRT)
77
+ /// > should use the _beginthreadex and _endthreadex functions for
78
+ /// > thread management rather than CreateThread and ExitThread
79
+ let threadHandleValue = _beginthreadex ( nil , 0 , { args in
75
80
func reportError( _ error: SubprocessError ) {
76
81
_registration. withLock { store in
77
82
for continuation in store. values {
@@ -126,18 +131,23 @@ final class AsyncIO: @unchecked Sendable {
126
131
break
127
132
}
128
133
// Notify the continuations
129
- _registration. withLock { store in
134
+ let continuation = _registration. withLock { store -> SignalStream . Continuation ? in
130
135
if let continuation = store [ targetFileDescriptor] {
131
- continuation. yield ( bytesTransferred )
136
+ return continuation
132
137
}
138
+ return nil
133
139
}
140
+ continuation? . yield ( bytesTransferred)
134
141
}
135
142
return 0
136
143
} , threadContextPtr. toOpaque ( ) , 0 , nil )
137
- guard let threadHandle = threadHandle else {
144
+ guard threadHandleValue > 0 ,
145
+ let threadHandle = HANDLE ( bitPattern: threadHandleValue) else {
146
+ // _beginthreadex uses errno instead of GetLastError()
147
+ let capturedError = _subprocess_windows_get_errno ( )
138
148
let error = SubprocessError (
139
- code: . init( . asyncIOFailed( " CreateThread failed" ) ) ,
140
- underlyingError: . init( rawValue: GetLastError ( ) )
149
+ code: . init( . asyncIOFailed( " _beginthreadex failed" ) ) ,
150
+ underlyingError: . init( rawValue: capturedError )
141
151
)
142
152
self . monitorThread = . failure( error)
143
153
return
@@ -156,10 +166,10 @@ final class AsyncIO: @unchecked Sendable {
156
166
return
157
167
}
158
168
PostQueuedCompletionStatus (
159
- ioPort,
160
- 0 ,
161
- shutdownPort,
162
- nil
169
+ ioPort, // CompletionPort
170
+ 0 , // Number of bytes transferred.
171
+ shutdownPort, // Completion key to post status
172
+ nil // Overlapped
163
173
)
164
174
// Wait for monitor thread to exit
165
175
WaitForSingleObject ( monitorThreadHandle, INFINITE)
@@ -246,26 +256,24 @@ final class AsyncIO: @unchecked Sendable {
246
256
var signalStream = self . registerHandle ( handle) . makeAsyncIterator ( )
247
257
248
258
while true {
259
+ // We use an empty `_OVERLAPPED()` here because `ReadFile` below
260
+ // only reads non-seekable files, aka pipes.
249
261
var overlapped = _OVERLAPPED ( )
250
262
let succeed = try resultBuffer. withUnsafeMutableBufferPointer { bufferPointer in
251
263
// Get a pointer to the memory at the specified offset
252
264
// Windows ReadFile uses DWORD for target count, which means we can only
253
265
// read up to DWORD (aka UInt32) max.
254
- let targetCount : DWORD
255
- if MemoryLayout< Int> . size == MemoryLayout< Int32> . size {
256
- // On 32 bit systems we don't have to worry about overflowing
257
- targetCount = DWORD ( truncatingIfNeeded: bufferPointer. count - readLength)
258
- } else {
259
- // On 64 bit systems we need to cap the count at DWORD max
260
- targetCount = DWORD ( truncatingIfNeeded: min ( bufferPointer. count - readLength, Int ( UInt32 . max) ) )
261
- }
266
+ let targetCount : DWORD = self . calculateRemainingCount (
267
+ totalCount: bufferPointer. count,
268
+ readCount: readLength
269
+ )
262
270
263
271
let offsetAddress = bufferPointer. baseAddress!. advanced ( by: readLength)
264
272
// Read directly into the buffer at the offset
265
273
return ReadFile (
266
274
handle,
267
275
offsetAddress,
268
- DWORD ( truncatingIfNeeded : targetCount) ,
276
+ targetCount,
269
277
nil ,
270
278
& overlapped
271
279
)
@@ -300,7 +308,7 @@ final class AsyncIO: @unchecked Sendable {
300
308
return resultBuffer
301
309
} else {
302
310
// Read some data
303
- readLength += Int ( bytesRead)
311
+ readLength += Int ( truncatingIfNeeded : bytesRead)
304
312
if maxLength == . max {
305
313
// Grow resultBuffer if needed
306
314
guard Double ( readLength) > 0.8 * Double( resultBuffer. count) else {
@@ -333,24 +341,22 @@ final class AsyncIO: @unchecked Sendable {
333
341
var signalStream = self . registerHandle ( diskIO. channel) . makeAsyncIterator ( )
334
342
var writtenLength : Int = 0
335
343
while true {
344
+ // We use an empty `_OVERLAPPED()` here because `WriteFile` below
345
+ // only writes to non-seekable files, aka pipes.
336
346
var overlapped = _OVERLAPPED ( )
337
347
let succeed = try span. withUnsafeBytes { ptr in
338
348
// Windows WriteFile uses DWORD for target count
339
349
// which means we can only write up to DWORD max
340
- let remainingLength : DWORD
341
- if MemoryLayout< Int> . size == MemoryLayout< Int32> . size {
342
- // On 32 bit systems we don't have to worry about overflowing
343
- remainingLength = DWORD ( truncatingIfNeeded: ptr. count - writtenLength)
344
- } else {
345
- // On 64 bit systems we need to cap the count at DWORD max
346
- remainingLength = DWORD ( truncatingIfNeeded: min ( ptr. count - writtenLength, Int ( DWORD . max) ) )
347
- }
350
+ let remainingLength : DWORD = self . calculateRemainingCount (
351
+ totalCount: ptr. count,
352
+ readCount: writtenLength
353
+ )
348
354
349
355
let startPtr = ptr. baseAddress!. advanced ( by: writtenLength)
350
356
return WriteFile (
351
357
handle,
352
358
startPtr,
353
- DWORD ( truncatingIfNeeded : remainingLength) ,
359
+ remainingLength,
354
360
nil ,
355
361
& overlapped
356
362
)
@@ -371,7 +377,7 @@ final class AsyncIO: @unchecked Sendable {
371
377
// Now wait for read to finish
372
378
let bytesWritten : DWORD = try await signalStream. next ( ) ?? 0
373
379
374
- writtenLength += Int ( bytesWritten)
380
+ writtenLength += Int ( truncatingIfNeeded : bytesWritten)
375
381
if writtenLength >= span. byteCount {
376
382
return writtenLength
377
383
}
@@ -387,23 +393,21 @@ final class AsyncIO: @unchecked Sendable {
387
393
var signalStream = self . registerHandle ( diskIO. channel) . makeAsyncIterator ( )
388
394
var writtenLength : Int = 0
389
395
while true {
396
+ // We use an empty `_OVERLAPPED()` here because `WriteFile` below
397
+ // only writes to non-seekable files, aka pipes.
390
398
var overlapped = _OVERLAPPED ( )
391
399
let succeed = try bytes. withUnsafeBytes { ptr in
392
400
// Windows WriteFile uses DWORD for target count
393
401
// which means we can only write up to DWORD max
394
- let remainingLength : DWORD
395
- if MemoryLayout< Int> . size == MemoryLayout< Int32> . size {
396
- // On 32 bit systems we don't have to worry about overflowing
397
- remainingLength = DWORD ( truncatingIfNeeded: ptr. count - writtenLength)
398
- } else {
399
- // On 64 bit systems we need to cap the count at DWORD max
400
- remainingLength = DWORD ( truncatingIfNeeded: min ( ptr. count - writtenLength, Int ( DWORD . max) ) )
401
- }
402
+ let remainingLength : DWORD = self . calculateRemainingCount (
403
+ totalCount: ptr. count,
404
+ readCount: writtenLength
405
+ )
402
406
let startPtr = ptr. baseAddress!. advanced ( by: writtenLength)
403
407
return WriteFile (
404
408
handle,
405
409
startPtr,
406
- DWORD ( truncatingIfNeeded : remainingLength) ,
410
+ remainingLength,
407
411
nil ,
408
412
& overlapped
409
413
)
@@ -423,12 +427,25 @@ final class AsyncIO: @unchecked Sendable {
423
427
}
424
428
// Now wait for read to finish
425
429
let bytesWritten : DWORD = try await signalStream. next ( ) ?? 0
426
- writtenLength += Int ( bytesWritten)
430
+ writtenLength += Int ( truncatingIfNeeded : bytesWritten)
427
431
if writtenLength >= bytes. count {
428
432
return writtenLength
429
433
}
430
434
}
431
435
}
436
+
437
+ // Windows ReadFile uses DWORD for target count, which means we can only
438
+ // read up to DWORD (aka UInt32) max.
439
+ private func calculateRemainingCount( totalCount: Int , readCount: Int ) -> DWORD {
440
+ // We support both 32bit and 64bit systems for Windows
441
+ if MemoryLayout< Int> . size == MemoryLayout< Int32> . size {
442
+ // On 32 bit systems we don't have to worry about overflowing
443
+ return DWORD ( truncatingIfNeeded: totalCount - readCount)
444
+ } else {
445
+ // On 64 bit systems we need to cap the count at DWORD max
446
+ return DWORD ( truncatingIfNeeded: min ( totalCount - readCount, Int ( DWORD . max) ) )
447
+ }
448
+ }
432
449
}
433
450
434
451
extension Array : AsyncIO . _ContiguousBytes where Element == UInt8 { }
0 commit comments