Skip to content

Commit add68a0

Browse files
authored
Ensure pool runs until all connections are closed (vapor#429)
- Ensure pool runs until all connections are closed - Fix an ordering issue in `RequestQueue` - Remove unused `closeConnection` in NewPoolActions
1 parent 468ae25 commit add68a0

File tree

5 files changed

+82
-20
lines changed

5 files changed

+82
-20
lines changed

Diff for: Sources/ConnectionPoolModule/ConnectionPool.swift

+9-6
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ public final class ConnectionPool<
306306
@usableFromInline
307307
enum NewPoolActions: Sendable {
308308
case makeConnection(StateMachine.ConnectionRequest)
309-
case closeConnection(Connection)
310309
case runKeepAlive(Connection)
311310

312311
case scheduleTimer(StateMachine.Timer)
@@ -342,9 +341,6 @@ public final class ConnectionPool<
342341
case .runKeepAlive(let connection):
343342
self.runKeepAlive(connection, in: &taskGroup)
344343

345-
case .closeConnection(let connection):
346-
self.closeConnection(connection)
347-
348344
case .scheduleTimer(let timer):
349345
self.runTimer(timer, in: &taskGroup)
350346
}
@@ -427,8 +423,15 @@ public final class ConnectionPool<
427423
do {
428424
let bundle = try await self.factory(request.connectionID, self)
429425
self.connectionEstablished(bundle)
430-
bundle.connection.onClose {
431-
self.connectionDidClose(bundle.connection, error: $0)
426+
427+
// after the connection has been established, we keep the task open. This ensures
428+
// that the pools run method can not be exited before all connections have been
429+
// closed.
430+
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
431+
bundle.connection.onClose {
432+
self.connectionDidClose(bundle.connection, error: $0)
433+
continuation.resume()
434+
}
432435
}
433436
} catch {
434437
self.connectionEstablishFailed(error, for: request)

Diff for: Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ extension PoolStateMachine {
4444
var result = TinyFastSequence<Request>()
4545
result.reserveCapacity(Int(max))
4646
var popped = 0
47-
while let requestID = self.queue.popFirst(), popped < max {
47+
while popped < max, let requestID = self.queue.popFirst() {
4848
if let requestIndex = self.requests.index(forKey: requestID) {
4949
popped += 1
5050
result.append(self.requests.remove(at: requestIndex).value)

Diff for: Sources/ConnectionPoolModule/PoolStateMachine.swift

+15-9
Original file line numberDiff line numberDiff line change
@@ -355,18 +355,24 @@ struct PoolStateMachine<
355355

356356
@inlinable
357357
mutating func connectionClosed(_ connection: Connection) -> Action {
358-
self.cacheNoMoreConnectionsAllowed = false
358+
switch self.poolState {
359+
case .running, .shuttingDown(graceful: true):
360+
self.cacheNoMoreConnectionsAllowed = false
359361

360-
let closedConnectionAction = self.connections.connectionClosed(connection.id)
362+
let closedConnectionAction = self.connections.connectionClosed(connection.id)
361363

362-
let connectionAction: ConnectionAction
363-
if let newRequest = closedConnectionAction.newConnectionRequest {
364-
connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel)
365-
} else {
366-
connectionAction = .cancelTimers(closedConnectionAction.timersToCancel)
367-
}
364+
let connectionAction: ConnectionAction
365+
if let newRequest = closedConnectionAction.newConnectionRequest {
366+
connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel)
367+
} else {
368+
connectionAction = .cancelTimers(closedConnectionAction.timersToCancel)
369+
}
370+
371+
return .init(request: .none, connection: connectionAction)
368372

369-
return .init(request: .none, connection: connectionAction)
373+
case .shuttingDown(graceful: false), .shutDown:
374+
return .none()
375+
}
370376
}
371377

372378
struct CleanupAction {

Diff for: Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift

+40-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testable import _ConnectionPoolModule
2+
import Atomics
23
import XCTest
34
import NIOEmbedded
45

@@ -52,7 +53,14 @@ final class ConnectionPoolTests: XCTestCase {
5253
}
5354

5455
taskGroup.cancelAll()
56+
57+
XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0)
58+
for connection in factory.runningConnections {
59+
connection.closeIfClosing()
60+
}
5561
}
62+
63+
XCTAssertEqual(factory.runningConnections.count, 0)
5664
}
5765

5866
func testShutdownPoolWhileConnectionIsBeingCreated() async {
@@ -155,34 +163,62 @@ final class ConnectionPoolTests: XCTestCase {
155163
try await factory.makeConnection(id: $0, for: $1)
156164
}
157165

166+
let hasFinished = ManagedAtomic(false)
167+
let createdConnections = ManagedAtomic(0)
168+
let iterations = 10_000
169+
158170
// the same connection is reused 1000 times
159171

160-
await withThrowingTaskGroup(of: Void.self) { taskGroup in
172+
await withTaskGroup(of: Void.self) { taskGroup in
161173
taskGroup.addTask {
162174
await pool.run()
175+
XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original)
163176
}
164177

165178
taskGroup.addTask {
166179
var usedConnectionIDs = Set<Int>()
167180
for _ in 0..<config.maximumConnectionHardLimit {
168181
await factory.nextConnectAttempt { connectionID in
169182
XCTAssertTrue(usedConnectionIDs.insert(connectionID).inserted)
183+
createdConnections.wrappingIncrement(ordering: .relaxed)
170184
return 1
171185
}
172186
}
173187

188+
174189
XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0)
175190
}
176191

177-
for _ in 0..<10_000 {
192+
let (stream, continuation) = AsyncStream.makeStream(of: Void.self)
193+
194+
for _ in 0..<iterations {
178195
taskGroup.addTask {
179-
let leasedConnection = try await pool.leaseConnection()
180-
pool.releaseConnection(leasedConnection)
196+
do {
197+
let leasedConnection = try await pool.leaseConnection()
198+
pool.releaseConnection(leasedConnection)
199+
} catch {
200+
XCTFail("Unexpected error: \(error)")
201+
}
202+
continuation.yield()
181203
}
182204
}
183205

206+
var leaseReleaseIterator = stream.makeAsyncIterator()
207+
for _ in 0..<iterations {
208+
_ = await leaseReleaseIterator.next()
209+
}
210+
184211
taskGroup.cancelAll()
212+
213+
XCTAssertFalse(hasFinished.load(ordering: .relaxed))
214+
for connection in factory.runningConnections {
215+
connection.closeIfClosing()
216+
}
185217
}
218+
219+
XCTAssertEqual(createdConnections.load(ordering: .relaxed), config.maximumConnectionHardLimit)
220+
XCTAssert(hasFinished.load(ordering: .relaxed))
221+
XCTAssertEqual(factory.runningConnections.count, 0)
186222
}
187223
}
188224

Diff for: Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift

+17
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,18 @@ final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duratio
8888
var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>()
8989

9090
var waiter = Deque<CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>>()
91+
92+
var runningConnections = [ConnectionID: Connection]()
9193
}
9294

9395
var pendingConnectionAttemptsCount: Int {
9496
self.stateBox.withLockedValue { $0.attempts.count }
9597
}
9698

99+
var runningConnections: [Connection] {
100+
self.stateBox.withLockedValue { Array($0.runningConnections.values) }
101+
}
102+
97103
func makeConnection(
98104
id: Int,
99105
for pool: ConnectionPool<MockConnection, Int, ConnectionIDGenerator, ConnectionRequest<MockConnection>, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics<Int>, Clock>
@@ -137,6 +143,17 @@ final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duratio
137143
do {
138144
let streamCount = try await closure(connectionID)
139145
let connection = MockConnection(id: connectionID)
146+
147+
connection.onClose { _ in
148+
self.stateBox.withLockedValue { state in
149+
_ = state.runningConnections.removeValue(forKey: connectionID)
150+
}
151+
}
152+
153+
self.stateBox.withLockedValue { state in
154+
_ = state.runningConnections[connectionID] = connection
155+
}
156+
140157
continuation.resume(returning: (connection, streamCount))
141158
return connection
142159
} catch {

0 commit comments

Comments
 (0)