Skip to content

Commit dbf9c2e

Browse files
authored
Fix row stream cancel/error behavior (vapor#353)
1 parent c692eda commit dbf9c2e

File tree

6 files changed

+200
-8
lines changed

6 files changed

+200
-8
lines changed

Diff for: Package.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ let package = Package(
1414
],
1515
dependencies: [
1616
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"),
17-
.package(url: "https://github.com/apple/swift-nio.git", from: "2.50.0"),
17+
.package(url: "https://github.com/apple/swift-nio.git", from: "2.51.1"),
1818
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"),
1919
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"),
2020
.package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"),

Diff for: Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

+9-1
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ struct ConnectionStateMachine {
841841
// MARK: Consumer
842842

843843
mutating func cancelQueryStream() -> ConnectionAction {
844-
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
844+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
845845
preconditionFailure("Tried to cancel stream without active query")
846846
}
847847

@@ -926,6 +926,8 @@ struct ConnectionStateMachine {
926926
.wait,
927927
.read:
928928
preconditionFailure("Expecting only failure actions if an error happened")
929+
case .evaluateErrorAtConnectionLevel:
930+
return .closeConnectionAndCleanup(cleanupContext)
929931
case .failQuery(let queryContext, with: let error):
930932
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
931933
case .forwardStreamError(let error, let read):
@@ -1169,6 +1171,12 @@ extension ConnectionStateMachine {
11691171
case .forwardStreamError(let error, let read):
11701172
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
11711173
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
1174+
1175+
case .evaluateErrorAtConnectionLevel(let error):
1176+
if let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) {
1177+
return .closeConnectionAndCleanup(cleanupContext)
1178+
}
1179+
return .wait
11721180
case .read:
11731181
return .read
11741182
case .wait:

Diff for: Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

+9-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ struct ExtendedQueryStateMachine {
3232
case failQuery(ExtendedQueryContext, with: PSQLError)
3333
case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column])
3434
case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String)
35-
35+
36+
case evaluateErrorAtConnectionLevel(PSQLError)
37+
3638
// --- streaming actions
3739
// actions if query has requested next row but we are waiting for backend
3840
case forwardRows([DataRow])
@@ -422,11 +424,15 @@ struct ExtendedQueryStateMachine {
422424
.noDataMessageReceived(let context),
423425
.bindCompleteReceived(let context):
424426
self.state = .error(error)
425-
return .failQuery(context, with: error)
427+
if self.isCancelled {
428+
return .evaluateErrorAtConnectionLevel(error)
429+
} else {
430+
return .failQuery(context, with: error)
431+
}
426432

427433
case .drain:
428434
self.state = .error(error)
429-
return .forwardStreamError(error, read: false)
435+
return .evaluateErrorAtConnectionLevel(error)
430436

431437
case .streaming(_, var streamStateMachine):
432438
self.state = .error(error)

Diff for: Sources/PostgresNIO/New/PostgresChannelHandler.swift

+1-3
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
273273

274274

275275
case .forwardStreamError(let error, let read, let cleanupContext):
276-
let rowStream = self.rowStream!
276+
self.rowStream!.receive(completion: .failure(error))
277277
self.rowStream = nil
278-
rowStream.receive(completion: .failure(error))
279278
if let cleanupContext = cleanupContext {
280279
self.closeConnectionAndCleanup(cleanupContext, context: context)
281280
} else if read {
@@ -512,7 +511,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource {
512511
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
513512
return
514513
}
515-
// we ignore this right now :)
516514
let action = self.state.cancelQueryStream()
517515
self.run(action, with: handlerContext)
518516
}

Diff for: Tests/IntegrationTests/AsyncTests.swift

+84
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,43 @@ final class AsyncPostgresConnectionTests: XCTestCase {
7575
}
7676

7777
XCTAssertFalse(connection.isClosed, "Connection should survive!")
78+
79+
for num in 0..<10 {
80+
for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) {
81+
XCTAssertEqual(decoded, num)
82+
}
83+
}
84+
}
85+
}
86+
87+
func testConnectionSurvives1kQueriesWithATypo() async throws {
88+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
89+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
90+
let eventLoop = eventLoopGroup.next()
91+
92+
let start = 1
93+
let end = 10000
94+
95+
try await withTestConnection(on: eventLoop) { connection -> () in
96+
for _ in 0..<1000 {
97+
do {
98+
try await connection.query("SELECT generte_series(\(start), \(end));", logger: .psqlTest)
99+
XCTFail("Expected to throw from the request")
100+
} catch {
101+
guard let error = error as? PSQLError else { return XCTFail("Unexpected error type: \(error)") }
102+
103+
XCTAssertEqual(error.code, .server)
104+
XCTAssertEqual(error.serverInfo?[.severity], "ERROR")
105+
}
106+
}
107+
108+
// the connection survived all of this, we can still run normal queries:
109+
110+
for num in 0..<10 {
111+
for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) {
112+
XCTAssertEqual(decoded, num)
113+
}
114+
}
78115
}
79116
}
80117

@@ -172,6 +209,53 @@ final class AsyncPostgresConnectionTests: XCTestCase {
172209
}
173210
}
174211
#endif
212+
213+
func testCancelTaskThatIsVeryLongRunningWhichAlsoFailsWhileInStreamingMode() async throws {
214+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
215+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
216+
let eventLoop = eventLoopGroup.next()
217+
218+
// we cancel the query after 400ms.
219+
// the server times out the query after 1sec.
220+
221+
try await withTestConnection(on: eventLoop) { connection -> () in
222+
try await connection.query("SET statement_timeout=1000;", logger: .psqlTest) // 1000 milliseconds
223+
224+
try await withThrowingTaskGroup(of: Void.self) { group in
225+
group.addTask {
226+
let start = 1
227+
let end = 100_000_000
228+
229+
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
230+
var counter = 0
231+
do {
232+
for try await element in rows.decode(Int.self, context: .default) {
233+
XCTAssertEqual(element, counter + 1)
234+
counter += 1
235+
}
236+
XCTFail("Expected to get cancelled while reading the query")
237+
XCTAssertEqual(counter, end)
238+
} catch let error as CancellationError {
239+
XCTAssertGreaterThanOrEqual(counter, 1)
240+
// Expected
241+
print("\(error)")
242+
} catch {
243+
XCTFail("Unexpected error: \(error)")
244+
}
245+
246+
XCTAssertTrue(Task.isCancelled)
247+
XCTAssertFalse(connection.isClosed, "Connection should survive!")
248+
}
249+
250+
let delay: UInt64 = 400_000_000 // 400 milliseconds
251+
try await Task.sleep(nanoseconds: delay)
252+
253+
group.cancelAll()
254+
}
255+
256+
try await connection.query("SELECT 1;", logger: .psqlTest)
257+
}
258+
}
175259
}
176260

177261
extension XCTestCase {

Diff for: Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

+96
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,100 @@ class ExtendedQueryStateMachineTests: XCTestCase {
181181
XCTAssertEqual(state.commandCompletedReceived("SELECT 4"), .wait)
182182
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
183183
}
184+
185+
func testCancelQueryAfterServerError() {
186+
var state = ConnectionStateMachine.readyForQuery()
187+
188+
let logger = Logger.psqlTest
189+
let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self)
190+
promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all.
191+
let query: PostgresQuery = "SELECT version()"
192+
let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise)
193+
194+
XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
195+
XCTAssertEqual(state.parseCompleteReceived(), .wait)
196+
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
197+
198+
// We need to ensure that even though the row description from the wire says that we
199+
// will receive data in `.text` format, we will actually receive it in binary format,
200+
// since we requested it in binary with our bind message.
201+
let input: [RowDescription.Column] = [
202+
.init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text)
203+
]
204+
let expected: [RowDescription.Column] = input.map {
205+
.init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType,
206+
dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary)
207+
}
208+
209+
XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait)
210+
XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected))
211+
let dataRows1: [DataRow] = [
212+
[ByteBuffer(string: "test1")],
213+
[ByteBuffer(string: "test2")],
214+
[ByteBuffer(string: "test3")]
215+
]
216+
for row in dataRows1 {
217+
XCTAssertEqual(state.dataRowReceived(row), .wait)
218+
}
219+
XCTAssertEqual(state.channelReadComplete(), .forwardRows(dataRows1))
220+
XCTAssertEqual(state.readEventCaught(), .wait)
221+
XCTAssertEqual(state.requestQueryRows(), .read)
222+
let dataRows2: [DataRow] = [
223+
[ByteBuffer(string: "test4")],
224+
[ByteBuffer(string: "test5")],
225+
[ByteBuffer(string: "test6")]
226+
]
227+
for row in dataRows2 {
228+
XCTAssertEqual(state.dataRowReceived(row), .wait)
229+
}
230+
let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"])
231+
XCTAssertEqual(state.errorReceived(serverError), .forwardStreamError(.server(serverError), read: false, cleanupContext: .none))
232+
233+
XCTAssertEqual(state.channelReadComplete(), .wait)
234+
XCTAssertEqual(state.readEventCaught(), .read)
235+
236+
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
237+
}
238+
239+
func testQueryErrorDoesNotKillConnection() {
240+
var state = ConnectionStateMachine.readyForQuery()
241+
242+
let logger = Logger.psqlTest
243+
let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self)
244+
promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all.
245+
let query: PostgresQuery = "SELECT version()"
246+
let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise)
247+
248+
XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
249+
XCTAssertEqual(state.parseCompleteReceived(), .wait)
250+
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
251+
252+
let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"])
253+
XCTAssertEqual(
254+
state.errorReceived(serverError), .failQuery(queryContext, with: .server(serverError), cleanupContext: .none)
255+
)
256+
257+
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
258+
}
259+
260+
func testQueryErrorAfterCancelDoesNotKillConnection() {
261+
var state = ConnectionStateMachine.readyForQuery()
262+
263+
let logger = Logger.psqlTest
264+
let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self)
265+
promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all.
266+
let query: PostgresQuery = "SELECT version()"
267+
let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise)
268+
269+
XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
270+
XCTAssertEqual(state.parseCompleteReceived(), .wait)
271+
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
272+
XCTAssertEqual(state.cancelQueryStream(), .failQuery(queryContext, with: .queryCancelled, cleanupContext: .none))
273+
274+
let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"])
275+
XCTAssertEqual(state.errorReceived(serverError), .wait)
276+
277+
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
278+
}
279+
184280
}

0 commit comments

Comments
 (0)