Skip to content

Commit 7b4595c

Browse files
committed
parameterized queries
1 parent f121eb1 commit 7b4595c

9 files changed

+103
-13
lines changed

Sources/PostgreSQL/Message/Base/PostgreSQLMessage.swift

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ enum PostgreSQLMessage {
1515
case parse(PostgreSQLParseRequest)
1616
/// Identifies the message as a Bind command.
1717
case bind(PostgreSQLBindRequest)
18+
/// Identifies the message as a Describe command.
19+
case describe(PostgreSQLDescribeRequest)
1820
/// Identifies the message as an Execute command.
1921
case execute(PostgreSQLExecuteRequest)
2022
/// Identifies the message as a Sync command.
@@ -23,4 +25,6 @@ enum PostgreSQLMessage {
2325
case parseComplete
2426
/// Identifies the message as a Bind-complete indicator.
2527
case bindComplete
28+
29+
2630
}

Sources/PostgreSQL/Message/Messages/PostgreSQLDataRow.swift

+29-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,35 @@ struct PostgreSQLDataRowColumn: Decodable {
4747

4848
/// Parses this column to the specified data type assuming binary format.
4949
func parseBinary(dataType: PostgreSQLDataType) throws -> PostgreSQLData {
50-
fatalError("Binary format code not supported.")
50+
switch dataType {
51+
case .name, .text:
52+
return try makeString().flatMap { .string($0) } ?? .null
53+
case .oid, .regproc, .int4:
54+
return makeFixedWidthInteger(Int32.self).flatMap { .int32($0) } ?? .null
55+
case .int2:
56+
return makeFixedWidthInteger(Int16.self).flatMap { .int16($0) } ?? .null
57+
case .bool:
58+
return makeFixedWidthInteger(Byte.self).flatMap { .bool($0 == 1) } ?? .null
59+
case .char:
60+
return makeFixedWidthInteger(Byte.self).flatMap { byte in
61+
let char = Character(Unicode.Scalar(byte))
62+
return .character(char)
63+
} ?? .null
64+
case .pg_node_tree:
65+
print("pg node tree")
66+
return .null
67+
case ._aclitem:
68+
print("acl item")
69+
return .null
70+
}
71+
}
72+
73+
func makeFixedWidthInteger<I>(_ type: I.Type = I.self) -> I? where I: FixedWidthInteger {
74+
return value.flatMap { data in
75+
return data.withUnsafeBytes { (pointer: UnsafePointer<I>) -> I in
76+
return pointer.pointee.bigEndian
77+
}
78+
}
5179
}
5280

5381
/// Parses this column to the specified data type assuming text format.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import Bits
2+
3+
/*
4+
Describe (F)
5+
Byte1('D')
6+
Identifies the message as a Describe command.
7+
8+
Int32
9+
Length of message contents in bytes, including self.
10+
11+
Byte1
12+
'S' to describe a prepared statement; or 'P' to describe a portal.
13+
14+
String
15+
The name of the prepared statement or portal to describe (an empty string selects the unnamed prepared statement or portal).
16+
17+
*/
18+
19+
/// Identifies the message as a Describe command.
20+
struct PostgreSQLDescribeRequest: Encodable {
21+
/// 'S' to describe a prepared statement; or 'P' to describe a portal.
22+
let type: PostgreSQLDescribeType
23+
24+
/// The name of the prepared statement or portal to describe
25+
/// (an empty string selects the unnamed prepared statement or portal).
26+
var name: String
27+
}
28+
29+
enum PostgreSQLDescribeType: Byte, Encodable {
30+
case statement = 0x53 // S
31+
case portal = 0x50 // P
32+
}

Sources/PostgreSQL/Message/PostgreSQLMessageEncoder.swift

+9-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ final class PostgreSQLMessageEncoder {
2525
case .bind(let bind):
2626
identifier = .B
2727
try bind.encode(to: encoder)
28+
case .describe(let describe):
29+
identifier = .D
30+
try describe.encode(to: encoder)
2831
case .execute(let execute):
2932
identifier = .E
3033
try execute.encode(to: encoder)
@@ -84,6 +87,11 @@ internal final class _PostgreSQLMessageEncoder: Encoder, SingleValueEncodingCont
8487
self.data.append(numericCast(value))
8588
}
8689

90+
/// See SingleValueEncodingContainer.encode
91+
func encode(_ value: UInt8) throws {
92+
self.data.append(value)
93+
}
94+
8795
/// See SingleValueEncodingContainer.encode
8896
func encode(_ value: Int16) throws {
8997
var value = value.bigEndian
@@ -137,7 +145,6 @@ internal final class _PostgreSQLMessageEncoder: Encoder, SingleValueEncodingCont
137145

138146
func encode(_ value: Int) throws { fatalError("Unsupported type: \(type(of: value))") }
139147
func encode(_ value: UInt) throws { fatalError("Unsupported type: \(type(of: value))") }
140-
func encode(_ value: UInt8) throws { fatalError("Unsupported type: \(type(of: value))") }
141148
func encode(_ value: UInt16) throws { fatalError("Unsupported type: \(type(of: value))") }
142149
func encode(_ value: UInt32) throws { fatalError("Unsupported type: \(type(of: value))") }
143150
func encode(_ value: UInt64) throws { fatalError("Unsupported type: \(type(of: value))") }
@@ -214,7 +221,7 @@ fileprivate final class _PostgreSQLMessageUnkeyedEncoder: UnkeyedEncodingContain
214221
func nestedContainer<NestedKey>(keyedBy keyType: NestedKey.Type)
215222
-> KeyedEncodingContainer<NestedKey> where NestedKey: CodingKey { return encoder.container(keyedBy: NestedKey.self) }
216223
func nestedUnkeyedContainer() -> UnkeyedEncodingContainer { return encoder.unkeyedContainer() }
217-
func superEncoder() -> Encoder { print(#function); return encoder }
224+
func superEncoder() -> Encoder { return encoder }
218225

219226
deinit {
220227
let size = numericCast(count) as Int16

Sources/PostgreSQL/Message/PostgreSQLMessageParser.swift

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ final class PostgreSQLMessageParser: TranslatingStream {
3535

3636
/// Parses the data, setting `excess` or requesting more data if insufficient.
3737
func parse(data: Data) throws -> TranslatingStreamResult<PostgreSQLMessage> {
38-
print("Parse: \(data.hexDebug)")
3938
let data = buffered + data
4039
guard let (message, remaining) = try PostgreSQLMessageDecoder().decode(data) else {
4140
buffered.append(data)

Sources/PostgreSQL/Message/PostgreSQLMessageSerializer.swift

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ final class PostgreSQLMessageSerializer: TranslatingStream {
3939

4040
/// Serializes data, storing `excess` if it does not fit in the buffer.
4141
func serialize(data: Data) -> TranslatingStreamResult<ByteBuffer> {
42-
print("Serialize: \(data.hexDebug)")
4342
let count = data.copyBytes(to: buffer)
4443
let view = ByteBuffer(start: buffer.baseAddress, count: count)
4544
if data.count > count {

Sources/PostgreSQL/PostgreSQLClient.swift

+21-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ final class PostgreSQLClient {
2424
func send(_ message: PostgreSQLMessage) -> Future<[PostgreSQLMessage]> {
2525
var responses: [PostgreSQLMessage] = []
2626
return queueStream.enqueue([message]) { message in
27-
print(message)
2827
responses.append(message)
2928
switch message {
3029
case .readyForQuery: return true
@@ -36,6 +35,16 @@ final class PostgreSQLClient {
3635
}
3736
}
3837

38+
/// Sends a simple PostgreSQL query command, collecting the parsed results.
39+
func query(_ string: String) -> Future<[[String: PostgreSQLData]]> {
40+
var rows: [[String: PostgreSQLData]] = []
41+
return query(string) { row in
42+
rows.append(row)
43+
}.map(to: [[String: PostgreSQLData]].self) {
44+
return rows
45+
}
46+
}
47+
3948
/// Sends a simple PostgreSQL query command, returning the parsed results to
4049
/// the supplied closure.
4150
func query(_ string: String, onRow: @escaping ([String: PostgreSQLData]) -> ()) -> Future<Void> {
@@ -58,16 +67,21 @@ final class PostgreSQLClient {
5867
}
5968
}
6069

61-
/// Sends a simple PostgreSQL query command, collecting the parsed results.
62-
func query(_ string: String) -> Future<[[String: PostgreSQLData]]> {
70+
/// Sends a parameterized PostgreSQL query command, collecting the parsed results.
71+
func parameterizedQuery(
72+
_ string: String,
73+
_ parameters: [PostgreSQLData] = []
74+
) throws -> Future<[[String: PostgreSQLData]]> {
6375
var rows: [[String: PostgreSQLData]] = []
64-
return query(string) { row in
76+
return try parameterizedQuery(string, parameters) { row in
6577
rows.append(row)
6678
}.map(to: [[String: PostgreSQLData]].self) {
6779
return rows
6880
}
6981
}
7082

83+
/// Sends a parameterized PostgreSQL query command, returning the parsed results to
84+
/// the supplied closure.
7185
func parameterizedQuery(
7286
_ string: String,
7387
_ parameters: [PostgreSQLData] = [],
@@ -85,15 +99,15 @@ final class PostgreSQLClient {
8599
parameters: parameters.map { try .serialize(data: $0) },
86100
resultFormatCodes: [.binary]
87101
)
102+
let describe = PostgreSQLDescribeRequest(type: .portal, name: "")
88103
let execute = PostgreSQLExecuteRequest(
89104
portalName: "",
90105
maxRows: 0
91106
)
92107
var currentRow: PostgreSQLRowDescription?
93108
return queueStream.enqueue([
94-
.parse(parse), .bind(bind), .execute(execute), .sync
109+
.parse(parse), .bind(bind), .describe(describe), .execute(execute), .sync
95110
]) { message in
96-
print(message)
97111
switch message {
98112
case .errorResponse(let e): throw e
99113
case .parseComplete: return false
@@ -157,7 +171,7 @@ public final class AsymmetricQueueStream<I, O>: Stream, ConnectionContext {
157171
for o in output {
158172
self.queuedOutput.insert(o, at: 0)
159173
}
160-
upstream!.request(count: UInt(output.count))
174+
upstream!.request(count: 1)
161175
update()
162176
return input.promise.future
163177
}

Sources/PostgreSQL/PostgreSQLData.swift

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ extension PostgreSQLData {
3030
/// Returns int value, `nil` if not an int.
3131
public var int: Int? {
3232
switch self {
33+
case .int8(let i): return Int(i)
34+
case .int16(let i): return Int(i)
35+
case .int32(let i): return Int(i)
3336
case .int(let i): return i
3437
case .string(let s): return Int(s)
3538
default: return nil

Tests/PostgreSQLTests/PostgreSQLClientTests.swift

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ class PostgreSQLClientTests: XCTestCase {
2222
let query = """
2323
select * from "pg_type" where "typlen" = $1 or "typlen" = $2
2424
"""
25-
try client.parameterizedQuery(query, [
25+
let rows = try client.parameterizedQuery(query, [
2626
.int32(1),
2727
.int32(2),
2828
]).await(on: eventLoop)
29+
30+
for row in rows {
31+
XCTAssert(row["typlen"]?.int == 1 || row["typlen"]?.int == 2)
32+
}
2933
}
3034

3135
static var allTests = [

0 commit comments

Comments
 (0)