Skip to content

Commit 8032167

Browse files
committed
Support binary data transfer in COPY FROM
My benchmark of transferring the integers from 0 to 1,000,000 both as an integer and as a string was about the same speed as the old text-based transfer. I believe that the binary transfer will start to show significant benefits when transferring binary data, other fields that don't need to be represented as fields and also means that the user doesn't need to worry about escapping their data.
1 parent 1602d85 commit 8032167

File tree

3 files changed

+260
-0
lines changed

3 files changed

+260
-0
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,110 @@
1+
import NIO
2+
3+
#if compiler(>=6.0)
4+
/// Handle to send binary data for a `COPY ... FROM STDIN` query to the backend.
5+
///
6+
/// It takes care of serializing `PostgresEncodable` column types into the binary format that Postgres expects.
7+
public struct PostgresBinaryCopyFromWriter: ~Copyable {
8+
/// Handle to serialize columns into a row that is being written by `PostgresBinaryCopyFromWriter`.
9+
public struct ColumnWriter: ~Copyable {
10+
/// The `PostgresBinaryCopyFromWriter` that is gathering the serialized data.
11+
///
12+
/// We need to model this as `UnsafeMutablePointer` because we can't express in the Swift type system that
13+
/// `ColumnWriter` never exceeds the lifetime of `PostgresBinaryCopyFromWriter`.
14+
@usableFromInline
15+
let underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>
16+
17+
/// The number of columns that have been written by this `ColumnWriter`.
18+
@usableFromInline
19+
var columns: UInt16 = 0
20+
21+
@usableFromInline
22+
init(underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>) {
23+
self.underlying = underlying
24+
}
25+
26+
/// Serialize a single column to a row.
27+
///
28+
/// - Important: It is critical that that data type encoded here exactly matches the data type in the
29+
/// databasse. For example, if the database stores an a 4-bit integer the corresponding `writeColumn` must
30+
/// be called with an `Int32`. Serializing an integer of a different width will cause a deserialization
31+
/// failure in the backend.
32+
@inlinable
33+
public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
34+
columns += 1
35+
try underlying.pointee.writeColumn(column)
36+
}
37+
}
38+
39+
/// The underlying `PostgresCopyFromWriter` that sends the serialized data to the backend.
40+
@usableFromInline let underlying: PostgresCopyFromWriter
41+
42+
/// The buffer in which we accumulate binary data. Once this buffer exceeds `bufferSize`, we flush it to
43+
/// the backend.
44+
@usableFromInline var buffer = ByteBuffer()
45+
46+
/// Once `buffer` exceeds this size, it gets flushed to the backend.
47+
@usableFromInline let bufferSize: Int
48+
49+
init(underlying: PostgresCopyFromWriter, bufferSize: Int) {
50+
self.underlying = underlying
51+
// Allocate 10% more than the buffer size because we only flush the buffer once it has exceeded `bufferSize`
52+
buffer.reserveCapacity(bufferSize + bufferSize / 10)
53+
self.bufferSize = bufferSize
54+
}
55+
56+
/// Serialize a single row to the backend. Call `writeColumn` on `columnWriter` for every column that should be
57+
/// included in the row.
58+
@inlinable
59+
public mutating func writeRow(_ body: (_ columnWriter: inout ColumnWriter) throws -> Void) async throws {
60+
// Write a placeholder for the number of columns
61+
let columnIndex = buffer.writerIndex
62+
buffer.writeInteger(UInt16(0))
63+
64+
let columns = try withUnsafeMutablePointer(to: &self) { pointerToSelf in
65+
// Important: We need to ensure that `pointerToSel` (and thus `ColumnWriter`) does not exceed the lifetime
66+
// of `self` because it is holding an unsafe reference to it.
67+
//
68+
// We achieve this because `ColumnWriter` is non-Copyable and thus the client can't store a copy to it.
69+
// Futhermore `columnWriter` is destroyed before the end of `withUnsafeMutablePointer`, which holds `self`
70+
// alive.
71+
var columnWriter = ColumnWriter(underlying: pointerToSelf)
72+
73+
try body(&columnWriter)
74+
75+
return columnWriter.columns
76+
}
77+
78+
// Fill in the number of columns
79+
buffer.setInteger(columns, at: columnIndex)
80+
81+
if buffer.readableBytes > bufferSize {
82+
try await flush()
83+
}
84+
}
85+
86+
/// Serialize a single column to the buffer. Should only be called by `ColumnWriter`.
87+
@inlinable
88+
mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
89+
if let column {
90+
let sizeIndex = buffer.readableBytes
91+
buffer.writeInteger(Int32(0))
92+
try column.encode(into: &buffer, context: .default)
93+
buffer.setInteger(Int32(buffer.readableBytes - sizeIndex - 4), at: sizeIndex)
94+
} else {
95+
buffer.writeInteger(Int32(-1))
96+
}
97+
}
98+
99+
/// Flush any pending data in the buffer to the backend.
100+
@usableFromInline
101+
mutating func flush(isolation: (any Actor)? = #isolation) async throws {
102+
try await underlying.write(buffer)
103+
buffer.clear()
104+
}
105+
}
106+
#endif
107+
1108
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
2109
public struct PostgresCopyFromWriter: Sendable {
3110
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
@@ -115,15 +222,25 @@ public struct PostgresCopyFromFormat: Sendable {
115222
public init() {}
116223
}
117224

225+
/// Options that can be used to modify the `binary` format of a COPY operation.
226+
public struct BinaryOptions: Sendable {
227+
public init() {}
228+
}
229+
118230
enum Format {
119231
case text(TextOptions)
232+
case binary(BinaryOptions)
120233
}
121234

122235
var format: Format
123236

124237
public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
125238
return PostgresCopyFromFormat(format: .text(options))
126239
}
240+
241+
public static func binary(_ options: BinaryOptions) -> PostgresCopyFromFormat {
242+
return PostgresCopyFromFormat(format: .binary(options))
243+
}
127244
}
128245

129246
#if compiler(>=6.0)
@@ -156,6 +273,8 @@ private func buildCopyFromQuery(
156273
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
157274
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
158275
}
276+
case .binary:
277+
queryOptions.append("FORMAT binary")
159278
}
160279
precondition(!queryOptions.isEmpty)
161280
query += " WITH ("
@@ -165,6 +284,50 @@ private func buildCopyFromQuery(
165284
}
166285

167286
extension PostgresConnection {
287+
/// Copy data into a table using a `COPY <table name> FROM STDIN` query, transferring data in a binary format.
288+
///
289+
/// - Parameters:
290+
/// - table: The name of the table into which to copy the data.
291+
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
292+
/// - bufferSize: How many bytes to accumulate a local buffer before flushing it to the database. Can affect
293+
/// performance characteristics of the copy operation.
294+
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
295+
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
296+
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
297+
/// by the `copyFrom` function.
298+
///
299+
/// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be
300+
/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings.
301+
public func copyFromBinary(
302+
table: String,
303+
columns: [String] = [],
304+
options: PostgresCopyFromFormat.BinaryOptions = .init(),
305+
bufferSize: Int = 100_000,
306+
logger: Logger,
307+
isolation: isolated (any Actor)? = #isolation,
308+
file: String = #fileID,
309+
line: Int = #line,
310+
writeData: @escaping @Sendable (inout PostgresBinaryCopyFromWriter) async throws -> Void
311+
) async throws {
312+
try await copyFrom(table: table, columns: columns, format: .binary(PostgresCopyFromFormat.BinaryOptions()), logger: logger) { writer in
313+
var header = ByteBuffer()
314+
header.writeString("PGCOPY\n")
315+
header.writeInteger(UInt8(0xff))
316+
header.writeString("\r\n\0")
317+
318+
// Flag fields
319+
header.writeInteger(UInt32(0))
320+
321+
// Header extension area length
322+
header.writeInteger(UInt32(0))
323+
try await writer.write(header)
324+
325+
var binaryWriter = PostgresBinaryCopyFromWriter(underlying: writer, bufferSize: bufferSize)
326+
try await writeData(&binaryWriter)
327+
try await binaryWriter.flush()
328+
}
329+
}
330+
168331
/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
169332
///
170333
/// - Parameters:

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,5 +488,40 @@ final class IntegrationTests: XCTestCase {
488488
XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror
489489
}
490490
}
491+
492+
func testCopyFromBinary() async throws {
493+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
494+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
495+
let eventLoop = eventLoopGroup.next()
496+
497+
let conn = try await PostgresConnection.test(on: eventLoop).get()
498+
defer { XCTAssertNoThrow(try conn.close().wait()) }
499+
500+
_ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get()
501+
_ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get()
502+
503+
try await conn.copyFromBinary(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in
504+
let records: [(id: Int, name: String)] = [
505+
(1, "Alice"),
506+
(42, "Bob")
507+
]
508+
for record in records {
509+
try await writer.writeRow { columnWriter in
510+
try columnWriter.writeColumn(Int32(record.id))
511+
try columnWriter.writeColumn(record.name)
512+
}
513+
}
514+
}
515+
let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) }
516+
guard rows.count == 2 else {
517+
XCTFail("Expected 2 columns, received \(rows.count)")
518+
return
519+
}
520+
XCTAssertEqual(rows[0].0, 1)
521+
XCTAssertEqual(rows[0].1, "Alice")
522+
XCTAssertEqual(rows[1].0, 42)
523+
XCTAssertEqual(rows[1].1, "Bob")
524+
}
525+
491526
#endif
492527
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,68 @@ class PostgresConnectionTests: XCTestCase {
912912
try await connection.closeFuture.get()
913913
}
914914
}
915+
916+
func testCopyFromBinary() async throws {
917+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
918+
919+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
920+
taskGroup.addTask {
921+
try await connection.copyFromBinary(table: "copy_table", logger: .psqlTest) { writer in
922+
try await writer.writeRow { columnWriter in
923+
try columnWriter.writeColumn(Int32(1))
924+
try columnWriter.writeColumn("Alice")
925+
}
926+
try await writer.writeRow { columnWriter in
927+
try columnWriter.writeColumn(Int32(2))
928+
try columnWriter.writeColumn("Bob")
929+
}
930+
}
931+
}
932+
933+
let copyRequest = try await channel.waitForUnpreparedRequest()
934+
XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table" FROM STDIN WITH (FORMAT binary)"#)
935+
936+
try await channel.sendUnpreparedRequestWithNoParametersBindResponse()
937+
try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])))
938+
939+
let copyData = try await channel.waitForCopyData()
940+
XCTAssertEqual(copyData.result, .done)
941+
var data = copyData.data
942+
// Signature
943+
XCTAssertEqual(data.readString(length: 7), "PGCOPY\n")
944+
XCTAssertEqual(data.readInteger(as: UInt8.self), 0xff)
945+
XCTAssertEqual(data.readString(length: 3), "\r\n\0")
946+
// Flags
947+
XCTAssertEqual(data.readInteger(as: UInt32.self), 0)
948+
// Header extension area length
949+
XCTAssertEqual(data.readInteger(as: UInt32.self), 0)
950+
951+
struct Row: Equatable {
952+
let id: Int32
953+
let name: String
954+
}
955+
var rows: [Row] = []
956+
while data.readableBytes > 0 {
957+
// Number of columns
958+
XCTAssertEqual(data.readInteger(as: UInt16.self), 2)
959+
// 'id' column
960+
XCTAssertEqual(data.readInteger(as: UInt32.self), 4)
961+
let id = try XCTUnwrap(data.readInteger(as: Int32.self))
962+
// 'name' column length
963+
let nameLength = try XCTUnwrap(data.readInteger(as: UInt32.self))
964+
let name = try XCTUnwrap(data.readString(length: Int(nameLength)))
965+
rows.append(Row(id: id, name: name))
966+
}
967+
XCTAssertEqual(rows, [
968+
Row(id: 1, name: "Alice"),
969+
Row(id: 2, name: "Bob")
970+
])
971+
try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1"))
972+
973+
try await channel.waitForPostgresFrontendMessage(\.sync)
974+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
975+
}
976+
}
915977
#endif
916978

917979
func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {

0 commit comments

Comments
 (0)