Skip to content

Commit ccd2da4

Browse files
committed
Review fixes.
1 parent 2d78bdc commit ccd2da4

File tree

3 files changed

+59
-44
lines changed

3 files changed

+59
-44
lines changed

Sources/PostgreSQL/Connection/PostgreSQLConnection+Connect.swift

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,50 @@ import NIO
33
import NIOOpenSSL
44

55
extension PostgreSQLConnection {
6-
/// Connects to a PostgreSQL server using a TCP socket.
6+
@available(*, deprecated, message: "Use `.connect(to:...)` instead.")
77
public static func connect(
8-
to serverAddress: PostgreSQLDatabaseConfig.ServerAddress = .default,
8+
hostname: String = "localhost",
9+
port: Int = 5432,
910
transport: PostgreSQLTransportConfig = .cleartext,
1011
on worker: Worker,
1112
onError: @escaping (Error) -> ()
12-
) throws -> Future<PostgreSQLConnection> {
13-
let handler = QueueHandler<PostgreSQLMessage, PostgreSQLMessage>(on: worker, onError: onError)
14-
let bootstrap = ClientBootstrap(group: worker.eventLoop)
15-
// Enable SO_REUSEADDR.
16-
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
17-
.channelInitializer { channel in
18-
return channel.pipeline.addPostgreSQLClientHandlers().then {
19-
channel.pipeline.add(handler: handler)
20-
}
21-
}
22-
23-
let connectedBootstrap: Future<Channel>
24-
switch serverAddress {
25-
case let .tcp(hostname, port):
26-
connectedBootstrap = bootstrap.connect(host: hostname, port: port)
27-
case let .unixSocket(socketPath):
28-
connectedBootstrap = bootstrap.connect(unixDomainSocketPath: socketPath)
29-
}
30-
31-
return connectedBootstrap.flatMap { channel in
32-
let connection = PostgreSQLConnection(queue: handler, channel: channel)
33-
if case .tls(let tlsConfiguration) = transport.method {
34-
return connection.addSSLClientHandler(using: tlsConfiguration).transform(to: connection)
35-
} else {
36-
return worker.eventLoop.newSucceededFuture(result: connection)
37-
}
38-
}
13+
) throws -> Future<PostgreSQLConnection> {
14+
return connect(to: .tcp(hostname: hostname, port: port), transport: transport, on: worker, onError: onError)
15+
}
16+
17+
/// Connects to a PostgreSQL server using a TCP socket.
18+
public static func connect(
19+
to serverAddress: PostgreSQLDatabaseConfig.ServerAddress = .default,
20+
transport: PostgreSQLTransportConfig = .cleartext,
21+
on worker: Worker,
22+
onError: @escaping (Error) -> ()
23+
) throws -> Future<PostgreSQLConnection> {
24+
let handler = QueueHandler<PostgreSQLMessage, PostgreSQLMessage>(on: worker, onError: onError)
25+
let bootstrap = ClientBootstrap(group: worker.eventLoop)
26+
// Enable SO_REUSEADDR.
27+
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
28+
.channelInitializer { channel in
29+
return channel.pipeline.addPostgreSQLClientHandlers().then {
30+
channel.pipeline.add(handler: handler)
31+
}
32+
}
33+
34+
let connectedBootstrap: Future<Channel>
35+
switch serverAddress {
36+
case let .tcp(hostname, port):
37+
connectedBootstrap = bootstrap.connect(host: hostname, port: port)
38+
case let .unixSocket(socketPath):
39+
connectedBootstrap = bootstrap.connect(unixDomainSocketPath: socketPath)
40+
}
41+
42+
return connectedBootstrap.flatMap { channel in
43+
let connection = PostgreSQLConnection(queue: handler, channel: channel)
44+
if case .tls(let tlsConfiguration) = transport.method {
45+
return connection.addSSLClientHandler(using: tlsConfiguration).transform(to: connection)
46+
} else {
47+
return worker.eventLoop.newSucceededFuture(result: connection)
48+
}
49+
}
3950
}
4051
}
4152

Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,20 @@ public struct PostgreSQLDatabaseConfig {
77
public static func `default`() -> PostgreSQLDatabaseConfig {
88
return .init(hostname: "localhost", port: 5432, username: "postgres")
99
}
10-
11-
public enum ServerAddress {
12-
case tcp(hostname: String, port: Int)
13-
case unixSocket(path: String)
14-
15-
public static let `default` = ServerAddress.tcp(hostname: "localhost", port: 5432)
16-
public static let defaultViaSocket = ServerAddress.unixSocket(path: "/tmp/.s.PGSQL.5432")
17-
}
18-
19-
public let serverAddress: ServerAddress
10+
11+
/// Specifies how to connect to a PostgreSQL server.
12+
public enum ServerAddress {
13+
/// Connect via TCP using the given hostname and port.
14+
case tcp(hostname: String, port: Int)
15+
/// Connect via a Unix domain socket file.
16+
case unixSocket(path: String)
17+
18+
public static let `default` = ServerAddress.tcp(hostname: "localhost", port: 5432)
19+
public static let socketDefault = ServerAddress.unixSocket(path: "/tmp/.s.PGSQL.5432")
20+
}
21+
22+
/// Which server to connect to.
23+
public let serverAddress: ServerAddress
2024

2125
/// Username to authenticate.
2226
public let username: String

Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,18 +493,18 @@ extension PostgreSQLConnection {
493493
/// Creates a test event loop and psql client over ssl.
494494
static func makeTest(transport: PostgreSQLTransportConfig) throws -> PostgreSQLConnection {
495495
#if os(macOS)
496-
return try _makeTest(serverAddress: .tcp(hostname: "192.168.99.100", port: transport.isTLS ? 5433 : 5432), password: "vapor_password", transport: transport)
496+
return try _makeTest(serverAddress: .tcp(hostname: "192.168.99.100", port: transport.isTLS ? 5433 : 5432), password: "vapor_password", transport: transport)
497497
#else
498-
return try _makeTest(serverAddress: .tcp(hostname: transport.isTLS ? "tls" : "cleartext", port: 5432), password: "vapor_password", transport: transport)
498+
return try _makeTest(serverAddress: .tcp(hostname: transport.isTLS ? "tls" : "cleartext", port: 5432), password: "vapor_password", transport: transport)
499499
#endif
500500
}
501501

502502
/// Creates a test connection.
503-
private static func _makeTest(serverAddress: PostgreSQLDatabaseConfig.ServerAddress, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) throws -> PostgreSQLConnection {
503+
private static func _makeTest(serverAddress: PostgreSQLDatabaseConfig.ServerAddress, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) throws -> PostgreSQLConnection {
504504
let group = MultiThreadedEventLoopGroup(numThreads: 1)
505-
let client = try PostgreSQLConnection.connect(to: serverAddress, transport: transport, on: group) { error in
505+
let client = try PostgreSQLConnection.connect(to: serverAddress, transport: transport, on: group) { error in
506506
XCTFail("\(error)")
507-
}.wait()
507+
}.wait()
508508
_ = try client.authenticate(username: "vapor_username", database: "vapor_database", password: password).wait()
509509
return client
510510
}

0 commit comments

Comments
 (0)