Skip to content

Commit 8ac3869

Browse files
author
Andrew Theis
committed
Move SSL handshake to PostgreSQLConnection.connect static method
1 parent 8bc1856 commit 8ac3869

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

Sources/PostgreSQL/Connection/PostgreSQLConnection+TCP.swift

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import Async
22
import NIO
3+
import NIOOpenSSL
34

45
extension PostgreSQLConnection {
56
/// Connects to a Redis server using a TCP socket.
67
public static func connect(
78
hostname: String = "localhost",
89
port: Int = 5432,
10+
tlsConfiguration: TLSConfiguration? = nil,
911
on worker: Worker,
1012
onError: @escaping (Error) -> ()
1113
) throws -> Future<PostgreSQLConnection> {
@@ -19,8 +21,12 @@ extension PostgreSQLConnection {
1921
}
2022
}
2123

22-
return bootstrap.connect(host: hostname, port: port).map(to: PostgreSQLConnection.self) { channel in
23-
return .init(queue: handler, channel: channel)
24+
return bootstrap.connect(host: hostname, port: port).flatMap(to: PostgreSQLConnection.self) { channel in
25+
let connection = PostgreSQLConnection(queue: handler, channel: channel)
26+
if let tlsConfiguration = tlsConfiguration {
27+
return connection.attemptSSLConnection(using: tlsConfiguration).transform(to: connection)
28+
}
29+
return Future.map(on: worker) { connection }
2430
}
2531
}
2632
}

Sources/PostgreSQL/Database/PostgreSQLDatabase.swift

+1-6
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,8 @@ public final class PostgreSQLDatabase: Database, LogSupporting {
1919
public func newConnection(on worker: Worker) -> Future<PostgreSQLConnection> {
2020
let config = self.config
2121
return Future.flatMap(on: worker) {
22-
return try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, on: worker) { error in
22+
return try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, tlsConfiguration: config.tlsConfiguration, on: worker) { error in
2323
print("[PostgreSQL] \(error)")
24-
}.flatMap(to: PostgreSQLConnection.self) { client in
25-
if let tlsConfiguration = config.tlsConfiguration {
26-
return client.attemptSSLConnection(using: tlsConfiguration).transform(to: client)
27-
}
28-
return Future.map(on: worker) { client }
2924
}.flatMap(to: PostgreSQLConnection.self) { client in
3025
return client.authenticate(
3126
username: config.username,

Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import Foundation
22
import XCTest
33
import NIO
4+
import NIOOpenSSL
45
import PostgreSQL
56
import Core
67

@@ -11,6 +12,12 @@ class PostgreSQLConnectionTests: XCTestCase {
1112
let results = try client.simpleQuery("SELECT version();").wait()
1213
try XCTAssert(results[0].firstValue(forColumn: "version")?.decode(String.self).contains("10.") == true)
1314
}
15+
16+
func testSSLConnection() throws {
17+
let client = try PostgreSQLConnection.makeTest(tlsConfiguration: .forClient(certificateVerification: .none))
18+
let results = try client.simpleQuery("SELECT version();").wait()
19+
try XCTAssert(results[0].firstValue(forColumn: "version")?.decode(String.self).contains("10.") == true)
20+
}
1421

1522
func testSelectTypes() throws {
1623
let client = try PostgreSQLConnection.makeTest()
@@ -451,15 +458,15 @@ class PostgreSQLConnectionTests: XCTestCase {
451458

452459
extension PostgreSQLConnection {
453460
/// Creates a test event loop and psql client.
454-
static func makeTest() throws -> PostgreSQLConnection {
461+
static func makeTest(tlsConfiguration: TLSConfiguration? = nil) throws -> PostgreSQLConnection {
455462
let hostname: String
456463
#if Xcode
457464
hostname = (try? Process.execute("docker-machine", "ip")) ?? "192.168.99.100"
458465
#else
459466
hostname = "localhost"
460467
#endif
461468
let group = MultiThreadedEventLoopGroup(numThreads: 1)
462-
let client = try PostgreSQLConnection.connect(hostname: hostname, on: group) { error in
469+
let client = try PostgreSQLConnection.connect(hostname: hostname, tlsConfiguration: tlsConfiguration, on: group) { error in
463470
XCTFail("\(error)")
464471
}.wait()
465472
_ = try client.authenticate(username: "vapor_username", database: "vapor_database", password: nil).wait()

0 commit comments

Comments
 (0)