Skip to content

Commit f784169

Browse files
author
Andrew Theis
committed
Update based on PR feedback:
- Add PostgreSQLTransportConfig struct to wrap the possible transport methods - Use return worker.eventLoop.newSucceededFuture(result: connection) instead of Future.map
1 parent a45a570 commit f784169

File tree

5 files changed

+66
-17
lines changed

5 files changed

+66
-17
lines changed

Sources/PostgreSQL/Connection/PostgreSQLConnection+TCP.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ extension PostgreSQLConnection {
77
public static func connect(
88
hostname: String = "localhost",
99
port: Int = 5432,
10-
tlsConfiguration: TLSConfiguration? = nil,
10+
transportConfig: PostgreSQLTransportConfig = .cleartext,
1111
on worker: Worker,
1212
onError: @escaping (Error) -> ()
1313
) throws -> Future<PostgreSQLConnection> {
@@ -23,10 +23,10 @@ extension PostgreSQLConnection {
2323

2424
return bootstrap.connect(host: hostname, port: port).flatMap(to: PostgreSQLConnection.self) { channel in
2525
let connection = PostgreSQLConnection(queue: handler, channel: channel)
26-
if let tlsConfiguration = tlsConfiguration {
26+
if case .tls(let tlsConfiguration) = transportConfig.method {
2727
return connection.addSSLClientHandler(using: tlsConfiguration).transform(to: connection)
2828
}
29-
return Future.map(on: worker) { connection }
29+
return worker.eventLoop.newSucceededFuture(result: connection)
3030
}
3131
}
3232
}

Sources/PostgreSQL/Database/PostgreSQLDatabase.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public final class PostgreSQLDatabase: Database, LogSupporting {
1919
public func newConnection(on worker: Worker) -> Future<PostgreSQLConnection> {
2020
let config = self.config
2121
return Future.flatMap(on: worker) {
22-
return try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, tlsConfiguration: config.tlsConfiguration, on: worker) { error in
22+
return try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, transportConfig: config.transportConfig, on: worker) { error in
2323
print("[PostgreSQL] \(error)")
2424
}.flatMap(to: PostgreSQLConnection.self) { client in
2525
return client.authenticate(

Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ public struct PostgreSQLDatabaseConfig {
2424
/// Optional password to use for authentication.
2525
public let password: String?
2626

27-
/// Optional TLSConfiguration. Set this if your PostgreSQL server requires an SSL connection
28-
/// For paid Heroku Postgres plans, set this to `.forClient(certificateVerification: .none)`
29-
public let tlsConfiguration: TLSConfiguration?
27+
/// Configures how data is transported to the server. Use this to enable SSL.
28+
/// See `PostgreSQLTransportConfig` for more info
29+
public let transportConfig: PostgreSQLTransportConfig
3030

3131
/// Creates a new `PostgreSQLDatabaseConfig`.
32-
public init(hostname: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, tlsConfiguration: TLSConfiguration? = nil) {
32+
public init(hostname: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, transportConfig: PostgreSQLTransportConfig = .cleartext) {
3333
self.hostname = hostname
3434
self.port = port
3535
self.username = username
3636
self.database = database
3737
self.password = password
38-
self.tlsConfiguration = tlsConfiguration
38+
self.transportConfig = transportConfig
3939
}
4040

4141
/// Creates a `PostgreSQLDatabaseConfig` frome a connection string.
42-
public init(url urlString: String, tlsConfiguration: TLSConfiguration? = nil) throws {
42+
public init(url urlString: String, transportConfig: PostgreSQLTransportConfig = .cleartext) throws {
4343
guard let url = URL(string: urlString),
4444
let hostname = url.host,
4545
let port = url.port,
@@ -64,6 +64,6 @@ public struct PostgreSQLDatabaseConfig {
6464
self.database = database
6565
}
6666
self.password = url.password
67-
self.tlsConfiguration = tlsConfiguration
67+
self.transportConfig = transportConfig
6868
}
6969
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import Foundation
2+
import NIOOpenSSL
3+
4+
5+
public struct PostgreSQLTransportConfig {
6+
/// Does not attempt to enable TLS (this is the default)
7+
public static var cleartext: PostgreSQLTransportConfig {
8+
return .init(method: .cleartext)
9+
}
10+
11+
/// Enables TLS requiring a minimum version of TLS v1.1 on the server, but disables certificate verification
12+
/// This is what you would commonly use for paid Heroku PostgreSQL plans
13+
public static var unverifiedTLS: PostgreSQLTransportConfig {
14+
return .init(method: .tls(.forClient(certificateVerification: .none)))
15+
}
16+
17+
/// Enables TLS requiring a minimum version of TLS v1.1 on the server
18+
public static var standardTLS: PostgreSQLTransportConfig {
19+
return .init(method: .tls(.forClient()))
20+
}
21+
22+
/// Enables TLS requiring a minimum version of TLS v1.2 on the server
23+
public static var modernTLS: PostgreSQLTransportConfig {
24+
return .init(method: .tls(.forClient(minimumTLSVersion: .tlsv12)))
25+
}
26+
27+
/// Enables TLS requiring a minimum version of TLS v1.3 on the server
28+
/// TLS v1.3 specification is still a draft and unlikely to be supported by most servers
29+
/// See https://tools.ietf.org/html/draft-ietf-tls-tls13-28 for more info
30+
public static var edgeTLS: PostgreSQLTransportConfig {
31+
return .init(method: .tls(.forClient(minimumTLSVersion: .tlsv13)))
32+
}
33+
34+
/// Enables TLS and allows you to use a set `TLSConfiguration`
35+
public static func customTLS(_ tlsConfiguration: TLSConfiguration)-> PostgreSQLTransportConfig {
36+
return .init(method: .tls(tlsConfiguration))
37+
}
38+
39+
internal enum Method {
40+
case cleartext
41+
case tls(TLSConfiguration)
42+
}
43+
44+
internal let method: Method
45+
46+
internal init(method: Method) {
47+
self.method = method
48+
}
49+
}

Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class PostgreSQLConnectionTests: XCTestCase {
1313
try XCTAssert(results[0].firstValue(forColumn: "version")?.decode(String.self).contains("10.") == true)
1414
}
1515

16-
func testSSLConnection() throws {
17-
let client = try PostgreSQLConnection.makeTest(tlsConfiguration: .forClient(certificateVerification: .none))
16+
func testUnverifiedSSLConnection() throws {
17+
let client = try PostgreSQLConnection.makeTest(transportConfig: .unverifiedTLS)
1818
let results = try client.simpleQuery("SELECT version();").wait()
1919
try XCTAssert(results[0].firstValue(forColumn: "version")?.decode(String.self).contains("10.") == true)
2020
}
@@ -441,7 +441,7 @@ class PostgreSQLConnectionTests: XCTestCase {
441441
}
442442

443443
static var allTests = [
444-
("testSSLConnection", testSSLConnection),
444+
("testUnverifiedSSLConnection", testUnverifiedSSLConnection),
445445
("testVersion", testVersion),
446446
("testSelectTypes", testSelectTypes),
447447
("testParse", testParse),
@@ -459,7 +459,7 @@ class PostgreSQLConnectionTests: XCTestCase {
459459

460460
extension PostgreSQLConnection {
461461
/// Creates a test event loop and psql client.
462-
static func makeTest(tlsConfiguration: TLSConfiguration? = nil) throws -> PostgreSQLConnection {
462+
static func makeTest(transportConfig: PostgreSQLTransportConfig? = nil) throws -> PostgreSQLConnection {
463463
let hostname: String
464464
#if Xcode
465465
hostname = (try? Process.execute("docker-machine", "ip")) ?? "192.168.99.100"
@@ -469,8 +469,8 @@ extension PostgreSQLConnection {
469469
let group = MultiThreadedEventLoopGroup(numThreads: 1)
470470
var client: PostgreSQLConnection
471471

472-
if let tlsConfiguration = tlsConfiguration {
473-
client = try PostgreSQLConnection.connect(hostname: hostname, port: 5433, tlsConfiguration: tlsConfiguration, on: group) { error in
472+
if let transportConfig = transportConfig {
473+
client = try PostgreSQLConnection.connect(hostname: hostname, port: 5433, transportConfig: transportConfig, on: group) { error in
474474
XCTFail("\(error)")
475475
}.wait()
476476
} else {

0 commit comments

Comments
 (0)