Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,15 @@ struct ConnectionStateMachine {
// MARK: Consumer

mutating func cancelQueryStream() -> ConnectionAction {
preconditionFailure("Unimplemented")
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
preconditionFailure("Tried to cancel stream without active query")
}

return self.avoidingStateMachineCoW { machine -> ConnectionAction in
let action = queryState.cancel()
machine.state = .extendedQuery(queryState, connectionContext)
return machine.modify(with: action)
}
}

mutating func requestQueryRows() -> ConnectionAction {
Expand Down Expand Up @@ -1074,6 +1082,8 @@ extension ConnectionStateMachine {
return true
case .failedToAddSSLHandler:
return true
case .queryCancelled:
return false
case .server(let message):
guard let sqlState = message.fields[.sqlState] else {
// any error message that doesn't have a sql state field, is unexpected by default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

struct ExtendedQueryStateMachine {

enum State {
private enum State {
case initialized(ExtendedQueryContext)
case parseDescribeBindExecuteSyncSent(ExtendedQueryContext)

Expand All @@ -15,6 +15,8 @@ struct ExtendedQueryStateMachine {
/// used after receiving a `bindComplete` message
case bindCompleteReceived(ExtendedQueryContext)
case streaming([RowDescription.Column], RowStreamStateMachine)
/// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP
case drain([RowDescription.Column])

case commandComplete(commandTag: String)
case error(PSQLError)
Expand All @@ -41,9 +43,11 @@ struct ExtendedQueryStateMachine {
case wait
}

var state: State
private var state: State
private var isCancelled: Bool

init(queryContext: ExtendedQueryContext) {
self.isCancelled = false
self.state = .initialized(queryContext)
}

Expand Down Expand Up @@ -71,6 +75,44 @@ struct ExtendedQueryStateMachine {
}
}
}

mutating func cancel() -> Action {
switch self.state {
case .initialized:
preconditionFailure("Start must be called immediatly after the query was created")

case .parseDescribeBindExecuteSyncSent(let queryContext),
.parseCompleteReceived(let queryContext),
.parameterDescriptionReceived(let queryContext),
.rowDescriptionReceived(let queryContext, _),
.noDataMessageReceived(let queryContext),
.bindCompleteReceived(let queryContext):
guard !self.isCancelled else {
return .wait
}

self.isCancelled = true
return .failQuery(queryContext, with: .queryCancelled)

case .streaming(let columns, var streamStateMachine):
precondition(!self.isCancelled)
self.isCancelled = true
self.state = .drain(columns)
switch streamStateMachine.fail() {
case .wait:
return .forwardStreamError(.queryCancelled, read: false)
case .read:
return .forwardStreamError(.queryCancelled, read: true)
}

case .commandComplete, .error, .drain:
// the stream has already finished.
return .wait

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func parseCompletedReceived() -> Action {
guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else {
Expand Down Expand Up @@ -147,9 +189,11 @@ struct ExtendedQueryStateMachine {
.parameterDescriptionReceived,
.bindCompleteReceived,
.streaming,
.drain,
.commandComplete,
.error:
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))

case .modifying:
preconditionFailure("Invalid state")
}
Expand All @@ -169,6 +213,13 @@ struct ExtendedQueryStateMachine {
state = .streaming(columns, demandStateMachine)
return .wait
}

case .drain(let columns):
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}
// we ignore all rows and wait for readyForQuery
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
Expand Down Expand Up @@ -198,6 +249,11 @@ struct ExtendedQueryStateMachine {
state = .commandComplete(commandTag: commandTag)
return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag)
}

case .drain:
precondition(self.isCancelled)
self.state = .commandComplete(commandTag: commandTag)
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
Expand Down Expand Up @@ -229,7 +285,7 @@ struct ExtendedQueryStateMachine {
return self.setAndFireError(error)
case .rowDescriptionReceived, .noDataMessageReceived:
return self.setAndFireError(error)
case .streaming:
case .streaming, .drain:
return self.setAndFireError(error)
case .commandComplete:
return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage)))
Expand Down Expand Up @@ -269,6 +325,9 @@ struct ExtendedQueryStateMachine {
}
}

case .drain:
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
.parseCompleteReceived,
Expand All @@ -291,6 +350,7 @@ struct ExtendedQueryStateMachine {
switch self.state {
case .initialized,
.commandComplete,
.drain,
.error,
.parseDescribeBindExecuteSyncSent,
.parseCompleteReceived,
Expand Down Expand Up @@ -327,6 +387,7 @@ struct ExtendedQueryStateMachine {
.bindCompleteReceived:
return .read
case .streaming(let columns, var demandStateMachine):
precondition(!self.isCancelled)
return self.avoidingStateMachineCoW { state -> Action in
let action = demandStateMachine.read()
state = .streaming(columns, demandStateMachine)
Expand All @@ -339,6 +400,7 @@ struct ExtendedQueryStateMachine {
}
case .initialized,
.commandComplete,
.drain,
.error:
// we already have the complete stream received, now we are waiting for a
// `readyForQuery` package. To receive this we need to read!
Expand All @@ -361,11 +423,20 @@ struct ExtendedQueryStateMachine {
.bindCompleteReceived(let context):
self.state = .error(error)
return .failQuery(context, with: error)
case .streaming:

case .drain:
self.state = .error(error)
return .forwardStreamError(error, read: false)

case .streaming(_, var streamStateMachine):
self.state = .error(error)
switch streamStateMachine.fail() {
case .wait:
return .forwardStreamError(error, read: false)
case .read:
return .forwardStreamError(error, read: true)
}

case .commandComplete, .error:
preconditionFailure("""
This state must not be reached. If the query `.isComplete`, the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct RowStreamStateMachine {
/// preserved for performance reasons.
case waitingForDemand([DataRow])

case failed

case modifying
}

Expand Down Expand Up @@ -63,6 +65,11 @@ struct RowStreamStateMachine {
buffer.append(newRow)
self.state = .waitingForReadOrDemand(buffer)

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -86,6 +93,11 @@ struct RowStreamStateMachine {
.waitingForReadOrDemand:
preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)")

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -111,6 +123,11 @@ struct RowStreamStateMachine {
// the next `channelReadComplete` we will forward all buffered data
return .wait

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -136,6 +153,11 @@ struct RowStreamStateMachine {
// from the consumer
return .wait

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -158,6 +180,33 @@ struct RowStreamStateMachine {
// receive a call to `end()`, when we don't expect it here.
return buffer

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func fail() -> Action {
switch self.state {
case .waitingForRows,
.waitingForReadOrDemand,
.waitingForRead:
self.state = .failed
return .wait

case .waitingForDemand:
self.state = .failed
return .read

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand Down
7 changes: 6 additions & 1 deletion Sources/PostgresNIO/New/PSQLError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ struct PSQLError: Error {
case unsupportedAuthMechanism(PSQLAuthScheme)
case authMechanismRequiresPassword
case saslError(underlyingError: Error)


case queryCancelled
case tooManyParameters
case connectionQuiescing
case connectionClosed
Expand Down Expand Up @@ -58,6 +59,10 @@ struct PSQLError: Error {
static func sasl(underlying: Error) -> PSQLError {
Self.init(.saslError(underlyingError: underlying))
}

static var queryCancelled: PSQLError {
Self.init(.queryCancelled)
}

static var tooManyParameters: PSQLError {
Self.init(.tooManyParameters)
Expand Down
19 changes: 12 additions & 7 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
/// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed).
///
/// The context is captured in `handlerAdded` and released` in `handlerRemoved`
private var handlerContext: ChannelHandlerContext!
private var handlerContext: ChannelHandlerContext?
private var rowStream: PSQLRowStream?
private var decoder: NIOSingleStepByteToMessageProcessor<PostgresBackendMessageDecoder>
private var encoder: BufferedMessageEncoder!
Expand Down Expand Up @@ -262,7 +262,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {

case .forwardStreamComplete(let buffer, let commandTag):
guard let rowStream = self.rowStream else {
preconditionFailure("Expected to have a row stream here.")
// if the stream was cancelled we don't have it here anymore.
return
}
self.rowStream = nil
if buffer.count > 0 {
Expand Down Expand Up @@ -499,18 +500,20 @@ final class PostgresChannelHandler: ChannelDuplexHandler {

extension PostgresChannelHandler: PSQLRowsDataSource {
func request(for stream: PSQLRowStream) {
guard self.rowStream === stream else {
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
return
}
let action = self.state.requestQueryRows()
self.run(action, with: self.handlerContext!)
self.run(action, with: handlerContext)
}

func cancel(for stream: PSQLRowStream) {
guard self.rowStream === stream else {
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
return
}
// we ignore this right now :)
let action = self.state.cancelQueryStream()
self.run(action, with: handlerContext)
}
}

Expand All @@ -519,7 +522,8 @@ extension PostgresConnection.Configuration.Authentication {
AuthContext(
username: self.username,
password: self.password,
database: self.database)
database: self.database
)
}
}

Expand All @@ -529,7 +533,8 @@ extension AuthContext {
user: self.username,
database: self.database,
options: nil,
replication: .false)
replication: .false
)
}
}

Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/Postgres+PSQLCompat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import NIOCore
extension PSQLError {
func toPostgresError() -> Error {
switch self.base {
case .queryCancelled:
return self
case .server(let errorMessage):
var fields = [PostgresMessage.Error.Field: String]()
fields.reserveCapacity(errorMessage.fields.count)
Expand Down
Loading