Skip to content

Commit c1683ba

Browse files
fabianfettgwynne
andauthored
Make forward progress when Query is cancelled (vapor#261)
Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>
1 parent ab624e4 commit c1683ba

File tree

9 files changed

+262
-14
lines changed

9 files changed

+262
-14
lines changed

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,15 @@ struct ConnectionStateMachine {
842842
// MARK: Consumer
843843

844844
mutating func cancelQueryStream() -> ConnectionAction {
845-
preconditionFailure("Unimplemented")
845+
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
846+
preconditionFailure("Tried to cancel stream without active query")
847+
}
848+
849+
return self.avoidingStateMachineCoW { machine -> ConnectionAction in
850+
let action = queryState.cancel()
851+
machine.state = .extendedQuery(queryState, connectionContext)
852+
return machine.modify(with: action)
853+
}
846854
}
847855

848856
mutating func requestQueryRows() -> ConnectionAction {
@@ -1074,6 +1082,8 @@ extension ConnectionStateMachine {
10741082
return true
10751083
case .failedToAddSSLHandler:
10761084
return true
1085+
case .queryCancelled:
1086+
return false
10771087
case .server(let message):
10781088
guard let sqlState = message.fields[.sqlState] else {
10791089
// any error message that doesn't have a sql state field, is unexpected by default.

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import NIOCore
22

33
struct ExtendedQueryStateMachine {
44

5-
enum State {
5+
private enum State {
66
case initialized(ExtendedQueryContext)
77
case parseDescribeBindExecuteSyncSent(ExtendedQueryContext)
88

@@ -15,6 +15,8 @@ struct ExtendedQueryStateMachine {
1515
/// used after receiving a `bindComplete` message
1616
case bindCompleteReceived(ExtendedQueryContext)
1717
case streaming([RowDescription.Column], RowStreamStateMachine)
18+
/// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP
19+
case drain([RowDescription.Column])
1820

1921
case commandComplete(commandTag: String)
2022
case error(PSQLError)
@@ -41,9 +43,11 @@ struct ExtendedQueryStateMachine {
4143
case wait
4244
}
4345

44-
var state: State
46+
private var state: State
47+
private var isCancelled: Bool
4548

4649
init(queryContext: ExtendedQueryContext) {
50+
self.isCancelled = false
4751
self.state = .initialized(queryContext)
4852
}
4953

@@ -71,6 +75,44 @@ struct ExtendedQueryStateMachine {
7175
}
7276
}
7377
}
78+
79+
mutating func cancel() -> Action {
80+
switch self.state {
81+
case .initialized:
82+
preconditionFailure("Start must be called immediatly after the query was created")
83+
84+
case .parseDescribeBindExecuteSyncSent(let queryContext),
85+
.parseCompleteReceived(let queryContext),
86+
.parameterDescriptionReceived(let queryContext),
87+
.rowDescriptionReceived(let queryContext, _),
88+
.noDataMessageReceived(let queryContext),
89+
.bindCompleteReceived(let queryContext):
90+
guard !self.isCancelled else {
91+
return .wait
92+
}
93+
94+
self.isCancelled = true
95+
return .failQuery(queryContext, with: .queryCancelled)
96+
97+
case .streaming(let columns, var streamStateMachine):
98+
precondition(!self.isCancelled)
99+
self.isCancelled = true
100+
self.state = .drain(columns)
101+
switch streamStateMachine.fail() {
102+
case .wait:
103+
return .forwardStreamError(.queryCancelled, read: false)
104+
case .read:
105+
return .forwardStreamError(.queryCancelled, read: true)
106+
}
107+
108+
case .commandComplete, .error, .drain:
109+
// the stream has already finished.
110+
return .wait
111+
112+
case .modifying:
113+
preconditionFailure("Invalid state: \(self.state)")
114+
}
115+
}
74116

75117
mutating func parseCompletedReceived() -> Action {
76118
guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else {
@@ -147,9 +189,11 @@ struct ExtendedQueryStateMachine {
147189
.parameterDescriptionReceived,
148190
.bindCompleteReceived,
149191
.streaming,
192+
.drain,
150193
.commandComplete,
151194
.error:
152195
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))
196+
153197
case .modifying:
154198
preconditionFailure("Invalid state")
155199
}
@@ -169,6 +213,13 @@ struct ExtendedQueryStateMachine {
169213
state = .streaming(columns, demandStateMachine)
170214
return .wait
171215
}
216+
217+
case .drain(let columns):
218+
guard dataRow.columnCount == columns.count else {
219+
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
220+
}
221+
// we ignore all rows and wait for readyForQuery
222+
return .wait
172223

173224
case .initialized,
174225
.parseDescribeBindExecuteSyncSent,
@@ -198,6 +249,11 @@ struct ExtendedQueryStateMachine {
198249
state = .commandComplete(commandTag: commandTag)
199250
return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag)
200251
}
252+
253+
case .drain:
254+
precondition(self.isCancelled)
255+
self.state = .commandComplete(commandTag: commandTag)
256+
return .wait
201257

202258
case .initialized,
203259
.parseDescribeBindExecuteSyncSent,
@@ -229,7 +285,7 @@ struct ExtendedQueryStateMachine {
229285
return self.setAndFireError(error)
230286
case .rowDescriptionReceived, .noDataMessageReceived:
231287
return self.setAndFireError(error)
232-
case .streaming:
288+
case .streaming, .drain:
233289
return self.setAndFireError(error)
234290
case .commandComplete:
235291
return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage)))
@@ -269,6 +325,9 @@ struct ExtendedQueryStateMachine {
269325
}
270326
}
271327

328+
case .drain:
329+
return .wait
330+
272331
case .initialized,
273332
.parseDescribeBindExecuteSyncSent,
274333
.parseCompleteReceived,
@@ -291,6 +350,7 @@ struct ExtendedQueryStateMachine {
291350
switch self.state {
292351
case .initialized,
293352
.commandComplete,
353+
.drain,
294354
.error,
295355
.parseDescribeBindExecuteSyncSent,
296356
.parseCompleteReceived,
@@ -327,6 +387,7 @@ struct ExtendedQueryStateMachine {
327387
.bindCompleteReceived:
328388
return .read
329389
case .streaming(let columns, var demandStateMachine):
390+
precondition(!self.isCancelled)
330391
return self.avoidingStateMachineCoW { state -> Action in
331392
let action = demandStateMachine.read()
332393
state = .streaming(columns, demandStateMachine)
@@ -339,6 +400,7 @@ struct ExtendedQueryStateMachine {
339400
}
340401
case .initialized,
341402
.commandComplete,
403+
.drain,
342404
.error:
343405
// we already have the complete stream received, now we are waiting for a
344406
// `readyForQuery` package. To receive this we need to read!
@@ -361,11 +423,20 @@ struct ExtendedQueryStateMachine {
361423
.bindCompleteReceived(let context):
362424
self.state = .error(error)
363425
return .failQuery(context, with: error)
364-
365-
case .streaming:
426+
427+
case .drain:
366428
self.state = .error(error)
367429
return .forwardStreamError(error, read: false)
368430

431+
case .streaming(_, var streamStateMachine):
432+
self.state = .error(error)
433+
switch streamStateMachine.fail() {
434+
case .wait:
435+
return .forwardStreamError(error, read: false)
436+
case .read:
437+
return .forwardStreamError(error, read: true)
438+
}
439+
369440
case .commandComplete, .error:
370441
preconditionFailure("""
371442
This state must not be reached. If the query `.isComplete`, the

Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct RowStreamStateMachine {
2323
/// preserved for performance reasons.
2424
case waitingForDemand([DataRow])
2525

26+
case failed
27+
2628
case modifying
2729
}
2830

@@ -63,6 +65,11 @@ struct RowStreamStateMachine {
6365
buffer.append(newRow)
6466
self.state = .waitingForReadOrDemand(buffer)
6567

68+
case .failed:
69+
// Once the row stream state machine is marked as failed, no further events must be
70+
// forwarded to it.
71+
preconditionFailure("Invalid state: \(self.state)")
72+
6673
case .modifying:
6774
preconditionFailure("Invalid state: \(self.state)")
6875
}
@@ -86,6 +93,11 @@ struct RowStreamStateMachine {
8693
.waitingForReadOrDemand:
8794
preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)")
8895

96+
case .failed:
97+
// Once the row stream state machine is marked as failed, no further events must be
98+
// forwarded to it.
99+
preconditionFailure("Invalid state: \(self.state)")
100+
89101
case .modifying:
90102
preconditionFailure("Invalid state: \(self.state)")
91103
}
@@ -111,6 +123,11 @@ struct RowStreamStateMachine {
111123
// the next `channelReadComplete` we will forward all buffered data
112124
return .wait
113125

126+
case .failed:
127+
// Once the row stream state machine is marked as failed, no further events must be
128+
// forwarded to it.
129+
preconditionFailure("Invalid state: \(self.state)")
130+
114131
case .modifying:
115132
preconditionFailure("Invalid state: \(self.state)")
116133
}
@@ -136,6 +153,11 @@ struct RowStreamStateMachine {
136153
// from the consumer
137154
return .wait
138155

156+
case .failed:
157+
// Once the row stream state machine is marked as failed, no further events must be
158+
// forwarded to it.
159+
preconditionFailure("Invalid state: \(self.state)")
160+
139161
case .modifying:
140162
preconditionFailure("Invalid state: \(self.state)")
141163
}
@@ -158,6 +180,33 @@ struct RowStreamStateMachine {
158180
// receive a call to `end()`, when we don't expect it here.
159181
return buffer
160182

183+
case .failed:
184+
// Once the row stream state machine is marked as failed, no further events must be
185+
// forwarded to it.
186+
preconditionFailure("Invalid state: \(self.state)")
187+
188+
case .modifying:
189+
preconditionFailure("Invalid state: \(self.state)")
190+
}
191+
}
192+
193+
mutating func fail() -> Action {
194+
switch self.state {
195+
case .waitingForRows,
196+
.waitingForReadOrDemand,
197+
.waitingForRead:
198+
self.state = .failed
199+
return .wait
200+
201+
case .waitingForDemand:
202+
self.state = .failed
203+
return .read
204+
205+
case .failed:
206+
// Once the row stream state machine is marked as failed, no further events must be
207+
// forwarded to it.
208+
preconditionFailure("Invalid state: \(self.state)")
209+
161210
case .modifying:
162211
preconditionFailure("Invalid state: \(self.state)")
163212
}

Sources/PostgresNIO/New/PSQLError.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ struct PSQLError: Error {
1111
case unsupportedAuthMechanism(PSQLAuthScheme)
1212
case authMechanismRequiresPassword
1313
case saslError(underlyingError: Error)
14-
14+
15+
case queryCancelled
1516
case tooManyParameters
1617
case connectionQuiescing
1718
case connectionClosed
@@ -58,6 +59,10 @@ struct PSQLError: Error {
5859
static func sasl(underlying: Error) -> PSQLError {
5960
Self.init(.saslError(underlyingError: underlying))
6061
}
62+
63+
static var queryCancelled: PSQLError {
64+
Self.init(.queryCancelled)
65+
}
6166

6267
static var tooManyParameters: PSQLError {
6368
Self.init(.tooManyParameters)

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
1818
/// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed).
1919
///
2020
/// The context is captured in `handlerAdded` and released` in `handlerRemoved`
21-
private var handlerContext: ChannelHandlerContext!
21+
private var handlerContext: ChannelHandlerContext?
2222
private var rowStream: PSQLRowStream?
2323
private var decoder: NIOSingleStepByteToMessageProcessor<PostgresBackendMessageDecoder>
2424
private var encoder: BufferedMessageEncoder!
@@ -262,7 +262,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
262262

263263
case .forwardStreamComplete(let buffer, let commandTag):
264264
guard let rowStream = self.rowStream else {
265-
preconditionFailure("Expected to have a row stream here.")
265+
// if the stream was cancelled we don't have it here anymore.
266+
return
266267
}
267268
self.rowStream = nil
268269
if buffer.count > 0 {
@@ -499,18 +500,20 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
499500

500501
extension PostgresChannelHandler: PSQLRowsDataSource {
501502
func request(for stream: PSQLRowStream) {
502-
guard self.rowStream === stream else {
503+
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
503504
return
504505
}
505506
let action = self.state.requestQueryRows()
506-
self.run(action, with: self.handlerContext!)
507+
self.run(action, with: handlerContext)
507508
}
508509

509510
func cancel(for stream: PSQLRowStream) {
510-
guard self.rowStream === stream else {
511+
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
511512
return
512513
}
513514
// we ignore this right now :)
515+
let action = self.state.cancelQueryStream()
516+
self.run(action, with: handlerContext)
514517
}
515518
}
516519

@@ -519,7 +522,8 @@ extension PostgresConnection.Configuration.Authentication {
519522
AuthContext(
520523
username: self.username,
521524
password: self.password,
522-
database: self.database)
525+
database: self.database
526+
)
523527
}
524528
}
525529

@@ -529,7 +533,8 @@ extension AuthContext {
529533
user: self.username,
530534
database: self.database,
531535
options: nil,
532-
replication: .false)
536+
replication: .false
537+
)
533538
}
534539
}
535540

Sources/PostgresNIO/Postgres+PSQLCompat.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import NIOCore
33
extension PSQLError {
44
func toPostgresError() -> Error {
55
switch self.base {
6+
case .queryCancelled:
7+
return self
68
case .server(let errorMessage):
79
var fields = [PostgresMessage.Error.Field: String]()
810
fields.reserveCapacity(errorMessage.fields.count)

0 commit comments

Comments
 (0)