Skip to content

Commit 17b23b1

Browse files
authored
Adds prepared statement support to client (vapor#459)
1 parent 0679ede commit 17b23b1

File tree

3 files changed

+121
-4
lines changed

3 files changed

+121
-4
lines changed

Sources/PostgresNIO/Pool/PostgresClient.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,48 @@ public final class PostgresClient: Sendable {
342342
}
343343
}
344344

345+
/// Execute a prepared statement, taking care of the preparation when necessary
346+
public func execute<Statement: PostgresPreparedStatement, Row>(
347+
_ preparedStatement: Statement,
348+
logger: Logger,
349+
file: String = #fileID,
350+
line: Int = #line
351+
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Row> where Row == Statement.Row {
352+
let bindings = try preparedStatement.makeBindings()
353+
354+
do {
355+
let connection = try await self.leaseConnection()
356+
357+
let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
358+
let task = HandlerTask.executePreparedStatement(.init(
359+
name: String(reflecting: Statement.self),
360+
sql: Statement.sql,
361+
bindings: bindings,
362+
bindingDataTypes: Statement.bindingDataTypes,
363+
logger: logger,
364+
promise: promise
365+
))
366+
connection.channel.write(task, promise: nil)
367+
368+
promise.futureResult.whenFailure { _ in
369+
self.pool.releaseConnection(connection)
370+
}
371+
372+
return try await promise.futureResult
373+
.map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) }
374+
.get()
375+
.map { try preparedStatement.decodeRow($0) }
376+
} catch var error as PSQLError {
377+
error.file = file
378+
error.line = line
379+
error.query = .init(
380+
unsafeSQL: Statement.sql,
381+
binds: bindings
382+
)
383+
throw error // rethrow with more metadata
384+
}
385+
}
386+
345387
/// The client's run method. Users must call this function in order to start the client's background task processing
346388
/// like creating and destroying connections and running timers.
347389
///

Tests/IntegrationTests/PostgresClientTests.swift

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ final class PostgresClientTests: XCTestCase {
2525
await client.run()
2626
}
2727

28-
for i in 0..<10000 {
28+
let iterations = 1000
29+
30+
for i in 0..<iterations {
2931
taskGroup.addTask {
3032
try await client.withConnection() { connection in
3133
_ = try await connection.query("SELECT 1", logger: logger)
3234
}
33-
print("done: \(i)")
3435
}
3536
}
3637

37-
for _ in 0..<10000 {
38+
for _ in 0..<iterations {
3839
_ = await taskGroup.nextResult()!
3940
}
4041

@@ -78,6 +79,80 @@ final class PostgresClientTests: XCTestCase {
7879
}
7980
}
8081

82+
func testQueryTable() async throws {
83+
let tableName = "test_client_prepared_statement"
84+
85+
var mlogger = Logger(label: "test")
86+
mlogger.logLevel = .debug
87+
let logger = mlogger
88+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
89+
self.addTeardownBlock {
90+
try await eventLoopGroup.shutdownGracefully()
91+
}
92+
93+
let clientConfig = PostgresClient.Configuration.makeTestConfiguration()
94+
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
95+
do {
96+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
97+
taskGroup.addTask {
98+
await client.run()
99+
}
100+
101+
try await client.query(
102+
"""
103+
CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" (
104+
id SERIAL PRIMARY KEY,
105+
uuid UUID NOT NULL
106+
);
107+
""",
108+
logger: logger
109+
)
110+
111+
for _ in 0..<1000 {
112+
try await client.query(
113+
"""
114+
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(UUID()));
115+
""",
116+
logger: logger
117+
)
118+
}
119+
120+
let rows = try await client.query(#"SELECT id, uuid FROM "\#(unescaped: tableName)";"#, logger: logger).decode((Int, UUID).self)
121+
for try await (id, uuid) in rows {
122+
logger.info("id: \(id), uuid: \(uuid.uuidString)")
123+
}
124+
125+
struct Example: PostgresPreparedStatement {
126+
static let sql = "SELECT id, uuid FROM test_client_prepared_statement WHERE id < $1"
127+
typealias Row = (Int, UUID)
128+
var id: Int
129+
func makeBindings() -> PostgresBindings {
130+
var bindings = PostgresBindings()
131+
bindings.append(self.id)
132+
return bindings
133+
}
134+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
135+
try row.decode(Row.self)
136+
}
137+
}
138+
139+
for try await (id, uuid) in try await client.execute(Example(id: 200), logger: logger) {
140+
logger.info("id: \(id), uuid: \(uuid.uuidString)")
141+
}
142+
143+
try await client.query(
144+
"""
145+
DROP TABLE "\(unescaped: tableName)";
146+
""",
147+
logger: logger
148+
)
149+
150+
taskGroup.cancelAll()
151+
}
152+
} catch {
153+
XCTFail("Unexpected error: \(String(reflecting: error))")
154+
}
155+
}
81156
}
82157

83158
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class PostgresConnectionTests: XCTestCase {
155155
_ = try await iterator.next()
156156
XCTFail("Did not expect to not throw")
157157
} catch {
158-
print(error)
158+
self.logger.error("error", metadata: ["error": "\(error)"])
159159
}
160160
}
161161

0 commit comments

Comments
 (0)