Skip to content

Commit 2905779

Browse files
authored
Land PostgresClient that is backed by a ConnectionPool as SPI (vapor#430)
1 parent add68a0 commit 2905779

File tree

11 files changed

+764
-14
lines changed

11 files changed

+764
-14
lines changed

Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ extension PoolStateMachine {
308308
}
309309

310310
@inlinable
311-
mutating func parkConnection(at index: Int) -> Max2Sequence<ConnectionTimer> {
311+
mutating func parkConnection(at index: Int, hasBecomeIdle newIdle: Bool) -> Max2Sequence<ConnectionTimer> {
312312
let scheduleIdleTimeoutTimer: Bool
313313
switch index {
314314
case 0..<self.minimumConcurrentConnections:
@@ -318,7 +318,7 @@ extension PoolStateMachine {
318318

319319
case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
320320
// if a connection is a demand connection, we want a timeout timer
321-
scheduleIdleTimeoutTimer = true
321+
scheduleIdleTimeoutTimer = newIdle
322322

323323
case self.maximumConcurrentConnectionSoftLimit..<self.maximumConcurrentConnectionHardLimit:
324324
preconditionFailure("Overflow connections should never be parked.")
@@ -626,8 +626,11 @@ extension PoolStateMachine {
626626

627627
case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
628628
// the connection to be removed is a demand connection
629+
self.connections.swapAt(indexToDelete, lastConnectedIndex)
630+
self.removeO1(lastConnectedIndex)
631+
629632
switch lastConnectedIndex {
630-
case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
633+
case self.maximumConcurrentConnectionSoftLimit..<self.maximumConcurrentConnectionHardLimit:
631634
// an overflow connection was moved to a demand connection. It has to be currently leased
632635
precondition(self.connections[indexToDelete].isLeased)
633636
return nil

Sources/ConnectionPoolModule/PoolStateMachine.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ struct PoolStateMachine<
435435
case .leased:
436436
return .none()
437437

438-
case .idle:
439-
let timers = self.connections.parkConnection(at: index).map(self.mapTimers)
438+
case .idle(_, let newIdle):
439+
let timers = self.connections.parkConnection(at: index, hasBecomeIdle: newIdle).map(self.mapTimers)
440440

441441
return .init(
442442
request: .none,

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public final class PostgresConnection: @unchecked Sendable {
4444
return !self.channel.isActive
4545
}
4646

47-
let id: ID
47+
public let id: ID
4848

4949
private var _logger: Logger
5050

@@ -391,7 +391,7 @@ extension PostgresConnection {
391391
self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise)
392392
return try await promise.futureResult.get()
393393
} onCancel: {
394-
_ = self.close()
394+
self.close()
395395
}
396396
}
397397

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ extension ConnectionStateMachine {
985985
}
986986

987987
return false
988-
case .clientClosedConnection:
988+
case .clientClosedConnection, .poolClosed:
989989
preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen")
990990
case .serverClosedConnection:
991991
return true

Sources/PostgresNIO/New/PSQLError.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public struct PSQLError: Error {
2525

2626
case listenFailed
2727
case unlistenFailed
28+
case poolClosed
2829
}
2930

3031
internal var base: Base
@@ -33,22 +34,25 @@ public struct PSQLError: Error {
3334
self.base = base
3435
}
3536

36-
public static let sslUnsupported = Self.init(.sslUnsupported)
37+
public static let sslUnsupported = Self(.sslUnsupported)
3738
public static let failedToAddSSLHandler = Self(.failedToAddSSLHandler)
3839
public static let receivedUnencryptedDataAfterSSLRequest = Self(.receivedUnencryptedDataAfterSSLRequest)
3940
public static let server = Self(.server)
4041
public static let messageDecodingFailure = Self(.messageDecodingFailure)
4142
public static let unexpectedBackendMessage = Self(.unexpectedBackendMessage)
4243
public static let unsupportedAuthMechanism = Self(.unsupportedAuthMechanism)
4344
public static let authMechanismRequiresPassword = Self(.authMechanismRequiresPassword)
44-
public static let saslError = Self.init(.saslError)
45+
public static let saslError = Self(.saslError)
4546
public static let invalidCommandTag = Self(.invalidCommandTag)
4647
public static let queryCancelled = Self(.queryCancelled)
4748
public static let tooManyParameters = Self(.tooManyParameters)
4849
public static let clientClosedConnection = Self(.clientClosedConnection)
4950
public static let serverClosedConnection = Self(.serverClosedConnection)
5051
public static let connectionError = Self(.connectionError)
51-
public static let uncleanShutdown = Self.init(.uncleanShutdown)
52+
53+
public static let uncleanShutdown = Self(.uncleanShutdown)
54+
public static let poolClosed = Self(.poolClosed)
55+
5256
public static let listenFailed = Self.init(.listenFailed)
5357
public static let unlistenFailed = Self.init(.unlistenFailed)
5458

@@ -92,6 +96,8 @@ public struct PSQLError: Error {
9296
return "connectionError"
9397
case .uncleanShutdown:
9498
return "uncleanShutdown"
99+
case .poolClosed:
100+
return "poolClosed"
95101
case .listenFailed:
96102
return "listenFailed"
97103
case .unlistenFailed:
@@ -457,6 +463,10 @@ public struct PSQLError: Error {
457463
case sspi
458464
case sasl(mechanisms: [String])
459465
}
466+
467+
static var poolClosed: PSQLError {
468+
Self.init(code: .poolClosed)
469+
}
460470
}
461471

462472
extension PSQLError: CustomStringConvertible {
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import Logging
2+
import NIOConcurrencyHelpers
3+
import NIOCore
4+
import NIOSSL
5+
6+
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
7+
final class ConnectionFactory: Sendable {
8+
9+
struct ConfigCache: Sendable {
10+
var config: PostgresClient.Configuration
11+
}
12+
13+
let configBox: NIOLockedValueBox<ConfigCache>
14+
15+
struct SSLContextCache: Sendable {
16+
enum State {
17+
case none
18+
case producing(TLSConfiguration, [CheckedContinuation<NIOSSLContext, any Error>])
19+
case cached(TLSConfiguration, NIOSSLContext)
20+
case failed(TLSConfiguration, any Error)
21+
}
22+
23+
var state: State = .none
24+
}
25+
26+
let sslContextBox = NIOLockedValueBox(SSLContextCache())
27+
28+
let eventLoopGroup: any EventLoopGroup
29+
30+
let logger: Logger
31+
32+
init(config: PostgresClient.Configuration, eventLoopGroup: any EventLoopGroup, logger: Logger) {
33+
self.eventLoopGroup = eventLoopGroup
34+
self.configBox = NIOLockedValueBox(ConfigCache(config: config))
35+
self.logger = logger
36+
}
37+
38+
func makeConnection(_ connectionID: PostgresConnection.ID, pool: PostgresClient.Pool) async throws -> PostgresConnection {
39+
let config = try await self.makeConnectionConfig()
40+
41+
var connectionLogger = self.logger
42+
connectionLogger[postgresMetadataKey: .connectionID] = "\(connectionID)"
43+
44+
return try await PostgresConnection.connect(
45+
on: self.eventLoopGroup.any(),
46+
configuration: config,
47+
id: connectionID,
48+
logger: connectionLogger
49+
).get()
50+
}
51+
52+
func makeConnectionConfig() async throws -> PostgresConnection.Configuration {
53+
let config = self.configBox.withLockedValue { $0.config }
54+
55+
let tls: PostgresConnection.Configuration.TLS
56+
switch config.tls.base {
57+
case .prefer(let tlsConfiguration):
58+
let sslContext = try await self.getSSLContext(for: tlsConfiguration)
59+
tls = .prefer(sslContext)
60+
61+
case .require(let tlsConfiguration):
62+
let sslContext = try await self.getSSLContext(for: tlsConfiguration)
63+
tls = .require(sslContext)
64+
case .disable:
65+
tls = .disable
66+
}
67+
68+
var connectionConfig: PostgresConnection.Configuration
69+
switch config.endpointInfo {
70+
case .bindUnixDomainSocket(let path):
71+
connectionConfig = PostgresConnection.Configuration(
72+
unixSocketPath: path,
73+
username: config.username,
74+
password: config.password,
75+
database: config.database
76+
)
77+
78+
case .connectTCP(let host, let port):
79+
connectionConfig = PostgresConnection.Configuration(
80+
host: host,
81+
port: port,
82+
username: config.username,
83+
password: config.password,
84+
database: config.database,
85+
tls: tls
86+
)
87+
}
88+
89+
connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout)
90+
connectionConfig.options.tlsServerName = config.options.tlsServerName
91+
connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData
92+
93+
return connectionConfig
94+
}
95+
96+
private func getSSLContext(for tlsConfiguration: TLSConfiguration) async throws -> NIOSSLContext {
97+
enum Action {
98+
case produce
99+
case succeed(NIOSSLContext)
100+
case fail(any Error)
101+
case wait
102+
}
103+
104+
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<NIOSSLContext, any Error>) in
105+
let action = self.sslContextBox.withLockedValue { cache -> Action in
106+
switch cache.state {
107+
case .none:
108+
cache.state = .producing(tlsConfiguration, [continuation])
109+
return .produce
110+
111+
case .cached(let cachedTLSConfiguration, let context):
112+
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
113+
return .succeed(context)
114+
} else {
115+
cache.state = .producing(tlsConfiguration, [continuation])
116+
return .produce
117+
}
118+
119+
case .failed(let cachedTLSConfiguration, let error):
120+
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
121+
return .fail(error)
122+
} else {
123+
cache.state = .producing(tlsConfiguration, [continuation])
124+
return .produce
125+
}
126+
127+
case .producing(let cachedTLSConfiguration, var continuations):
128+
continuations.append(continuation)
129+
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
130+
cache.state = .producing(cachedTLSConfiguration, continuations)
131+
return .wait
132+
} else {
133+
cache.state = .producing(tlsConfiguration, continuations)
134+
return .produce
135+
}
136+
}
137+
}
138+
139+
switch action {
140+
case .wait:
141+
break
142+
143+
case .produce:
144+
// TBD: we might want to consider moving this off the concurrent executor
145+
self.reportProduceSSLContextResult(
146+
Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}),
147+
for: tlsConfiguration
148+
)
149+
150+
case .succeed(let context):
151+
continuation.resume(returning: context)
152+
153+
case .fail(let error):
154+
continuation.resume(throwing: error)
155+
}
156+
}
157+
}
158+
159+
private func reportProduceSSLContextResult(_ result: Result<NIOSSLContext, any Error>, for tlsConfiguration: TLSConfiguration) {
160+
enum Action {
161+
case fail(any Error, [CheckedContinuation<NIOSSLContext, any Error>])
162+
case succeed(NIOSSLContext, [CheckedContinuation<NIOSSLContext, any Error>])
163+
case none
164+
}
165+
166+
let action = self.sslContextBox.withLockedValue { cache -> Action in
167+
switch cache.state {
168+
case .none:
169+
preconditionFailure("Invalid state: \(cache.state)")
170+
171+
case .cached, .failed:
172+
return .none
173+
174+
case .producing(let cachedTLSConfiguration, let continuations):
175+
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
176+
switch result {
177+
case .success(let context):
178+
cache.state = .cached(cachedTLSConfiguration, context)
179+
return .succeed(context, continuations)
180+
181+
case .failure(let failure):
182+
cache.state = .failed(cachedTLSConfiguration, failure)
183+
return .fail(failure, continuations)
184+
}
185+
} else {
186+
return .none
187+
}
188+
}
189+
}
190+
191+
switch action {
192+
case .none:
193+
break
194+
195+
case .succeed(let context, let continuations):
196+
for continuation in continuations {
197+
continuation.resume(returning: context)
198+
}
199+
200+
case .fail(let error, let continuations):
201+
for continuation in continuations {
202+
continuation.resume(throwing: error)
203+
}
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)