Skip to content

Commit ec852e0

Browse files
committed
postgresql message parse refactor
1 parent 3474089 commit ec852e0

30 files changed

+764
-614
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import Crypto
2+
3+
extension PostgreSQLConnection {
4+
/// Authenticates the `PostgreSQLClient` using a username with no password.
5+
public func authenticate(username: String, database: String? = nil, password: String? = nil) -> Future<Void> {
6+
var authRequest: PostgreSQLMessage.AuthenticationRequest?
7+
return queue.enqueue([.startupMessage(.versionThree(parameters: [
8+
"user": username,
9+
"database": database ?? username
10+
]))]) { message in
11+
print("👈 \(message)")
12+
switch message {
13+
case .authenticationRequest(let a):
14+
authRequest = a
15+
return true
16+
case .error(let error): throw PostgreSQLError.errorResponse(error)
17+
default: throw PostgreSQLError(identifier: "auth", reason: "Unsupported message encountered during auth: \(message).")
18+
}
19+
}.flatMap {
20+
guard let auth = authRequest else {
21+
throw PostgreSQLError(identifier: "authRequest", reason: "No authorization request / status sent.")
22+
}
23+
24+
let input: [PostgreSQLMessage]
25+
switch auth {
26+
case .ok:
27+
guard password == nil else {
28+
throw PostgreSQLError(identifier: "trust", reason: "No password is required")
29+
}
30+
input = []
31+
case .plaintext:
32+
guard let password = password else {
33+
throw PostgreSQLError(identifier: "password", reason: "Password is required")
34+
}
35+
input = [.password(.init(password: password))]
36+
case .md5(let salt):
37+
guard let password = password else {
38+
throw PostgreSQLError(identifier: "password", reason: "Password is required")
39+
}
40+
guard let passwordData = password.data(using: .utf8) else {
41+
throw PostgreSQLError(identifier: "passwordUTF8", reason: "Could not convert password to UTF-8 encoded Data.")
42+
}
43+
44+
guard let usernameData = username.data(using: .utf8) else {
45+
throw PostgreSQLError(identifier: "usernameUTF8", reason: "Could not convert username to UTF-8 encoded Data.")
46+
}
47+
48+
// pwdhash = md5(password + username).hexdigest()
49+
let pwdhash = try MD5.hash(passwordData + usernameData).hexEncodedString()
50+
// hash = "md5" + md 5(pwdhash + salt).hexdigest()
51+
let hash = try "md5" + MD5.hash(Data(pwdhash.utf8) + salt).hexEncodedString()
52+
input = [.password(.init(password: hash))]
53+
}
54+
55+
return self.queue.enqueue(input) { message in
56+
print("👈 \(message)")
57+
switch message {
58+
case .error(let error): throw PostgreSQLError.errorResponse(error)
59+
case .readyForQuery: return true
60+
case .authenticationRequest: return false
61+
case .parameterStatus, .backendKeyData: return false
62+
default: throw PostgreSQLError(identifier: "authenticationMessage", reason: "Unexpected authentication message: \(message)")
63+
}
64+
}
65+
}
66+
}
67+
}
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
extension PostgreSQLConnection {
22
internal func listen(_ channelName: String, handler: @escaping (String) throws -> ()) throws -> Future<Void> {
33
closeHandlers.append({ conn in
4-
let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";")
5-
return conn.send([.query(query)], onResponse: { _ in })
4+
return conn.send([.query(.init(query: "UNLISTEN \"\(channelName)\";"))], onResponse: { _ in })
65
})
76

87
notificationHandlers[channelName] = { message in
98
try handler(message)
109
}
11-
let query = PostgreSQLQuery(query: "LISTEN \"\(channelName)\";")
12-
return queue.enqueue([.query(query)], onInput: { message in
10+
return queue.enqueue([.query(.init(query: "LISTEN \"\(channelName)\";"))], onInput: { message in
1311
switch message {
1412
case let .notificationResponse(notification):
1513
try self.notificationHandlers[notification.channel]?(notification.message)
@@ -21,13 +19,11 @@ extension PostgreSQLConnection {
2119
}
2220

2321
internal func notify(_ channelName: String, message: String) throws -> Future<Void> {
24-
let query = PostgreSQLQuery(query: "NOTIFY \"\(channelName)\", '\(message)';")
25-
return send([.query(query)]).map(to: Void.self, { _ in })
22+
return send([.query(.init(query: "NOTIFY \"\(channelName)\", '\(message)';"))]).map(to: Void.self, { _ in })
2623
}
2724

2825
internal func unlisten(_ channelName: String, unlistenHandler: (() -> Void)? = nil) throws -> Future<Void> {
2926
notificationHandlers.removeValue(forKey: channelName)
30-
let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";")
31-
return send([.query(query)], onResponse: { _ in unlistenHandler?() })
27+
return send([.query(.init(query: "UNLISTEN \"\(channelName)\";"))], onResponse: { _ in unlistenHandler?() })
3228
}
3329
}

Sources/PostgreSQL/Connection/PostgreSQLConnection+Query.swift

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,30 +99,30 @@ extension PostgreSQLConnection {
9999
let parameters = try parameters.map { try $0.convertToPostgreSQLData() }
100100
logger?.record(query: string, values: parameters.map { $0.description })
101101

102-
let parse = PostgreSQLParseRequest(statementName: "", query: string, parameterTypes: parameters.map { $0.type })
103-
let describe = PostgreSQLDescribeRequest(type: .statement, name: "")
104-
let bind = PostgreSQLMessage.BindRequest(
105-
portalName: "",
106-
statementName: "",
107-
parameterFormatCodes: parameters.map {
108-
switch $0.storage {
109-
case .text: return .text
110-
case .binary, .null: return .binary
111-
}
112-
},
113-
parameters: parameters.map {
114-
switch $0.storage {
115-
case .text(let string): return .init(data: Data(string.utf8))
116-
case .binary(let data): return .init(data: data)
117-
case .null: return .init(data: nil)
118-
}
119-
},
120-
resultFormatCodes: resultFormat.formatCodes
121-
)
122-
let execute = PostgreSQLExecuteRequest(portalName: "", maxRows: 0)
123-
var currentRow: PostgreSQLRowDescription?
102+
var currentRow: PostgreSQLMessage.RowDescription?
124103
return self.send([
125-
.parse(parse), .describe(describe), .bind(bind), .execute(execute), .sync
104+
.parse(.init(statementName: "", query: string, parameterTypes: parameters.map { $0.type })),
105+
.describe(.init(command: .statement, name: "")),
106+
.bind(.init(
107+
portalName: "",
108+
statementName: "",
109+
parameterFormatCodes: parameters.map {
110+
switch $0.storage {
111+
case .text: return .text
112+
case .binary, .null: return .binary
113+
}
114+
},
115+
parameters: parameters.map {
116+
switch $0.storage {
117+
case .text(let string): return .init(data: Data(string.utf8))
118+
case .binary(let data): return .init(data: data)
119+
case .null: return .init(data: nil)
120+
}
121+
},
122+
resultFormatCodes: resultFormat.formatCodes
123+
)),
124+
.execute(.init(portalName: "", maxRows: 0)),
125+
.sync
126126
]) { message in
127127
switch message {
128128
case .parseComplete: break

Sources/PostgreSQL/Connection/PostgreSQLConnection+SimpleQuery.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ extension PostgreSQLConnection {
3939
/// Non-operation bounded simple query.
4040
private func _simpleQuery(_ string: String, onRow: @escaping ([PostgreSQLColumn: PostgreSQLData]) -> ()) -> Future<Void> {
4141
logger?.record(query: string)
42-
var currentRow: PostgreSQLRowDescription?
43-
let query = PostgreSQLQuery(query: string)
44-
return send([.query(query)]) { message in
42+
var currentRow: PostgreSQLMessage.RowDescription?
43+
return send([.query(.init(query: string))]) { message in
4544
switch message {
4645
case .rowDescription(let row):
4746
currentRow = row

Sources/PostgreSQL/Connection/PostgreSQLConnection+TransportConfig.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ extension PostgreSQLConnection {
6262
/// Ask the server if it supports SSL and adds a new OpenSSLClientHandler to pipeline if it does
6363
/// This will throw an error if the server does not support SSL
6464
internal func addSSLClientHandler(using tlsConfiguration: TLSConfiguration) -> Future<Void> {
65-
return queue.enqueue([.sslSupportRequest(PostgreSQLSSLSupportRequest())]) { message in
65+
return queue.enqueue([.sslSupportRequest(.init())]) { message in
6666
guard case .sslSupportResponse(let response) = message else {
6767
throw PostgreSQLError(identifier: "SSL support check", reason: "Unsupported message encountered during SSL support check: \(message).")
6868
}
6969
guard response == .supported else {
7070
throw PostgreSQLError(identifier: "SSL support check", reason: "tlsConfiguration given in PostgresSQLConfiguration, but SSL connection not supported by PostgreSQL server.")
7171
}
7272
return true
73-
}.flatMap {
74-
let sslContext = try SSLContext(configuration: tlsConfiguration)
75-
let handler = try OpenSSLClientHandler(context: sslContext)
76-
return self.channel.pipeline.add(handler: handler, first: true)
73+
}.flatMap {
74+
let sslContext = try SSLContext(configuration: tlsConfiguration)
75+
let handler = try OpenSSLClientHandler(context: sslContext)
76+
return self.channel.pipeline.add(handler: handler, first: true)
7777
}
7878
}
7979
}

Sources/PostgreSQL/Connection/PostgreSQLConnection.swift

Lines changed: 4 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import Crypto
2-
31
/// A PostgreSQL frontend client.
42
public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
53
/// See `DatabaseConnection`.
@@ -86,7 +84,7 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
8684
case .readyForQuery:
8785
if let e = error { throw e }
8886
return true
89-
case .error(let e): error = e
87+
case .error(let e): error = PostgreSQLError.errorResponse(e)
9088
case .notice(let n): print(n)
9189
default: try onResponse(message)
9290
}
@@ -115,78 +113,12 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
115113
return new
116114
}
117115

118-
/// Authenticates the `PostgreSQLClient` using a username with no password.
119-
public func authenticate(username: String, database: String? = nil, password: String? = nil) -> Future<Void> {
120-
let startup = PostgreSQLStartupMessage.versionThree(parameters: [
121-
"user": username,
122-
"database": database ?? username
123-
])
124-
var authRequest: PostgreSQLMessage.AuthenticationRequest?
125-
return queue.enqueue([.startupMessage(startup)]) { message in
126-
switch message {
127-
case .authenticationRequest(let a):
128-
authRequest = a
129-
return true
130-
default: throw PostgreSQLError(identifier: "auth", reason: "Unsupported message encountered during auth: \(message).")
131-
}
132-
}.flatMap(to: Void.self) {
133-
guard let auth = authRequest else {
134-
throw PostgreSQLError(identifier: "authRequest", reason: "No authorization request / status sent.")
135-
}
136-
137-
let input: [PostgreSQLMessage]
138-
switch auth {
139-
case .ok:
140-
guard password == nil else {
141-
throw PostgreSQLError(identifier: "trust", reason: "No password is required")
142-
}
143-
input = []
144-
case .plaintext:
145-
guard let password = password else {
146-
throw PostgreSQLError(identifier: "password", reason: "Password is required")
147-
}
148-
let passwordMessage = PostgreSQLPasswordMessage(password: password)
149-
input = [.password(passwordMessage)]
150-
case .md5(let salt):
151-
guard let password = password else {
152-
throw PostgreSQLError(identifier: "password", reason: "Password is required")
153-
}
154-
guard let passwordData = password.data(using: .utf8) else {
155-
throw PostgreSQLError(identifier: "passwordUTF8", reason: "Could not convert password to UTF-8 encoded Data.")
156-
}
157-
158-
guard let usernameData = username.data(using: .utf8) else {
159-
throw PostgreSQLError(identifier: "usernameUTF8", reason: "Could not convert username to UTF-8 encoded Data.")
160-
}
161-
162-
// pwdhash = md5(password + username).hexdigest()
163-
let pwdhash = try MD5.hash(passwordData + usernameData).hexEncodedString()
164-
// hash = "md5" + md 5(pwdhash + salt).hexdigest()
165-
let hash = try "md5" + MD5.hash(Data(pwdhash.utf8) + salt).hexEncodedString()
166-
167-
let passwordMessage = PostgreSQLPasswordMessage(password: hash)
168-
input = [.password(passwordMessage)]
169-
}
170-
171-
return self.queue.enqueue(input) { message in
172-
switch message {
173-
case .error(let error): throw error
174-
case .readyForQuery: return true
175-
case .authenticationRequest: return false
176-
case .parameterStatus, .backendKeyData: return false
177-
default: throw PostgreSQLError(identifier: "authenticationMessage", reason: "Unexpected authentication message: \(message)")
178-
}
179-
}
180-
}
181-
}
182-
183-
184116
/// Closes this client.
185117
public func close() {
186118
_ = executeCloseHandlersThenClose()
187119
}
188120

189-
121+
/// Executes close handlers before closing.
190122
private func executeCloseHandlersThenClose() -> Future<Void> {
191123
if let beforeClose = closeHandlers.popLast() {
192124
return beforeClose(self).then { _ in
@@ -197,12 +129,13 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
197129
}
198130
}
199131

200-
201132
/// Called when this class deinitializes.
202133
deinit {
203134
close()
204135
}
205136

206137
}
207138

139+
// MARK: Private
140+
208141
private let closeError = PostgreSQLError(identifier: "closed", reason: "Connection is closed.")

Sources/PostgreSQL/Message/PostgreSQLMessage+0.swift

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
// note: Please list enum cases alphabetically.
22

33
/// A frontend or backend PostgreSQL message.
4-
enum PostgreSQLMessage {
5-
/// The format code being used for the field.
6-
/// Currently will be zero (text) or one (binary).
7-
/// In a RowDescription returned from the statement variant of Describe,
8-
/// the format code is not yet known and will always be zero.
9-
enum FormatCode: Int16, Codable {
10-
case text = 0
11-
case binary = 1
12-
}
13-
4+
enum PostgreSQLMessage {
145
/// One of the various authentication request message formats.
156
case authenticationRequest(AuthenticationRequest)
167

@@ -25,61 +16,61 @@ enum PostgreSQLMessage {
2516
case bindComplete
2617

2718
/// Identifies the message as a command-completed response.
28-
case close(PostgreSQLCloseResponse)
19+
case close(CloseResponse)
2920

3021
/// Identifies the message as a data row.
31-
case dataRow(PostgreSQLDataRow)
22+
case dataRow(DataRow)
3223

3324
/// Identifies the message as a Describe command.
34-
case describe(PostgreSQLDescribeRequest)
25+
case describe(DescribeRequest)
3526

3627
/// Identifies the message as an error.
37-
case error(PostgreSQLDiagnosticResponse)
28+
case error(ErrorResponse)
3829

3930
/// Identifies the message as an Execute command.
40-
case execute(PostgreSQLExecuteRequest)
31+
case execute(ExecuteRequest)
4132

4233
/// Identifies the message as a no-data indicator.
4334
case noData
4435

4536
/// Identifies the message as a notice.
46-
case notice(PostgreSQLDiagnosticResponse)
37+
case notice(ErrorResponse)
4738

4839
/// Identifies the message as a notification response.
49-
case notificationResponse(PostgreSQLNotificationResponse)
40+
case notificationResponse(NotificationResponse)
5041

5142
/// Identifies the message as a parameter description.
52-
case parameterDescription(PostgreSQLParameterDescription)
43+
case parameterDescription(ParameterDescription)
5344

5445
/// Identifies the message as a run-time parameter status report.
55-
case parameterStatus(PostgreSQLParameterStatus)
46+
case parameterStatus(ParameterStatus)
5647

5748
/// Identifies the message as a Parse command.
58-
case parse(PostgreSQLParseRequest)
49+
case parse(ParseRequest)
5950

6051
/// Identifies the message as a Parse-complete indicator.
6152
case parseComplete
6253

6354
/// Identifies the message as a password response.
64-
case password(PostgreSQLPasswordMessage)
55+
case password(PasswordMessage)
6556

6657
/// Identifies the message as a simple query.
67-
case query(PostgreSQLQuery)
58+
case query(Query)
6859

6960
/// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle.
70-
case readyForQuery(PostgreSQLReadyForQuery)
61+
case readyForQuery(ReadyForQuery)
7162

7263
/// Identifies the message as a row description.
73-
case rowDescription(PostgreSQLRowDescription)
64+
case rowDescription(RowDescription)
7465

7566
/// Response after sending an sslSupportRequest message.
76-
case sslSupportResponse(PostgreSQLSSLSupportResponse)
67+
case sslSupportResponse(SupportResponse)
7768

7869
/// Asks the server if it supports SSL.
79-
case sslSupportRequest(PostgreSQLSSLSupportRequest)
70+
case sslSupportRequest(SSLSupportRequest)
8071

8172
/// Startup message
82-
case startupMessage(PostgreSQLStartupMessage)
73+
case startupMessage(StartupMessage)
8374

8475
/// Identifies the message as a Sync command.
8576
case sync

0 commit comments

Comments
 (0)