forked from vapor/postgres-nio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPostgresClient.swift
491 lines (419 loc) · 20.2 KB
/
PostgresClient.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
import NIOCore
import NIOSSL
import Atomics
import Logging
import ServiceLifecycle
import _ConnectionPoolModule
/// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's
/// behavior.
///
/// > Warning:
/// The client can only lease connections if the user is running the client's ``run()`` method in a long running task:
///
/// ```swift
/// let client = PostgresClient(configuration: configuration)
/// await withTaskGroup(of: Void.self) {
/// taskGroup.addTask {
/// client.run() // !important
/// }
///
/// do {
/// let rows = try await connection.query("SELECT userID, name, age FROM users;")
/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) {
/// // do something with the values
/// }
/// } catch {
/// // handle errors
/// }
///
/// // shutdown the client
/// taskGroup.cancelAll()
/// }
/// ```
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
public final class PostgresClient: Sendable, ServiceLifecycle.Service {
public struct Configuration: Sendable {
public struct TLS: Sendable {
enum Base {
case disable
case prefer(NIOSSL.TLSConfiguration)
case require(NIOSSL.TLSConfiguration)
}
var base: Base
private init(_ base: Base) {
self.base = base
}
/// Do not try to create a TLS connection to the server.
public static var disable: Self = Self.init(.disable)
/// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection.
/// If the server does not support TLS, create an insecure connection.
public static func prefer(_ sslContext: NIOSSL.TLSConfiguration) -> Self {
self.init(.prefer(sslContext))
}
/// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection.
/// If the server does not support TLS, fail the connection creation.
public static func require(_ sslContext: NIOSSL.TLSConfiguration) -> Self {
self.init(.require(sslContext))
}
}
// MARK: Client options
/// Describes general client behavior options. Those settings are considered advanced options.
public struct Options: Sendable {
/// A keep-alive behavior for Postgres connections. The ``frequency`` defines after which time an idle
/// connection shall run a keep-alive ``query``.
public struct KeepAliveBehavior: Sendable {
/// The amount of time that shall pass before an idle connection runs a keep-alive ``query``.
public var frequency: Duration
/// The ``query`` that is run on an idle connection after it has been idle for ``frequency``.
public var query: PostgresQuery
/// Create a new `KeepAliveBehavior`.
/// - Parameters:
/// - frequency: The amount of time that shall pass before an idle connection runs a keep-alive `query`.
/// Defaults to `30` seconds.
/// - query: The `query` that is run on an idle connection after it has been idle for `frequency`.
/// Defaults to `SELECT 1;`.
public init(frequency: Duration = .seconds(30), query: PostgresQuery = "SELECT 1;") {
self.frequency = frequency
self.query = query
}
}
/// A timeout for creating a TCP/Unix domain socket connection. Defaults to `10` seconds.
public var connectTimeout: Duration = .seconds(10)
/// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled.
/// Defaults to none (but see below).
///
/// > When set to `nil`:
/// If the connection is made to a server over TCP using
/// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host`
/// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other
/// method, SNI is disabled.
public var tlsServerName: String? = nil
/// Whether the connection is required to provide backend key data (internal Postgres stuff).
///
/// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`.
/// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default).
public var requireBackendKeyData: Bool = true
/// The minimum number of connections that the client shall keep open at any time, even if there is no
/// demand. Default to `0`.
///
/// If the open connection count becomes less than ``minimumConnections`` new connections
/// are created immidiatly. Must be greater or equal to zero and less than ``maximumConnections``.
///
/// Idle connections are kept alive using the ``keepAliveBehavior``.
public var minimumConnections: Int = 0
/// The maximum number of connections that the client may open to the server at any time. Must be greater
/// than ``minimumConnections``. Defaults to `20` connections.
///
/// Connections, that are created in response to demand are kept alive for the ``connectionIdleTimeout``
/// before they are dropped.
public var maximumConnections: Int = 20
/// The maximum amount time that a connection that is not part of the ``minimumConnections`` is kept
/// open without being leased. Defaults to `60` seconds.
public var connectionIdleTimeout: Duration = .seconds(60)
/// The ``KeepAliveBehavior-swift.struct`` to ensure that the underlying tcp-connection is still active
/// for idle connections. `Nil` means that the client shall not run keep alive queries to the server. Defaults to a
/// keep alive query of `SELECT 1;` every `30` seconds.
public var keepAliveBehavior: KeepAliveBehavior? = KeepAliveBehavior()
/// Create an options structure with default values.
///
/// Most users should not need to adjust the defaults.
public init() {}
}
// MARK: - Accessors
/// The hostname to connect to for TCP configurations.
///
/// Always `nil` for other configurations.
public var host: String? {
if case let .connectTCP(host, _) = self.endpointInfo { return host }
else { return nil }
}
/// The port to connect to for TCP configurations.
///
/// Always `nil` for other configurations.
public var port: Int? {
if case let .connectTCP(_, port) = self.endpointInfo { return port }
else { return nil }
}
/// The socket path to connect to for Unix domain socket connections.
///
/// Always `nil` for other configurations.
public var unixSocketPath: String? {
if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path }
else { return nil }
}
/// The TLS mode to use for the connection. Valid for all configurations.
///
/// See ``TLS-swift.struct``.
public var tls: TLS = .prefer(.makeClientConfiguration())
/// Options for handling the communication channel. Most users don't need to change these.
///
/// See ``Options-swift.struct``.
public var options: Options = .init()
/// The username to connect with.
public var username: String
/// The password, if any, for the user specified by ``username``.
///
/// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero
/// length; these are not the same thing.
public var password: String?
/// The name of the database to open.
///
/// - Note: If set to `nil` or an empty string, the provided ``username`` is used.
public var database: String?
// MARK: - Initializers
/// Create a configuration for connecting to a server with a hostname and optional port.
///
/// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost
/// definitely want this one.
///
/// - Parameters:
/// - host: The hostname to connect to.
/// - port: The TCP port to connect to (defaults to 5432).
/// - tls: The TLS mode to use.
public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) {
self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database)
}
/// Create a configuration for connecting to a server through a UNIX domain socket.
///
/// - Parameters:
/// - path: The filesystem path of the socket to connect to.
/// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``.
public init(unixSocketPath: String, username: String, password: String?, database: String?) {
self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database)
}
// MARK: - Implementation details
enum EndpointInfo {
case bindUnixDomainSocket(path: String)
case connectTCP(host: String, port: Int)
}
var endpointInfo: EndpointInfo
init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) {
self.endpointInfo = endpointInfo
self.tls = tls
self.username = username
self.password = password
self.database = database
}
}
typealias Pool = ConnectionPool<
PostgresConnection,
PostgresConnection.ID,
ConnectionIDGenerator,
ConnectionRequest<PostgresConnection>,
ConnectionRequest.ID,
PostgresKeepAliveBehavor,
PostgresClientMetrics,
ContinuousClock
>
let pool: Pool
let factory: ConnectionFactory
let runningAtomic = ManagedAtomic(false)
let backgroundLogger: Logger
/// Creates a new ``PostgresClient``, that does not log any background information.
/// Don't forget to run ``run()`` the client in a long running task.
///
/// - Parameters:
/// - configuration: The client's configuration. See ``Configuration`` for details.
/// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``.
public convenience init(
configuration: Configuration,
eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup
) {
self.init(configuration: configuration, eventLoopGroup: eventLoopGroup, backgroundLogger: Self.loggingDisabled)
}
/// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task.
///
/// - Parameters:
/// - configuration: The client's configuration. See ``Configuration`` for details.
/// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``.
/// - backgroundLogger: A `swift-log` `Logger` to log background messages to. A copy of this logger is also
/// forwarded to the created connections as a background logger.
public init(
configuration: Configuration,
eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup,
backgroundLogger: Logger
) {
let factory = ConnectionFactory(config: configuration, eventLoopGroup: eventLoopGroup, logger: backgroundLogger)
self.factory = factory
self.backgroundLogger = backgroundLogger
self.pool = ConnectionPool(
configuration: .init(configuration),
idGenerator: ConnectionIDGenerator(),
requestType: ConnectionRequest<PostgresConnection>.self,
keepAliveBehavior: .init(configuration.options.keepAliveBehavior, logger: backgroundLogger),
observabilityDelegate: .init(logger: backgroundLogger),
clock: ContinuousClock()
) { (connectionID, pool) in
let connection = try await factory.makeConnection(connectionID, pool: pool)
return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1)
}
}
/// Lease a connection for the provided `closure`'s lifetime.
///
/// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture
/// the provided `PostgresConnection`.
/// - Returns: The closure's return value.
public func withConnection<Result>(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result {
let connection = try await self.leaseConnection()
defer { self.pool.releaseConnection(connection) }
return try await closure(connection)
}
/// Run a query on the Postgres server the client is connected to.
///
/// - Parameters:
/// - query: The ``PostgresQuery`` to run
/// - logger: The `Logger` to log into for the query
/// - file: The file, the query was started in. Used for better error reporting.
/// - line: The line, the query was started in. Used for better error reporting.
/// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result.
/// The sequence be discarded.
@discardableResult
public func query(
_ query: PostgresQuery,
logger: Logger? = nil,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresRowSequence {
let logger = logger ?? Self.loggingDisabled
do {
guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}
let connection = try await self.leaseConnection()
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(connection.id)"
let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)
connection.channel.write(HandlerTask.extendedQuery(context), promise: nil)
promise.futureResult.whenFailure { _ in
self.pool.releaseConnection(connection)
}
return try await promise.futureResult.map {
$0.asyncSequence(onFinish: {
self.pool.releaseConnection(connection)
})
}.get()
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}
/// Execute a prepared statement, taking care of the preparation when necessary
public func execute<Statement: PostgresPreparedStatement, Row>(
_ preparedStatement: Statement,
logger: Logger? = nil,
file: String = #fileID,
line: Int = #line
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Row> where Row == Statement.Row {
let bindings = try preparedStatement.makeBindings()
let logger = logger ?? Self.loggingDisabled
do {
let connection = try await self.leaseConnection()
let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let task = HandlerTask.executePreparedStatement(.init(
name: String(reflecting: Statement.self),
sql: Statement.sql,
bindings: bindings,
bindingDataTypes: Statement.bindingDataTypes,
logger: logger,
promise: promise
))
connection.channel.write(task, promise: nil)
promise.futureResult.whenFailure { _ in
self.pool.releaseConnection(connection)
}
return try await promise.futureResult
.map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) }
.get()
.map { try preparedStatement.decodeRow($0) }
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = .init(
unsafeSQL: Statement.sql,
binds: bindings
)
throw error // rethrow with more metadata
}
}
/// The client's run method. Users must call this function in order to start the client's background task processing
/// like creating and destroying connections and running timers.
///
/// Calls to ``withConnection(_:)`` will emit a `logger` warning, if ``run()`` hasn't been called previously.
public func run() async {
let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed)
precondition(!atomicOp.original, "PostgresClient.run() should just be called once!")
await cancelOnGracefulShutdown {
await self.pool.run()
}
}
// MARK: - Private Methods -
private func leaseConnection() async throws -> PostgresConnection {
if !self.runningAtomic.load(ordering: .relaxed) {
self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.")
}
return try await self.pool.leaseConnection()
}
/// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform.
///
/// This will select the concrete `EventLoopGroup` depending which platform this is running on.
public static var defaultEventLoopGroup: EventLoopGroup {
PostgresConnection.defaultEventLoopGroup
}
static let loggingDisabled = Logger(label: "Postgres-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() })
}
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
struct PostgresKeepAliveBehavor: ConnectionKeepAliveBehavior {
let behavior: PostgresClient.Configuration.Options.KeepAliveBehavior?
let logger: Logger
init(_ behavior: PostgresClient.Configuration.Options.KeepAliveBehavior?, logger: Logger) {
self.behavior = behavior
self.logger = logger
}
var keepAliveFrequency: Duration? {
self.behavior?.frequency
}
func runKeepAlive(for connection: PostgresConnection) async throws {
try await connection.query(self.behavior!.query, logger: self.logger).map { _ in }.get()
}
}
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
extension ConnectionPoolConfiguration {
init(_ config: PostgresClient.Configuration) {
self = ConnectionPoolConfiguration()
self.minimumConnectionCount = config.options.minimumConnections
self.maximumConnectionSoftLimit = config.options.maximumConnections
self.maximumConnectionHardLimit = config.options.maximumConnections
self.idleTimeout = config.options.connectionIdleTimeout
}
}
extension PostgresConnection: PooledConnection {
public func close() {
self.channel.close(mode: .all, promise: nil)
}
public func onClose(_ closure: @escaping ((any Error)?) -> ()) {
self.closeFuture.whenComplete { _ in closure(nil) }
}
}
extension ConnectionPoolError {
func mapToPSQLError(lastConnectError: Error?) -> Error {
var psqlError: PSQLError
switch self {
case .poolShutdown:
psqlError = PSQLError.poolClosed
psqlError.underlying = self
case .requestCancelled:
psqlError = PSQLError.queryCancelled
psqlError.underlying = self
default:
return self
}
return psqlError
}
}