Skip to content

Commit cdb18d1

Browse files
fabianfettgwynne
andauthored
State machine (vapor#135)
* Adds PSQLFrontendMessage & PSQLBackendMessage * State machine * Removed Xcode headers. * Apply suggestions from code review Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org> * Code review * Apply suggestions from code review Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org> * Code review * Apply suggestions from code review Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org> * Code review * Add rudementary sasl support * Update Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org> * Code review * Code review * A little more error handling * Error handling * Better logging * Fixes! * Some better state handling when closing * State machine tests * Better cleanup in error states * Cherry pick to be reverted. * PreparedStatementStateMachine tests * Code review * Enable trace logging to better find the flaky tests * PSQLChannelHandler logging + cleanup * PR review * Code review * Update Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org> * Last code comment * Last code comment fix Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>
1 parent 5876fdf commit cdb18d1

File tree

125 files changed

+10157
-812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+10157
-812
lines changed

.github/workflows/test.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ jobs:
9999
run: swift test --enable-test-discovery --sanitize=thread
100100
env:
101101
POSTGRES_HOSTNAME: psql
102-
POSTGRES_USERNAME: vapor_username
102+
POSTGRES_USER: vapor_username
103+
POSTGRES_DB: vapor_database
103104
POSTGRES_PASSWORD: vapor_password
104-
POSTGRES_DATABASE: vapor_database
105+
POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }}
105106

106107
# Run package tests on macOS against supported PSQL versions
107108
macos:
@@ -138,6 +139,7 @@ jobs:
138139
run: swift test --enable-test-discovery --sanitize=thread
139140
env:
140141
POSTGRES_HOSTNAME: 127.0.0.1
141-
POSTGRES_USERNAME: vapor_username
142+
POSTGRES_USER: vapor_username
143+
POSTGRES_DB: postgres
142144
POSTGRES_PASSWORD: vapor_password
143-
POSTGRES_DATABASE: postgres
145+
POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }}

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ let package = Package(
2222
.product(name: "Logging", package: "swift-log"),
2323
.product(name: "Metrics", package: "swift-metrics"),
2424
.product(name: "NIO", package: "swift-nio"),
25+
.product(name: "NIOTLS", package: "swift-nio"),
26+
.product(name: "NIOFoundationCompat", package: "swift-nio"),
2527
.product(name: "NIOSSL", package: "swift-nio-ssl"),
2628
]),
2729
.testTarget(name: "PostgresNIOTests", dependencies: [
Lines changed: 9 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import Crypto
21
import NIO
3-
import Logging
42

53
extension PostgresConnection {
64
public func authenticate(
@@ -9,155 +7,17 @@ extension PostgresConnection {
97
password: String? = nil,
108
logger: Logger = .init(label: "codes.vapor.postgres")
119
) -> EventLoopFuture<Void> {
12-
let auth = PostgresAuthenticationRequest(
10+
let authContext = AuthContext(
1311
username: username,
14-
database: database,
15-
password: password
16-
)
17-
return self.send(auth, logger: self.logger)
18-
}
19-
}
20-
21-
// MARK: Private
12+
password: password,
13+
database: database)
14+
let outgoing = PSQLOutgoingEvent.authenticate(authContext)
15+
self.underlying.channel.triggerUserOutboundEvent(outgoing, promise: nil)
2216

23-
private final class PostgresAuthenticationRequest: PostgresRequest {
24-
enum State {
25-
case ready
26-
case saslInitialSent(SASLAuthenticationManager<SASLMechanism.SCRAM.SHA256>)
27-
case saslChallengeResponse(SASLAuthenticationManager<SASLMechanism.SCRAM.SHA256>)
28-
case saslWaitOkay
29-
case done
30-
}
31-
32-
let username: String
33-
let database: String?
34-
let password: String?
35-
var state: State
36-
37-
init(username: String, database: String?, password: String?) {
38-
self.state = .ready
39-
self.username = username
40-
self.database = database
41-
self.password = password
42-
}
43-
44-
func log(to logger: Logger) {
45-
logger.debug("Logging into Postgres db \(self.database ?? "nil") as \(self.username)")
46-
}
47-
48-
func respond(to message: PostgresMessage) throws -> [PostgresMessage]? {
49-
if case .error = message.identifier {
50-
// terminate immediately on error
51-
return nil
52-
}
53-
54-
switch self.state {
55-
case .ready:
56-
switch message.identifier {
57-
case .authentication:
58-
let auth = try PostgresMessage.Authentication(message: message)
59-
switch auth {
60-
case .md5(let salt):
61-
let pwdhash = self.md5((self.password ?? "") + self.username).hexdigest()
62-
let hash = "md5" + self.md5(self.bytes(pwdhash) + salt).hexdigest()
63-
return try [PostgresMessage.Password(string: hash).message()]
64-
case .plaintext:
65-
return try [PostgresMessage.Password(string: self.password ?? "").message()]
66-
case .saslMechanisms(let saslMechanisms):
67-
if saslMechanisms.contains("SCRAM-SHA-256") && self.password != nil {
68-
let saslManager = SASLAuthenticationManager(asClientSpeaking:
69-
SASLMechanism.SCRAM.SHA256(username: self.username, password: { self.password! }))
70-
var message: PostgresMessage?
71-
72-
if (try saslManager.handle(message: nil, sender: { bytes in
73-
message = try PostgresMessage.SASLInitialResponse(mechanism: "SCRAM-SHA-256", initialData: bytes).message()
74-
})) {
75-
self.state = .saslWaitOkay
76-
} else {
77-
self.state = .saslInitialSent(saslManager)
78-
}
79-
return [message].compactMap { $0 }
80-
} else {
81-
throw PostgresError.protocol("Unable to authenticate with any available SASL mechanism: \(saslMechanisms)")
82-
}
83-
case .saslContinue, .saslFinal:
84-
throw PostgresError.protocol("Unexpected SASL response to start message: \(message)")
85-
case .ok:
86-
self.state = .done
87-
return []
88-
}
89-
default: throw PostgresError.protocol("Unexpected response to start message: \(message)")
90-
}
91-
case .saslInitialSent(let manager),
92-
.saslChallengeResponse(let manager):
93-
switch message.identifier {
94-
case .authentication:
95-
let auth = try PostgresMessage.Authentication(message: message)
96-
switch auth {
97-
case .saslContinue(let data), .saslFinal(let data):
98-
var message: PostgresMessage?
99-
if try manager.handle(message: data, sender: { bytes in
100-
message = try PostgresMessage.SASLResponse(responseData: bytes).message()
101-
}) {
102-
self.state = .saslWaitOkay
103-
} else {
104-
self.state = .saslChallengeResponse(manager)
105-
}
106-
return [message].compactMap { $0 }
107-
default: throw PostgresError.protocol("Unexpected response during SASL negotiation: \(message)")
108-
}
109-
default: throw PostgresError.protocol("Unexpected response during SASL negotiation: \(message)")
110-
}
111-
case .saslWaitOkay:
112-
switch message.identifier {
113-
case .authentication:
114-
let auth = try PostgresMessage.Authentication(message: message)
115-
switch auth {
116-
case .ok:
117-
self.state = .done
118-
return []
119-
default: throw PostgresError.protocol("Unexpected response while waiting for post-SASL ok: \(message)")
120-
}
121-
default: throw PostgresError.protocol("Unexpected response while waiting for post-SASL ok: \(message)")
122-
}
123-
case .done:
124-
switch message.identifier {
125-
case .parameterStatus:
126-
// self.status[status.parameter] = status.value
127-
return []
128-
case .backendKeyData:
129-
// self.processID = data.processID
130-
// self.secretKey = data.secretKey
131-
return []
132-
case .readyForQuery:
133-
return nil
134-
default: throw PostgresError.protocol("Unexpected response to password authentication: \(message)")
135-
}
17+
return self.underlying.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in
18+
handler.authenticateFuture
19+
}.flatMapErrorThrowing { error in
20+
throw error.asAppropriatePostgresError
13621
}
137-
138-
}
139-
140-
func start() throws -> [PostgresMessage] {
141-
return try [
142-
PostgresMessage.Startup.versionThree(parameters: [
143-
"user": self.username,
144-
"database": self.database ?? username
145-
]).message()
146-
]
147-
}
148-
149-
// MARK: Private
150-
151-
private func md5(_ string: String) -> [UInt8] {
152-
return md5(self.bytes(string))
153-
}
154-
155-
private func md5(_ message: [UInt8]) -> [UInt8] {
156-
let digest = Insecure.MD5.hash(data: message)
157-
return .init(digest)
158-
}
159-
160-
func bytes(_ string: String) -> [UInt8] {
161-
return Array(string.utf8)
16222
}
16323
}
Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import Logging
21
import NIO
32

43
extension PostgresConnection {
@@ -9,47 +8,26 @@ extension PostgresConnection {
98
logger: Logger = .init(label: "codes.vapor.postgres"),
109
on eventLoop: EventLoop
1110
) -> EventLoopFuture<PostgresConnection> {
12-
let bootstrap = ClientBootstrap(group: eventLoop)
13-
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
14-
return bootstrap.connect(to: socketAddress).flatMap { channel in
15-
return channel.pipeline.addHandlers([
16-
ByteToMessageHandler(PostgresMessageDecoder(logger: logger)),
17-
MessageToByteHandler(PostgresMessageEncoder(logger: logger)),
18-
PostgresRequestHandler(logger: logger),
19-
PostgresErrorHandler(logger: logger)
20-
]).map {
21-
return PostgresConnection(channel: channel, logger: logger)
22-
}
23-
}.flatMap { (conn: PostgresConnection) in
24-
if let tlsConfiguration = tlsConfiguration {
25-
return conn.requestTLS(
26-
using: tlsConfiguration,
27-
serverHostname: serverHostname,
28-
logger: logger
29-
).flatMapError { error in
30-
conn.close().flatMapThrowing {
31-
throw error
32-
}
33-
}.map { conn }
34-
} else {
35-
return eventLoop.makeSucceededFuture(conn)
36-
}
11+
12+
let coders = PSQLConnection.Configuration.Coders(
13+
jsonEncoder: PostgresJSONEncoderWrapper(_defaultJSONEncoder),
14+
jsonDecoder: PostgresJSONDecoderWrapper(_defaultJSONDecoder)
15+
)
16+
17+
let configuration = PSQLConnection.Configuration(
18+
connection: .resolved(address: socketAddress, serverName: serverHostname),
19+
authentication: nil,
20+
tlsConfiguration: tlsConfiguration,
21+
coders: coders)
22+
23+
return PSQLConnection.connect(
24+
configuration: configuration,
25+
logger: logger,
26+
on: eventLoop
27+
).map { connection in
28+
PostgresConnection(underlying: connection, logger: logger)
29+
}.flatMapErrorThrowing { error in
30+
throw error.asAppropriatePostgresError
3731
}
3832
}
3933
}
40-
41-
42-
private final class PostgresErrorHandler: ChannelInboundHandler {
43-
typealias InboundIn = Never
44-
45-
let logger: Logger
46-
init(logger: Logger) {
47-
self.logger = logger
48-
}
49-
50-
func errorCaught(context: ChannelHandlerContext, error: Error) {
51-
self.logger.error("Uncaught error: \(error)")
52-
context.close(promise: nil)
53-
context.fireErrorCaught(error)
54-
}
55-
}

0 commit comments

Comments
 (0)