diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 73d6a206..9c6ce553 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -29,7 +29,7 @@ final class PostgresRequestContext { } } -final class PostgresRequestHandler: ChannelDuplexHandler { +class PostgresRequestHandler: ChannelDuplexHandler { typealias InboundIn = PostgresMessage typealias OutboundIn = PostgresRequestContext typealias OutboundOut = PostgresMessage diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift new file mode 100644 index 00000000..881f98c3 --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift @@ -0,0 +1,34 @@ +import NIO + + +/// PostgreSQL request to close a prepared statement or portal. +final class CloseRequest: PostgresRequest { + + /// Name of the prepared statement or portal to close. + let name: String + + /// Close + let target: PostgresMessage.Close.Target + + init(name: String, closeType: PostgresMessage.Close.Target) { + self.name = name + self.target = closeType + } + + func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { + if message.identifier != .closeComplete { + fatalError("Unexpected PostgreSQL message \(message)") + } + return nil + } + + func start() throws -> [PostgresMessage] { + let close = try PostgresMessage.Close(target: target, name: name).message() + let sync = try PostgresMessage.Sync().message() + return [close, sync] + } + + func log(to logger: Logger) { + logger.debug("Requesting Close of \(name)") + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift new file mode 100644 index 00000000..6552e433 --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -0,0 +1,164 @@ +import Foundation + +extension PostgresDatabase { + public func prepare(query: String) -> EventLoopFuture { + let name = "nio-postgres-\(UUID().uuidString)" + let prepare = PrepareQueryRequest(query, as: name) + return self.send(prepare, logger: self.logger).map { () -> (PreparedQuery) in + let prepared = PreparedQuery(database: self, name: name, rowDescription: prepare.rowLookupTable!) + return prepared + } + } + + public func prepare(query: String, handler: @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { + prepare(query: query) + .flatMap { preparedQuery in + handler(preparedQuery) + .flatMap { results in + preparedQuery.deallocate().map { results } + } + } + } +} + + +public struct PreparedQuery { + let database: PostgresDatabase + let name: String + let rowLookupTable: PostgresRow.LookupTable + + init(database: PostgresDatabase, name: String, rowDescription: PostgresRow.LookupTable) { + self.database = database + self.name = name + self.rowLookupTable = rowDescription + } + + public func execute(_ binds: [PostgresData] = []) -> EventLoopFuture<[PostgresRow]> { + var rows: [PostgresRow] = [] + return self.execute(binds) { rows.append($0) }.map { rows } + } + + public func execute(_ binds: [PostgresData] = [], _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + let handler = ExecutePreparedQuery(query: self, binds: binds, onRow: onRow) + return database.send(handler, logger: database.logger) + } + + public func deallocate() -> EventLoopFuture { + database.send(CloseRequest(name: self.name, + closeType: .preparedStatement), + logger:database.logger) + + } +} + + +private final class PrepareQueryRequest: PostgresRequest { + let query: String + let name: String + var rowLookupTable: PostgresRow.LookupTable? + var resultFormatCodes: [PostgresFormatCode] + var logger: Logger? + + init(_ query: String, as name: String) { + self.query = query + self.name = name + self.resultFormatCodes = [.binary] + } + + func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { + switch message.identifier { + case .rowDescription: + let row = try PostgresMessage.RowDescription(message: message) + self.rowLookupTable = PostgresRow.LookupTable( + rowDescription: row, + resultFormat: self.resultFormatCodes + ) + return [] + case .parseComplete, .parameterDescription: + return [] + case .readyForQuery: + return nil + default: + fatalError("Unexpected message: \(message)") + } + + } + + func start() throws -> [PostgresMessage] { + let parse = PostgresMessage.Parse( + statementName: self.name, + query: self.query, + parameterTypes: [] + ) + let describe = PostgresMessage.Describe( + command: .statement, + name: self.name + ) + return try [parse.message(), describe.message(), PostgresMessage.Sync().message()] + } + + + func log(to logger: Logger) { + self.logger = logger + logger.debug("\(self.query) prepared as \(self.name)") + } +} + + +private final class ExecutePreparedQuery: PostgresRequest { + let query: PreparedQuery + let binds: [PostgresData] + var onRow: (PostgresRow) throws -> () + var resultFormatCodes: [PostgresFormatCode] + var logger: Logger? + + init(query: PreparedQuery, binds: [PostgresData], onRow: @escaping (PostgresRow) throws -> ()) { + self.query = query + self.binds = binds + self.onRow = onRow + self.resultFormatCodes = [.binary] + } + + func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { + switch message.identifier { + case .bindComplete: + return [] + case .dataRow: + let data = try PostgresMessage.DataRow(message: message) + let row = PostgresRow(dataRow: data, lookupTable: query.rowLookupTable) + try onRow(row) + return [] + case .noData: + return [] + case .commandComplete: + return [] + case .readyForQuery: + return nil + default: throw PostgresError.protocol("Unexpected message during query: \(message)") + } + } + + func start() throws -> [PostgresMessage] { + + let bind = PostgresMessage.Bind( + portalName: "", + statementName: query.name, + parameterFormatCodes: self.binds.map { $0.formatCode }, + parameters: self.binds.map { .init(value: $0.value) }, + resultFormatCodes: self.resultFormatCodes + ) + let execute = PostgresMessage.Execute( + portalName: "", + maxRows: 0 + ) + + let sync = PostgresMessage.Sync() + return try [bind.message(), execute.message(), sync.message()] + } + + func log(to logger: Logger) { + self.logger = logger + logger.debug("Execute Prepared Query: \(query.name)") + } + +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift new file mode 100644 index 00000000..69871df6 --- /dev/null +++ b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift @@ -0,0 +1,39 @@ +import NIO + +extension PostgresMessage { + /// Identifies the message as a Close Command + public struct Close: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .close + } + + /// Close Target. Determines if the Close command should close a prepared statement + /// or portal. + public enum Target: Int8 { + case preparedStatement = 0x53 // 'S' - prepared statement + case portal = 0x50 // 'P' - portal + } + + /// Determines if the `name` identifes a portal or a prepared statement + public var target: Target + + /// The name of the prepared statement or portal to describe + /// (an empty string selects the unnamed prepared statement or portal). + public var name: String + + + /// See `CustomStringConvertible`. + public var description: String { + switch target { + case .preparedStatement: return "Statement(\(name))" + case .portal: return "Portal(\(name))" + } + } + + /// Serializes this message into a byte buffer. + public func serialize(into buffer: inout ByteBuffer) throws { + buffer.writeInteger(target.rawValue) + buffer.write(nullTerminated: name) + } + } +} diff --git a/Tests/PostgresNIOTests/PostgresNIOTests.swift b/Tests/PostgresNIOTests/PostgresNIOTests.swift index 1edf56ee..6161cef4 100644 --- a/Tests/PostgresNIOTests/PostgresNIOTests.swift +++ b/Tests/PostgresNIOTests/PostgresNIOTests.swift @@ -1075,6 +1075,38 @@ final class PostgresNIOTests: XCTestCase { } } + func testPreparedQuery() throws { + let conn = try PostgresConnection.test(on: eventLoop).wait() + + defer { try! conn.close().wait() } + let prepared = try conn.prepare(query: "SELECT 1 as one;").wait() + let rows = try prepared.execute().wait() + + + XCTAssertEqual(rows.count, 1) + let value = rows[0].column("one") + XCTAssertEqual(value?.int, 1) + } + + func testPrepareQueryClosure() throws { + let conn = try PostgresConnection.test(on: eventLoop).wait() + + defer { try! conn.close().wait() } + let x = conn.prepare(query: "SELECT $1::text as foo;", handler: { query in + let a = query.execute(["a"]) + let b = query.execute(["b"]) + let c = query.execute(["c"]) + return EventLoopFuture.whenAllSucceed([a, b, c], on: conn.eventLoop) + + }) + let rows = try x.wait() + XCTAssertEqual(rows.count, 3) + XCTAssertEqual(rows[0][0].column("foo")?.string, "a") + XCTAssertEqual(rows[1][0].column("foo")?.string, "b") + XCTAssertEqual(rows[2][0].column("foo")?.string, "c") + } + + // https://github.com/vapor/postgres-nio/issues/71 func testChar1Serialization() throws { let conn = try PostgresConnection.test(on: eventLoop).wait() @@ -1163,6 +1195,7 @@ private func prepareTableToMeasureSelectPerformance( } _ = try conn.query(insertQuery, batchedFixtureData).wait() } + }