Skip to content

Commit 0679ede

Browse files
authored
Fix prepared statements (vapor#455)
1 parent 85d189c commit 0679ede

11 files changed

+150
-34
lines changed

Diff for: Sources/PostgresNIO/Connection/PostgresConnection.swift

+6-3
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ public final class PostgresConnection: @unchecked Sendable {
234234
let context = ExtendedQueryContext(
235235
name: name,
236236
query: query,
237+
bindingDataTypes: [],
237238
logger: logger,
238239
promise: promise
239240
)
@@ -472,9 +473,10 @@ extension PostgresConnection {
472473
let bindings = try preparedStatement.makeBindings()
473474
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
474475
let task = HandlerTask.executePreparedStatement(.init(
475-
name: String(reflecting: Statement.self),
476+
name: Statement.name,
476477
sql: Statement.sql,
477478
bindings: bindings,
479+
bindingDataTypes: Statement.bindingDataTypes,
478480
logger: logger,
479481
promise: promise
480482
))
@@ -493,10 +495,10 @@ extension PostgresConnection {
493495
)
494496
throw error // rethrow with more metadata
495497
}
496-
497498
}
498499

499500
/// Execute a prepared statement, taking care of the preparation when necessary
501+
@_disfavoredOverload
500502
public func execute<Statement: PostgresPreparedStatement>(
501503
_ preparedStatement: Statement,
502504
logger: Logger,
@@ -506,9 +508,10 @@ extension PostgresConnection {
506508
let bindings = try preparedStatement.makeBindings()
507509
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
508510
let task = HandlerTask.executePreparedStatement(.init(
509-
name: String(reflecting: Statement.self),
511+
name: Statement.name,
510512
sql: Statement.sql,
511513
bindings: bindings,
514+
bindingDataTypes: Statement.bindingDataTypes,
512515
logger: logger,
513516
promise: promise
514517
))

Diff for: Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct ConnectionStateMachine {
9797
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)
9898

9999
// Prepare statement actions
100-
case sendParseDescribeSync(name: String, query: String)
100+
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
101101
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
102102
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)
103103

@@ -587,7 +587,7 @@ struct ConnectionStateMachine {
587587
switch queryContext.query {
588588
case .executeStatement(_, let promise), .unnamed(_, let promise):
589589
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
590-
case .prepareStatement(_, _, let promise):
590+
case .prepareStatement(_, _, _, let promise):
591591
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
592592
}
593593
case .closeCommand(let closeContext):
@@ -1057,8 +1057,8 @@ extension ConnectionStateMachine {
10571057
return .read
10581058
case .wait:
10591059
return .wait
1060-
case .sendParseDescribeSync(name: let name, query: let query):
1061-
return .sendParseDescribeSync(name: name, query: query)
1060+
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
1061+
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
10621062
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
10631063
return .succeedPreparedStatementCreation(promise, with: rowDescription)
10641064
case .failPreparedStatementCreation(let promise, with: let error):

Diff for: Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

+7-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct ExtendedQueryStateMachine {
2626

2727
enum Action {
2828
case sendParseDescribeBindExecuteSync(PostgresQuery)
29-
case sendParseDescribeSync(name: String, query: String)
29+
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
3030
case sendBindExecuteSync(PSQLExecuteStatement)
3131

3232
// --- general actions
@@ -79,10 +79,10 @@ struct ExtendedQueryStateMachine {
7979
return .sendBindExecuteSync(prepared)
8080
}
8181

82-
case .prepareStatement(let name, let query, _):
82+
case .prepareStatement(let name, let query, let bindingDataTypes, _):
8383
return self.avoidingStateMachineCoW { state -> Action in
8484
state = .messagesSent(queryContext)
85-
return .sendParseDescribeSync(name: name, query: query)
85+
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
8686
}
8787
}
8888
}
@@ -107,7 +107,7 @@ struct ExtendedQueryStateMachine {
107107
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
108108
return .failQuery(eventLoopPromise, with: .queryCancelled)
109109

110-
case .prepareStatement(_, _, let eventLoopPromise):
110+
case .prepareStatement(_, _, _, let eventLoopPromise):
111111
return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled)
112112
}
113113

@@ -165,7 +165,7 @@ struct ExtendedQueryStateMachine {
165165
return .wait
166166
}
167167

168-
case .prepareStatement(_, _, let promise):
168+
case .prepareStatement(_, _, _, let promise):
169169
return self.avoidingStateMachineCoW { state -> Action in
170170
state = .noDataMessageReceived(queryContext)
171171
return .succeedPreparedStatementCreation(promise, with: nil)
@@ -200,7 +200,7 @@ struct ExtendedQueryStateMachine {
200200
case .unnamed, .executeStatement:
201201
return .wait
202202

203-
case .prepareStatement(_, _, let eventLoopPromise):
203+
case .prepareStatement(_, _, _, let eventLoopPromise):
204204
return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription)
205205
}
206206
}
@@ -477,7 +477,7 @@ struct ExtendedQueryStateMachine {
477477
switch context.query {
478478
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
479479
return .failQuery(eventLoopPromise, with: error)
480-
case .prepareStatement(_, _, let eventLoopPromise):
480+
case .prepareStatement(_, _, _, let eventLoopPromise):
481481
return .failPreparedStatementCreation(eventLoopPromise, with: error)
482482
}
483483
}

Diff for: Sources/PostgresNIO/New/PSQLTask.swift

+11-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ enum PSQLTask {
2121
eventLoopPromise.fail(error)
2222
case .executeStatement(_, let eventLoopPromise):
2323
eventLoopPromise.fail(error)
24-
case .prepareStatement(_, _, let eventLoopPromise):
24+
case .prepareStatement(_, _, _, let eventLoopPromise):
2525
eventLoopPromise.fail(error)
2626
}
2727

@@ -35,7 +35,7 @@ final class ExtendedQueryContext {
3535
enum Query {
3636
case unnamed(PostgresQuery, EventLoopPromise<PSQLRowStream>)
3737
case executeStatement(PSQLExecuteStatement, EventLoopPromise<PSQLRowStream>)
38-
case prepareStatement(name: String, query: String, EventLoopPromise<RowDescription?>)
38+
case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise<RowDescription?>)
3939
}
4040

4141
let query: Query
@@ -62,17 +62,19 @@ final class ExtendedQueryContext {
6262
init(
6363
name: String,
6464
query: String,
65+
bindingDataTypes: [PostgresDataType],
6566
logger: Logger,
6667
promise: EventLoopPromise<RowDescription?>
6768
) {
68-
self.query = .prepareStatement(name: name, query: query, promise)
69+
self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise)
6970
self.logger = logger
7071
}
7172
}
7273

7374
final class PreparedStatementContext: Sendable {
7475
let name: String
7576
let sql: String
77+
let bindingDataTypes: [PostgresDataType]
7678
let bindings: PostgresBindings
7779
let logger: Logger
7880
let promise: EventLoopPromise<PSQLRowStream>
@@ -81,12 +83,18 @@ final class PreparedStatementContext: Sendable {
8183
name: String,
8284
sql: String,
8385
bindings: PostgresBindings,
86+
bindingDataTypes: [PostgresDataType],
8487
logger: Logger,
8588
promise: EventLoopPromise<PSQLRowStream>
8689
) {
8790
self.name = name
8891
self.sql = sql
8992
self.bindings = bindings
93+
if bindingDataTypes.isEmpty {
94+
self.bindingDataTypes = bindings.metadata.map(\.dataType)
95+
} else {
96+
self.bindingDataTypes = bindingDataTypes
97+
}
9098
self.logger = logger
9199
self.promise = promise
92100
}

Diff for: Sources/PostgresNIO/New/PostgresChannelHandler.swift

+6-4
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
345345
self.closeConnectionAndCleanup(cleanupContext, context: context)
346346
case .fireChannelInactive:
347347
context.fireChannelInactive()
348-
case .sendParseDescribeSync(let name, let query):
349-
self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context)
348+
case .sendParseDescribeSync(let name, let query, let bindingDataTypes):
349+
self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context)
350350
case .sendBindExecuteSync(let executeStatement):
351351
self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context)
352352
case .sendParseDescribeBindExecuteSync(let query):
@@ -489,13 +489,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
489489
}
490490
}
491491

492-
private func sendParseDecribeAndSyncMessage(
492+
private func sendParseDescribeAndSyncMessage(
493493
statementName: String,
494494
query: String,
495+
bindingDataTypes: [PostgresDataType],
495496
context: ChannelHandlerContext
496497
) {
497498
precondition(self.rowStream == nil, "Expected to not have an open stream at this point")
498-
self.encoder.parse(preparedStatementName: statementName, query: query, parameters: [])
499+
self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes)
499500
self.encoder.describePreparedStatement(statementName)
500501
self.encoder.sync()
501502
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
@@ -724,6 +725,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
724725
return .extendedQuery(.init(
725726
name: preparedStatement.name,
726727
query: preparedStatement.sql,
728+
bindingDataTypes: preparedStatement.bindingDataTypes,
727729
logger: preparedStatement.logger,
728730
promise: promise
729731
))

Diff for: Sources/PostgresNIO/New/PreparedStatement.swift

+22-1
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,36 @@
2626
/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`,
2727
/// which will take care of preparing the statement on the server side and executing it.
2828
public protocol PostgresPreparedStatement: Sendable {
29+
/// The prepared statements name.
30+
///
31+
/// > Note: There is a default implementation that returns the implementor's name.
32+
static var name: String { get }
33+
2934
/// The type rows returned by the statement will be decoded into
3035
associatedtype Row
3136

3237
/// The SQL statement to prepare on the database server.
3338
static var sql: String { get }
3439

35-
/// Make the bindings to provided concrete values to use when executing the prepared SQL statement
40+
/// The postgres data types of the values that are bind when this statement is executed.
41+
///
42+
/// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned
43+
/// from ``PostgresPreparedStatement/makeBindings()``.
44+
///
45+
/// > Note: There is a default implementation that returns an empty array, which will lead to
46+
/// automatic inference.
47+
static var bindingDataTypes: [PostgresDataType] { get }
48+
49+
/// Make the bindings to provided concrete values to use when executing the prepared SQL statement.
50+
/// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``.
3651
func makeBindings() throws -> PostgresBindings
3752

3853
/// Decode a row returned by the database into an instance of `Row`
3954
func decodeRow(_ row: PostgresRow) throws -> Row
4055
}
56+
57+
extension PostgresPreparedStatement {
58+
public static var name: String { String(reflecting: self) }
59+
60+
public static var bindingDataTypes: [PostgresDataType] { [] }
61+
}

Diff for: Tests/IntegrationTests/AsyncTests.swift

+81
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,87 @@ final class AsyncPostgresConnectionTests: XCTestCase {
358358
}
359359
}
360360
}
361+
362+
static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable"
363+
func testPreparedStatementWithIntegerBinding() async throws {
364+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
365+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
366+
let eventLoop = eventLoopGroup.next()
367+
368+
struct InsertPreparedStatement: PostgresPreparedStatement {
369+
static let name = "INSERT-AsyncTestPreparedStatementTestTable"
370+
371+
static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"#
372+
typealias Row = ()
373+
374+
var uuid: UUID
375+
376+
func makeBindings() -> PostgresBindings {
377+
var bindings = PostgresBindings()
378+
bindings.append(self.uuid)
379+
return bindings
380+
}
381+
382+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
383+
()
384+
}
385+
}
386+
387+
struct SelectPreparedStatement: PostgresPreparedStatement {
388+
static let name = "SELECT-AsyncTestPreparedStatementTestTable"
389+
390+
static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"#
391+
typealias Row = (Int, UUID)
392+
393+
var id: Int
394+
395+
func makeBindings() -> PostgresBindings {
396+
var bindings = PostgresBindings()
397+
bindings.append(self.id)
398+
return bindings
399+
}
400+
401+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
402+
try row.decode((Int, UUID).self)
403+
}
404+
}
405+
406+
do {
407+
try await withTestConnection(on: eventLoop) { connection in
408+
try await connection.query("""
409+
CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" (
410+
id SERIAL PRIMARY KEY,
411+
uuid UUID NOT NULL
412+
)
413+
""",
414+
logger: .psqlTest
415+
)
416+
417+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
418+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
419+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
420+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
421+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
422+
423+
let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest)
424+
var counter = 0
425+
for try await (id, uuid) in rows {
426+
Logger.psqlTest.info("Received row", metadata: [
427+
"id": "\(id)", "uuid": "\(uuid)"
428+
])
429+
counter += 1
430+
}
431+
432+
try await connection.query("""
433+
DROP TABLE "\(unescaped: Self.preparedStatementTestTable)";
434+
""",
435+
logger: .psqlTest
436+
)
437+
}
438+
} catch {
439+
XCTFail("Unexpected error: \(String(describing: error))")
440+
}
441+
}
361442
}
362443

363444
extension XCTestCase {

Diff for: Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
1212
let name = "haha"
1313
let query = #"SELECT id FROM users WHERE id = $1 "#
1414
let prepareStatementContext = ExtendedQueryContext(
15-
name: name, query: query, logger: .psqlTest, promise: promise
15+
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
1616
)
1717

1818
XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
19-
.sendParseDescribeSync(name: name, query: query))
19+
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
2020
XCTAssertEqual(state.parseCompleteReceived(), .wait)
2121
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
2222

@@ -38,11 +38,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
3838
let name = "haha"
3939
let query = #"DELETE FROM users WHERE id = $1 "#
4040
let prepareStatementContext = ExtendedQueryContext(
41-
name: name, query: query, logger: .psqlTest, promise: promise
41+
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
4242
)
4343

4444
XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
45-
.sendParseDescribeSync(name: name, query: query))
45+
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
4646
XCTAssertEqual(state.parseCompleteReceived(), .wait)
4747
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
4848

@@ -60,11 +60,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
6060
let name = "haha"
6161
let query = #"DELETE FROM users WHERE id = $1 "#
6262
let prepareStatementContext = ExtendedQueryContext(
63-
name: name, query: query, logger: .psqlTest, promise: promise
63+
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
6464
)
6565

6666
XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
67-
.sendParseDescribeSync(name: name, query: query))
67+
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
6868
XCTAssertEqual(state.parseCompleteReceived(), .wait)
6969
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
7070

0 commit comments

Comments
 (0)