Skip to content

Commit 1cd8d36

Browse files
authored
Add support for Network.framework (vapor#253)
1 parent c7edb9b commit 1cd8d36

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

Diff for: Package.swift

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ let package = Package(
1414
],
1515
dependencies: [
1616
.package(url: "https://github.com/apple/swift-nio.git", from: "2.35.0"),
17+
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"),
1718
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"),
1819
.package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"),
1920
.package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"),
@@ -27,6 +28,7 @@ let package = Package(
2728
.product(name: "NIO", package: "swift-nio"),
2829
.product(name: "NIOCore", package: "swift-nio"),
2930
.product(name: "NIOPosix", package: "swift-nio"),
31+
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
3032
.product(name: "NIOTLS", package: "swift-nio"),
3133
.product(name: "NIOSSL", package: "swift-nio-ssl"),
3234
.product(name: "NIOFoundationCompat", package: "swift-nio"),

Diff for: Sources/PostgresNIO/Connection/PostgresConnection.swift

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import NIOCore
22
import NIOConcurrencyHelpers
3+
#if canImport(Network)
4+
import NIOTransportServices
5+
#endif
36
import NIOSSL
47
import Logging
58
import NIOPosix
@@ -249,12 +252,13 @@ public final class PostgresConnection {
249252
// thread and the EventLoop.
250253
return eventLoop.flatSubmit { () -> EventLoopFuture<PostgresConnection> in
251254
let connectFuture: EventLoopFuture<Channel>
255+
let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration)
252256

253257
switch configuration.connection {
254258
case .resolved(let address, _):
255-
connectFuture = ClientBootstrap(group: eventLoop).connect(to: address)
259+
connectFuture = bootstrap.connect(to: address)
256260
case .unresolved(let host, let port):
257-
connectFuture = ClientBootstrap(group: eventLoop).connect(host: host, port: port)
261+
connectFuture = bootstrap.connect(host: host, port: port)
258262
}
259263

260264
return connectFuture.flatMap { channel -> EventLoopFuture<PostgresConnection> in
@@ -271,6 +275,23 @@ public final class PostgresConnection {
271275
}
272276
}
273277

278+
static func makeBootstrap(
279+
on eventLoop: EventLoop,
280+
configuration: PostgresConnection.InternalConfiguration
281+
) -> NIOClientTCPBootstrapProtocol {
282+
#if canImport(Network)
283+
if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) {
284+
return tsBootstrap
285+
}
286+
#endif
287+
288+
if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) {
289+
return nioBootstrap
290+
}
291+
292+
fatalError("No matching bootstrap found")
293+
}
294+
274295
// MARK: Query
275296

276297
func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture<PSQLRowStream> {

Diff for: Tests/IntegrationTests/AsyncTests.swift

+25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import Logging
22
import XCTest
33
import PostgresNIO
4+
#if canImport(Network)
5+
import NIOTransportServices
6+
#endif
47

58
#if swift(>=5.5.2)
69
final class AsyncPostgresConnectionTests: XCTestCase {
@@ -41,6 +44,28 @@ final class AsyncPostgresConnectionTests: XCTestCase {
4144
XCTAssertEqual(counter, end + 1)
4245
}
4346
}
47+
48+
#if canImport(Network)
49+
func testSelect10kRowsNetworkFramework() async throws {
50+
let eventLoopGroup = NIOTSEventLoopGroup()
51+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
52+
let eventLoop = eventLoopGroup.next()
53+
54+
let start = 1
55+
let end = 10000
56+
57+
try await withTestConnection(on: eventLoop) { connection in
58+
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
59+
var counter = 1
60+
for try await element in rows.decode(Int.self, context: .default) {
61+
XCTAssertEqual(element, counter)
62+
counter += 1
63+
}
64+
65+
XCTAssertEqual(counter, end + 1)
66+
}
67+
}
68+
#endif
4469
}
4570

4671
extension XCTestCase {

0 commit comments

Comments
 (0)