Skip to content

Commit b2d222a

Browse files
committed
additional type support
1 parent 97dbc92 commit b2d222a

File tree

6 files changed

+143
-45
lines changed

6 files changed

+143
-45
lines changed

Sources/PostgreSQL/Message/Messages/PostgreSQLRowDescription.swift

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct PostgreSQLRowDescriptionField: Decodable {
3232
var columnAttributeNumber: Int16
3333

3434
/// The object ID of the field's data type.
35-
var dataTypeObjectID: Int32
35+
var dataType: PostgreSQLDataType
3636

3737
/// The data type size (see pg_type.typlen). Note that negative values denote variable-width types.
3838
var dataTypeSize: Int16
@@ -52,9 +52,41 @@ struct PostgreSQLRowDescriptionField: Decodable {
5252
name = try single.decode(String.self)
5353
tableObjectID = try single.decode(Int32.self)
5454
columnAttributeNumber = try single.decode(Int16.self)
55-
dataTypeObjectID = try single.decode(Int32.self)
55+
dataType = try single.decode(PostgreSQLDataType.self)
5656
dataTypeSize = try single.decode(Int16.self)
5757
dataTypeModifier = try single.decode(Int32.self)
5858
formatCode = try single.decode(Int16.self)
5959
}
6060
}
61+
62+
/// The data type's raw object ID.
63+
/// Use `select * from pg_type where oid in (<idhere>);` to lookup more information.
64+
enum PostgreSQLDataType: Int32, Decodable, Equatable {
65+
case bool = 16
66+
case char = 18
67+
case name = 19
68+
case int2 = 21
69+
case int4 = 23
70+
case regproc = 24
71+
case text = 25
72+
case oid = 26
73+
case pg_node_tree = 194
74+
case _aclitem = 1034
75+
76+
/// See Decodable.decode
77+
init(from decoder: Decoder) throws {
78+
let single = try decoder.singleValueContainer()
79+
let objectID = try single.decode(Int32.self)
80+
guard let type = PostgreSQLDataType.make(objectID) else {
81+
throw PostgreSQLError(
82+
identifier: "unsupportedColumnType",
83+
reason: "Unsupported data type: \(objectID)"
84+
)
85+
}
86+
self = type
87+
}
88+
89+
private static func make(_ objectID: Int32) -> PostgreSQLDataType? {
90+
return PostgreSQLDataType(rawValue: objectID)
91+
}
92+
}

Sources/PostgreSQL/Message/PostgreSQLMessageDecoder.swift

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,32 @@ internal final class _PostgreSQLMessageDecoder: Decoder, SingleValueDecodingCont
9090

9191
/// See SingleValueDecodingContainer.decode
9292
func decode(_ type: Int16.Type) throws -> Int16 {
93-
var int: Int16 = 0
94-
int += Int16(self.data.unsafePopFirst() << 8)
95-
int += Int16(self.data.unsafePopFirst())
96-
return int
93+
return try decode(fixedWidthInteger: Int16.self)
9794
}
9895

9996
/// See SingleValueDecodingContainer.decode
10097
func decode(_ type: Int32.Type) throws -> Int32 {
101-
var int: Int32 = 0
102-
int += Int32(self.data.unsafePopFirst() << 24)
103-
int += Int32(self.data.unsafePopFirst() << 16)
104-
int += Int32(self.data.unsafePopFirst() << 8)
105-
int += Int32(self.data.unsafePopFirst())
98+
return try decode(fixedWidthInteger: Int32.self)
99+
}
100+
101+
/// Decodes a fixed width integer.
102+
func decode<B>(fixedWidthInteger type: B.Type) throws -> B where B: FixedWidthInteger {
103+
guard data.count >= MemoryLayout<B>.size else {
104+
fatalError("Unexpected end of data while decoding \(B.self).")
105+
}
106+
107+
108+
let int: B = data.withUnsafeBytes { (pointer: UnsafePointer<UInt8>) -> B in
109+
return pointer.withMemoryRebound(to: B.self, capacity: 1) { (pointer: UnsafePointer<B>) -> B in
110+
return pointer.pointee.bigEndian
111+
}
112+
}
113+
114+
/// FIXME: performance
115+
for _ in 0..<MemoryLayout<B>.size {
116+
_ = data.unsafePopFirst()
117+
}
118+
106119
return int
107120
}
108121

Sources/PostgreSQL/Message/PostgreSQLMessageParser.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ final class PostgreSQLMessageParser: TranslatingStream {
3535

3636
/// Parses the data, setting `excess` or requesting more data if insufficient.
3737
func parse(data: Data) throws -> TranslatingStreamResult<PostgreSQLMessage> {
38-
// print("Parse: \(data.hexDebug)")
3938
let data = buffered + data
4039
guard let (message, remaining) = try PostgreSQLMessageDecoder().decode(data) else {
4140
buffered.append(data)

Sources/PostgreSQL/PostgreSQLClient.swift

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,44 @@ final class PostgreSQLClient {
5454
for (i, field) in row.fields.enumerated() {
5555
let col = data.columns[i]
5656
let data: PostgreSQLData
57-
switch field.dataTypeObjectID {
58-
case 25: // text
59-
data = try col.value.flatMap { data in
60-
guard let string = String(data: data, encoding: .utf8) else {
61-
throw PostgreSQLError(identifier: "utf8String", reason: "Unexpected non-UTF8 string.")
57+
switch field.formatCode {
58+
case 0: // text
59+
func makeString() throws -> String? {
60+
return try col.value.flatMap { data in
61+
guard let string = String(data: data, encoding: .utf8) else {
62+
throw PostgreSQLError(identifier: "utf8String", reason: "Unexpected non-UTF8 string.")
63+
}
64+
return string
6265
}
63-
return string
64-
}.flatMap { .string($0) } ?? .null
65-
default:
66-
throw PostgreSQLError(
67-
identifier: "unsupportedColumnType",
68-
reason: "Unsupported column type on field \(field.name): \(field.dataTypeObjectID)"
69-
)
66+
}
67+
68+
switch field.dataType {
69+
case .bool:
70+
data = try makeString().flatMap { $0 == "t" }.flatMap { .bool($0) } ?? .null
71+
case .text, .name:
72+
data = try makeString().flatMap { .string($0) } ?? .null
73+
case .oid, .regproc, .int4:
74+
data = try makeString().flatMap { Int32($0) }.flatMap { .int32($0) } ?? .null
75+
case .int2:
76+
data = try makeString().flatMap { Int16($0) }.flatMap { .int16($0) } ?? .null
77+
case .char:
78+
data = try makeString().flatMap { Character($0) }.flatMap { .character($0) } ?? .null
79+
case .pg_node_tree:
80+
print("\(field.name): is pg node tree")
81+
data = .null
82+
case ._aclitem:
83+
print("\(field.name): is acl item")
84+
data = .null
85+
}
86+
case 1: fatalError("Binary format code not supported.")
87+
default: fatalError("Unexpected format code: \(field.formatCode)")
7088
}
89+
7190
parsed[field.name] = data
7291
}
7392

7493
// append the result
7594
results.append(parsed)
76-
77-
// reset current row
78-
currentRow = nil
7995
default: break
8096
}
8197
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,38 @@
1+
/// Supported `PostgreSQLData` data types.
12
public enum PostgreSQLData {
3+
case character(Character)
24
case string(String)
5+
6+
case bool(Bool)
7+
8+
case int8(Int8)
9+
case int16(Int16)
10+
case int32(Int32)
11+
case int(Int)
12+
13+
case uint8(UInt8)
14+
case uint16(UInt16)
15+
case uint32(UInt32)
16+
case uint(UInt)
17+
318
case null
419
}
20+
21+
extension PostgreSQLData {
22+
/// Returns string value, `nil` if not a string.
23+
public var string: String? {
24+
switch self {
25+
case .string(let s): return s
26+
default: return nil
27+
}
28+
}
29+
30+
/// Returns int value, `nil` if not an int.
31+
public var int: Int? {
32+
switch self {
33+
case .int(let i): return i
34+
case .string(let s): return Int(s)
35+
default: return nil
36+
}
37+
}
38+
}

Tests/PostgreSQLTests/PostgreSQLClientTests.swift

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,31 @@ import XCTest
55
import TCP
66

77
class PostgreSQLClientTests: XCTestCase {
8-
func testStreaming() throws {
9-
let eventLoop = try DefaultEventLoop(label: "codes.vapor.postgresql.client.test")
10-
let client = try PostgreSQLClient.connect(on: eventLoop)
11-
12-
let startup = PostgreSQLStartupMessage.versionThree(parameters: ["user": "postgres"])
13-
let startupRes = try client.send(.startupMessage(startup)).await(on: eventLoop)
14-
for log in startupRes {
15-
switch log {
16-
case .parameterStatus(let param):
17-
if param.parameter == "session_authorization" {
18-
XCTAssertEqual(param.value, "postgres")
19-
}
20-
default: break
21-
}
22-
}
23-
8+
func testVersion() throws {
9+
let (client, eventLoop) = try PostgreSQLClient.makeTest()
2410
let results = try client.query("SELECT version();").await(on: eventLoop)
25-
print(results[0]["version"]!)
11+
XCTAssert(results[0]["version"]?.string?.contains("10.1") == true)
12+
}
13+
14+
func testSelectTypes() throws {
15+
let (client, eventLoop) = try PostgreSQLClient.makeTest()
16+
let results = try client.query("select * from pg_type;").await(on: eventLoop)
17+
XCTAssert(results.count > 350)
2618
}
2719

2820
static var allTests = [
29-
("testStreaming", testStreaming),
21+
("testVersion", testVersion),
3022
]
3123
}
24+
25+
extension PostgreSQLClient {
26+
/// Creates a test event loop and psql client.
27+
static func makeTest() throws -> (PostgreSQLClient, EventLoop) {
28+
let eventLoop = try DefaultEventLoop(label: "codes.vapor.postgresql.client.test")
29+
let client = try PostgreSQLClient.connect(on: eventLoop)
30+
31+
let startup = PostgreSQLStartupMessage.versionThree(parameters: ["user": "postgres"])
32+
_ = try client.send(.startupMessage(startup)).await(on: eventLoop)
33+
return (client, eventLoop)
34+
}
35+
}

0 commit comments

Comments
 (0)