From aa9273c06a0f42281635eaf0400aa024157c8fa9 Mon Sep 17 00:00:00 2001 From: Iceman Date: Thu, 20 Jul 2023 17:21:49 +0900 Subject: [PATCH 001/106] Use computed property to PostgresConnection.Configuration.TLS.disable for concurrency safe (#376) --- .../Connection/PostgresConnection+Configuration.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index 54eefc90..bc9bcfc2 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -13,7 +13,7 @@ extension PostgresConnection { // MARK: Initializers /// Do not try to create a TLS connection to the server. - public static var disable: Self = .init(base: .disable) + public static var disable: Self { .init(base: .disable) } /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. /// If the server does not support TLS, create an insecure connection. From f3587a586dc5d33b016da6b30d01bbad343c10af Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 29 Jul 2023 04:01:07 -0500 Subject: [PATCH 002/106] Fix multiple warnings generated by the documentation build (#378) --- .github/workflows/api-docs.yml | 2 +- .github/workflows/test.yml | 18 ++++++++++-------- .../PostgresNIO/Data/PostgresDataType.swift | 2 +- Sources/PostgresNIO/Docs.docc/index.md | 4 ++-- Sources/PostgresNIO/Docs.docc/migrations.md | 2 +- Sources/PostgresNIO/New/PostgresQuery.swift | 8 ++++---- Sources/PostgresNIO/Utilities/Exports.swift | 2 +- 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index 80291c6f..dc2e0634 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -11,4 +11,4 @@ jobs: with: package_name: postgres-nio modules: PostgresNIO - pathsToInvalidate: /postgresnio + pathsToInvalidate: /postgresnio/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 25374cf3..24821c77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,13 +26,13 @@ jobs: container: ${{ matrix.container }} runs-on: ubuntu-latest steps: + - name: Note Swift version + if: ${{ contains(matrix.swiftver, 'nightly') }} + run: | + echo "SWIFT_PLATFORM=$(. /etc/os-release && echo "${ID}${VERSION_ID}")" >>"${GITHUB_ENV}" + echo "SWIFT_VERSION=$(cat /.swift_tag)" >>"${GITHUB_ENV}" - name: Display OS and Swift versions - shell: bash run: | - if [[ '${{ contains(matrix.container, 'nightly') }}' == 'true' ]]; then - SWIFT_PLATFORM="$(source /etc/os-release && echo "${ID}${VERSION_ID}")" SWIFT_VERSION="$(cat /.swift_tag)" - printf 'SWIFT_PLATFORM=%s\nSWIFT_VERSION=%s\n' "${SWIFT_PLATFORM}" "${SWIFT_VERSION}" >>"${GITHUB_ENV}" - fi printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 @@ -144,6 +144,7 @@ jobs: POSTGRES_DB: 'postgres' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' + POSTGRES_VERSION: ${{ matrix.dbimage }} steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 @@ -151,9 +152,9 @@ jobs: xcode-version: ${{ matrix.xcode }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="$(brew --prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install '${{ matrix.dbimage }}' && brew link --force '${{ matrix.dbimage }}' - initdb --locale=C --auth-host '${{ matrix.dbauth }}' -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") + export PATH="$(brew --prefix)/opt/${POSTGRES_VERSION}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + (brew unlink postgresql || true) && brew install "${POSTGRES_VERSION}" && brew link --force "${POSTGRES_VERSION}" + initdb --locale=C --auth-host "${POSTGRES_HOST_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 2 - name: Checkout code @@ -175,3 +176,4 @@ jobs: run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: API breaking changes run: swift package diagnose-api-breaking-changes origin/main + diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index ede60f47..f3ab4dca 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -763,7 +763,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri } } - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index e7363054..b4dc7e30 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -1,12 +1,12 @@ # ``PostgresNIO`` -🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on SwiftNIO. ## Overview Features: -- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server using [SwiftNIO]. - An async/await interface that supports backpressure - Automatic conversions between Swift primitive types and the Postgres wire format - Integrated with the Swift server ecosystem, including use of [SwiftLog]. diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md index 33c8afd4..7185ba06 100644 --- a/Sources/PostgresNIO/Docs.docc/migrations.md +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -6,7 +6,7 @@ which use the ``PostgresRow/column(_:)`` API today. ## TLDR 1. Map your sequence of ``PostgresRow``s to ``PostgresRandomAccessRow``s. -2. Use the ``PostgresRandomAccessRow/subscript(name:)`` API to receive a ``PostgresCell`` +2. Use the ``PostgresRandomAccessRow/subscript(_:)-3facl`` API to receive a ``PostgresCell`` 3. Decode the ``PostgresCell`` into a Swift type using the ``PostgresCell/decode(_:file:line:)`` method. ```swift diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 381370e9..2e06e1d9 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -96,14 +96,14 @@ extension PostgresQuery { } extension PostgresQuery: CustomStringConvertible { - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { "\(self.sql) \(self.binds)" } } extension PostgresQuery: CustomDebugStringConvertible { - /// See ``Swift/CustomDebugStringConvertible/debugDescription``. + // See `CustomDebugStringConvertible.debugDescription`. public var debugDescription: String { "PostgresQuery(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds)))" } @@ -216,7 +216,7 @@ public struct PostgresBindings: Sendable, Hashable { } extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertible { - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { """ [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) @@ -225,7 +225,7 @@ extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertibl """ } - /// See ``Swift/CustomDebugStringConvertible/description``. + // See `CustomDebugStringConvertible.description`. public var debugDescription: String { """ [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 204df50c..58e12891 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.8) +#if swift(>=5.8) @_documentation(visibility: internal) @_exported import NIO @_documentation(visibility: internal) @_exported import NIOSSL From 718d154ad788b9e3fca73c83016a03d70d018dfb Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 3 Aug 2023 17:30:47 +0200 Subject: [PATCH 003/106] Crash fix: Multiple bad messages could trigger reentrancy issue (#379) If we receive multiple unexpected messages from the backend we can run into a reentrancy situation in which we still have unread messages in the incoming buffer after we have received `channelInactive`. This pr patches this crash. --- .../ConnectionStateMachine.swift | 26 ++--- .../New/PostgresChannelHandler.swift | 107 ++++++++++-------- .../New/PostgresChannelHandlerTests.swift | 39 ++++++- 3 files changed, 111 insertions(+), 61 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 563bb026..ba1e3c1f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -928,7 +928,7 @@ struct ConnectionStateMachine { .forwardStreamComplete, .wait, .read: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) case .failQuery(let queryContext, with: let error): @@ -951,7 +951,7 @@ struct ConnectionStateMachine { .succeedPreparedStatementCreation, .read, .wait: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .failPreparedStatementCreation(let preparedStatementContext, with: let error): return .failPreparedStatementCreation(preparedStatementContext, with: error, cleanupContext: cleanupContext) } @@ -970,22 +970,20 @@ struct ConnectionStateMachine { .succeedClose, .read, .wait: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } - case .error: - // TBD: this is an interesting case. why would this case happen? - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - return .closeConnectionAndCleanup(cleanupContext) - - case .closing: - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - return .closeConnectionAndCleanup(cleanupContext) - case .closed: - preconditionFailure("How can an error occur if the connection is already closed?") + case .error, .closing, .closed: + // We might run into this case because of reentrancy. For example: After we received an + // backend unexpected message, that we read of the wire, we bring this connection into + // the error state and will try to close the connection. However the server might have + // send further follow up messages. In those cases we will run into this method again + // and again. We should just ignore those events. + return .wait + case .modifying: - preconditionFailure("Invalid state") + preconditionFailure("Invalid state: \(self.state)") } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 84f07d47..fdb6a443 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -84,6 +84,17 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } func channelInactive(context: ChannelHandlerContext) { + do { + try self.decoder.finishProcessing(seenEOF: true) { message in + self.handleMessage(message, context: context) + } + } catch let error as PostgresMessageDecodingError { + let action = self.state.errorHappened(.messageDecodingFailure(error)) + self.run(action, with: context) + } catch { + preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") + } + self.logger.trace("Channel inactive.") let action = self.state.closed() self.run(action, with: context) @@ -100,51 +111,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { do { try self.decoder.process(buffer: buffer) { message in - self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) - let action: ConnectionStateMachine.ConnectionAction - - switch message { - case .authentication(let authentication): - action = self.state.authenticationMessageReceived(authentication) - case .backendKeyData(let keyData): - action = self.state.backendKeyDataReceived(keyData) - case .bindComplete: - action = self.state.bindCompleteReceived() - case .closeComplete: - action = self.state.closeCompletedReceived() - case .commandComplete(let commandTag): - action = self.state.commandCompletedReceived(commandTag) - case .dataRow(let dataRow): - action = self.state.dataRowReceived(dataRow) - case .emptyQueryResponse: - action = self.state.emptyQueryResponseReceived() - case .error(let errorResponse): - action = self.state.errorReceived(errorResponse) - case .noData: - action = self.state.noDataReceived() - case .notice(let noticeResponse): - action = self.state.noticeReceived(noticeResponse) - case .notification(let notification): - action = self.state.notificationReceived(notification) - case .parameterDescription(let parameterDescription): - action = self.state.parameterDescriptionReceived(parameterDescription) - case .parameterStatus(let parameterStatus): - action = self.state.parameterStatusReceived(parameterStatus) - case .parseComplete: - action = self.state.parseCompleteReceived() - case .portalSuspended: - action = self.state.portalSuspendedReceived() - case .readyForQuery(let transactionState): - action = self.state.readyForQueryReceived(transactionState) - case .rowDescription(let rowDescription): - action = self.state.rowDescriptionReceived(rowDescription) - case .sslSupported: - action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) - case .sslUnsupported: - action = self.state.sslUnsupportedReceived() - } - - self.run(action, with: context) + self.handleMessage(message, context: context) } } catch let error as PostgresMessageDecodingError { let action = self.state.errorHappened(.messageDecodingFailure(error)) @@ -153,7 +120,55 @@ final class PostgresChannelHandler: ChannelDuplexHandler { preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") } } - + + private func handleMessage(_ message: PostgresBackendMessage, context: ChannelHandlerContext) { + self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) + let action: ConnectionStateMachine.ConnectionAction + + switch message { + case .authentication(let authentication): + action = self.state.authenticationMessageReceived(authentication) + case .backendKeyData(let keyData): + action = self.state.backendKeyDataReceived(keyData) + case .bindComplete: + action = self.state.bindCompleteReceived() + case .closeComplete: + action = self.state.closeCompletedReceived() + case .commandComplete(let commandTag): + action = self.state.commandCompletedReceived(commandTag) + case .dataRow(let dataRow): + action = self.state.dataRowReceived(dataRow) + case .emptyQueryResponse: + action = self.state.emptyQueryResponseReceived() + case .error(let errorResponse): + action = self.state.errorReceived(errorResponse) + case .noData: + action = self.state.noDataReceived() + case .notice(let noticeResponse): + action = self.state.noticeReceived(noticeResponse) + case .notification(let notification): + action = self.state.notificationReceived(notification) + case .parameterDescription(let parameterDescription): + action = self.state.parameterDescriptionReceived(parameterDescription) + case .parameterStatus(let parameterStatus): + action = self.state.parameterStatusReceived(parameterStatus) + case .parseComplete: + action = self.state.parseCompleteReceived() + case .portalSuspended: + action = self.state.portalSuspendedReceived() + case .readyForQuery(let transactionState): + action = self.state.readyForQueryReceived(transactionState) + case .rowDescription(let rowDescription): + action = self.state.rowDescriptionReceived(rowDescription) + case .sslSupported: + action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) + case .sslUnsupported: + action = self.state.sslUnsupportedReceived() + } + + self.run(action, with: context) + } + func channelReadComplete(context: ChannelHandlerContext) { let action = self.state.channelReadComplete() self.run(action, with: context) diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 7ab0ce30..d76b8223 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -198,7 +198,44 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(message, .password(.init(value: password))) } - + + func testHandlerThatSendsMultipleWrongMessages() { + let config = self.testConnectionConfiguration() + let handler = PostgresChannelHandler(configuration: config, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler + ]) + + var maybeMessage: PostgresFrontendMessage? + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .startup(let startup) = maybeMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startup.parameters.user, config.username) + XCTAssertEqual(startup.parameters.database, config.database) + XCTAssertEqual(startup.parameters.options, nil) + XCTAssertEqual(startup.parameters.replication, .false) + + var buffer = ByteBuffer() + buffer.writeMultipleIntegers(UInt8(ascii: "R"), UInt32(8), Int32(0)) + buffer.writeMultipleIntegers(UInt8(ascii: "K"), UInt32(12), Int32(1234), Int32(5678)) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + XCTAssertNoThrow(try embedded.writeInbound(buffer)) + XCTAssertTrue(embedded.isActive) + + buffer.clear() + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + + XCTAssertThrowsError(try embedded.writeInbound(buffer)) + XCTAssertFalse(embedded.isActive) + } + // MARK: Helpers func testConnectionConfiguration( From 4fd297db09ea09c6007b4abdec056f5f5387bb27 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 4 Aug 2023 22:55:49 +0200 Subject: [PATCH 004/106] PostgresFrontendMessage: refactor encoding (#381) --- .../New/BufferedMessageEncoder.swift | 35 --- Sources/PostgresNIO/New/Messages/Bind.swift | 45 ---- Sources/PostgresNIO/New/Messages/Cancel.swift | 21 -- Sources/PostgresNIO/New/Messages/Close.swift | 20 -- .../PostgresNIO/New/Messages/Describe.swift | 21 -- .../PostgresNIO/New/Messages/Execute.swift | 23 -- Sources/PostgresNIO/New/Messages/Parse.swift | 26 --- .../PostgresNIO/New/Messages/Password.swift | 13 -- .../New/Messages/SASLInitialResponse.swift | 28 --- .../New/Messages/SASLResponse.swift | 19 -- .../PostgresNIO/New/Messages/SSLRequest.swift | 21 -- .../PostgresNIO/New/Messages/Startup.swift | 40 +--- .../New/PSQLFrontendMessageEncoder.swift | 85 -------- .../New/PostgresChannelHandler.swift | 113 +++++----- .../New/PostgresFrontendMessage.swift | 94 +++++++- .../New/PostgresFrontendMessageEncoder.swift | 205 ++++++++++++++++++ .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/Messages/BindTests.swift | 12 +- .../New/Messages/CancelTests.swift | 15 +- .../New/Messages/CloseTests.swift | 20 +- .../New/Messages/DescribeTests.swift | 18 +- .../New/Messages/ExecuteTests.swift | 9 +- .../New/Messages/ParseTests.swift | 39 ++-- .../New/Messages/PasswordTests.swift | 8 +- .../Messages/SASLInitialResponseTests.swift | 37 ++-- .../New/Messages/SASLResponseTests.swift | 26 +-- .../New/Messages/SSLRequestTests.swift | 12 +- .../New/Messages/StartupTests.swift | 11 +- .../New/PSQLFrontendMessageTests.swift | 24 +- .../New/PostgresChannelHandlerTests.swift | 20 +- 30 files changed, 464 insertions(+), 598 deletions(-) delete mode 100644 Sources/PostgresNIO/New/BufferedMessageEncoder.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Bind.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Cancel.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Close.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Describe.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Execute.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Parse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Password.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SASLResponse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SSLRequest.swift delete mode 100644 Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift create mode 100644 Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift deleted file mode 100644 index f202fcff..00000000 --- a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift +++ /dev/null @@ -1,35 +0,0 @@ -import NIOCore - -struct BufferedMessageEncoder { - private enum State { - case flushed - case writable - } - - private var buffer: ByteBuffer - private var state: State = .writable - private var encoder: PSQLFrontendMessageEncoder - - init(buffer: ByteBuffer, encoder: PSQLFrontendMessageEncoder) { - self.buffer = buffer - self.encoder = encoder - } - - mutating func encode(_ message: PostgresFrontendMessage) { - switch self.state { - case .flushed: - self.state = .writable - self.buffer.clear() - - case .writable: - break - } - - self.encoder.encode(data: message, out: &self.buffer) - } - - mutating func flush() -> ByteBuffer { - self.state = .flushed - return self.buffer - } -} diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift deleted file mode 100644 index 898018d4..00000000 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ /dev/null @@ -1,45 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Bind: PSQLMessagePayloadEncodable, Equatable { - /// The name of the destination portal (an empty string selects the unnamed portal). - var portalName: String - - /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). - var preparedStatementName: String - - /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. - var bind: PostgresBindings - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeNullTerminatedString(self.preparedStatementName) - - // The number of parameter format codes that follow (denoted C below). This can be - // zero to indicate that there are no parameters or that the parameters all use the - // default format (text); or one, in which case the specified format code is applied - // to all parameters; or it can equal the actual number of parameters. - buffer.writeInteger(UInt16(self.bind.count)) - - // The parameter format codes. Each must presently be zero (text) or one (binary). - self.bind.metadata.forEach { - buffer.writeInteger($0.format.rawValue) - } - - buffer.writeInteger(UInt16(self.bind.count)) - - var parametersCopy = self.bind.bytes - buffer.writeBuffer(¶metersCopy) - - // The number of result-column format codes that follow (denoted R below). This can be - // zero to indicate that there are no result columns or that the result columns should - // all use the default format (text); or one, in which case the specified format code - // is applied to all result columns (if any); or it can equal the actual number of - // result columns of the query. - buffer.writeInteger(1, as: Int16.self) - // The result-column format codes. Each must presently be zero (text) or one (binary). - buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift deleted file mode 100644 index 2f29d239..00000000 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Cancel: PSQLMessagePayloadEncodable, Equatable { - /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, - /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same - /// as any protocol version number.) - let cancelRequestCode: Int32 = 80877102 - - /// The process ID of the target backend. - let processID: Int32 - - /// The secret key for the target backend. - let secretKey: Int32 - - func encode(into buffer: inout ByteBuffer) { - buffer.writeMultipleIntegers(self.cancelRequestCode, self.processID, self.secretKey) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift deleted file mode 100644 index 7f038f94..00000000 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ /dev/null @@ -1,20 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - enum Close: PSQLMessagePayloadEncodable, Equatable { - case preparedStatement(String) - case portal(String) - - func encode(into buffer: inout ByteBuffer) { - switch self { - case .preparedStatement(let name): - buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) - case .portal(let name): - buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift deleted file mode 100644 index 76167d32..00000000 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - enum Describe: PSQLMessagePayloadEncodable, Equatable { - - case preparedStatement(String) - case portal(String) - - func encode(into buffer: inout ByteBuffer) { - switch self { - case .preparedStatement(let name): - buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) - case .portal(let name): - buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift deleted file mode 100644 index 17646484..00000000 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ /dev/null @@ -1,23 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Execute: PSQLMessagePayloadEncodable, Equatable { - /// The name of the portal to execute (an empty string selects the unnamed portal). - let portalName: String - - /// Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes “no limit”. - let maxNumberOfRows: Int32 - - init(portalName: String, maxNumberOfRows: Int32 = 0) { - self.portalName = portalName - self.maxNumberOfRows = maxNumberOfRows - } - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeInteger(self.maxNumberOfRows) - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift deleted file mode 100644 index 9d3cfa0b..00000000 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ /dev/null @@ -1,26 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Parse: PSQLMessagePayloadEncodable, Equatable { - /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). - let preparedStatementName: String - - /// The query string to be parsed. - let query: String - - /// The number of parameter data types specified (can be zero). Note that this is not an indication of the number of parameters that might appear in the query string, only the number that the frontend wants to prespecify types for. - let parameters: [PostgresDataType] - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.preparedStatementName) - buffer.writeNullTerminatedString(self.query) - buffer.writeInteger(UInt16(self.parameters.count)) - - self.parameters.forEach { dataType in - buffer.writeInteger(dataType.rawValue) - } - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift deleted file mode 100644 index 81d7ab30..00000000 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ /dev/null @@ -1,13 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Password: PSQLMessagePayloadEncodable, Equatable { - let value: String - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(value) - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift deleted file mode 100644 index 73db9332..00000000 --- a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift +++ /dev/null @@ -1,28 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct SASLInitialResponse: PSQLMessagePayloadEncodable, Equatable { - - let saslMechanism: String - let initialData: [UInt8] - - /// Creates a new `SSLRequest`. - init(saslMechanism: String, initialData: [UInt8]) { - self.saslMechanism = saslMechanism - self.initialData = initialData - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.saslMechanism) - - if self.initialData.count > 0 { - buffer.writeInteger(Int32(self.initialData.count)) - buffer.writeBytes(self.initialData) - } else { - buffer.writeInteger(Int32(-1)) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/SASLResponse.swift b/Sources/PostgresNIO/New/Messages/SASLResponse.swift deleted file mode 100644 index a6709dcd..00000000 --- a/Sources/PostgresNIO/New/Messages/SASLResponse.swift +++ /dev/null @@ -1,19 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct SASLResponse: PSQLMessagePayloadEncodable, Equatable { - - let data: [UInt8] - - /// Creates a new `SSLRequest`. - init(data: [UInt8]) { - self.data = data - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeBytes(self.data) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/SSLRequest.swift b/Sources/PostgresNIO/New/Messages/SSLRequest.swift deleted file mode 100644 index 6f9c45a3..00000000 --- a/Sources/PostgresNIO/New/Messages/SSLRequest.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - /// A message asking the PostgreSQL server if TLS is supported - /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 - struct SSLRequest: PSQLMessagePayloadEncodable, Equatable { - /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, - /// and 5679 in the least significant 16 bits. - let code: Int32 - - /// Creates a new `SSLRequest`. - init() { - self.code = 80877103 - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(self.code) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index f7da2127..16d23e09 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -1,13 +1,14 @@ import NIOCore extension PostgresFrontendMessage { - struct Startup: PSQLMessagePayloadEncodable, Equatable { + struct Startup: Hashable { + static let versionThree: Int32 = 0x00_03_00_00 /// Creates a `Startup` with "3.0" as the protocol version. static func versionThree(parameters: Parameters) -> Startup { - return .init(protocolVersion: 0x00_03_00_00, parameters: parameters) + return .init(protocolVersion: Self.versionThree, parameters: parameters) } - + /// The protocol version number. The most significant 16 bits are the major /// version number (3 for the protocol described here). The least significant /// 16 bits are the minor version number (0 for the protocol described here). @@ -16,7 +17,7 @@ extension PostgresFrontendMessage { /// The protocol version number is followed by one or more pairs of parameter /// name and value strings. A zero byte is required as a terminator after /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Equatable { + struct Parameters: Hashable { enum Replication { case `true` case `false` @@ -47,36 +48,5 @@ extension PostgresFrontendMessage { self.protocolVersion = protocolVersion self.parameters = parameters } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(self.protocolVersion) - buffer.writeNullTerminatedString("user") - buffer.writeNullTerminatedString(self.parameters.user) - - if let database = self.parameters.database { - buffer.writeNullTerminatedString("database") - buffer.writeNullTerminatedString(database) - } - - if let options = self.parameters.options { - buffer.writeNullTerminatedString("options") - buffer.writeNullTerminatedString(options) - } - - switch self.parameters.replication { - case .database: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("replication") - case .true: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("true") - case .false: - break - } - - buffer.writeInteger(UInt8(0)) - } } - } diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift deleted file mode 100644 index 24155d84..00000000 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ /dev/null @@ -1,85 +0,0 @@ -import NIOCore - -struct PSQLFrontendMessageEncoder: MessageToByteEncoder { - typealias OutboundIn = PostgresFrontendMessage - - init() {} - - func encode(data message: PostgresFrontendMessage, out buffer: inout ByteBuffer) { - switch message { - case .bind(let bind): - buffer.writeInteger(message.id.rawValue) - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - bind.encode(into: &buffer) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - - case .cancel(let cancel): - // cancel requests don't have an identifier - self.encode(payload: cancel, into: &buffer) - - case .close(let close): - self.encode(messageID: message.id, payload: close, into: &buffer) - - case .describe(let describe): - self.encode(messageID: message.id, payload: describe, into: &buffer) - - case .execute(let execute): - self.encode(messageID: message.id, payload: execute, into: &buffer) - - case .flush: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - - case .parse(let parse): - self.encode(messageID: message.id, payload: parse, into: &buffer) - - case .password(let password): - self.encode(messageID: message.id, payload: password, into: &buffer) - - case .saslInitialResponse(let saslInitialResponse): - self.encode(messageID: message.id, payload: saslInitialResponse, into: &buffer) - - case .saslResponse(let saslResponse): - self.encode(messageID: message.id, payload: saslResponse, into: &buffer) - - case .sslRequest(let request): - // sslRequests don't have an identifier - self.encode(payload: request, into: &buffer) - - case .startup(let startup): - // startup requests don't have an identifier - self.encode(payload: startup, into: &buffer) - - case .sync: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - - case .terminate: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - } - } - - private struct EmptyPayload: PSQLMessagePayloadEncodable { - func encode(into buffer: inout ByteBuffer) {} - } - - private func encode( - messageID: PostgresFrontendMessage.ID, - payload: Payload, - into buffer: inout ByteBuffer) - { - buffer.psqlWriteFrontendMessageID(messageID) - self.encode(payload: payload, into: &buffer) - } - - private func encode( - payload: Payload, - into buffer: inout ByteBuffer) - { - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - payload.encode(into: &buffer) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - } -} diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index fdb6a443..09feb521 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -21,7 +21,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var handlerContext: ChannelHandlerContext? private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor - private var encoder: BufferedMessageEncoder! + private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? @@ -58,10 +58,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { func handlerAdded(context: ChannelHandlerContext) { self.handlerContext = context - self.encoder = BufferedMessageEncoder( - buffer: context.channel.allocator.buffer(capacity: 256), - encoder: PSQLFrontendMessageEncoder() - ) + self.encoder = PostgresFrontendMessageEncoder(buffer: context.channel.allocator.buffer(capacity: 256)) if context.channel.isActive { self.connected(context: context) @@ -239,19 +236,19 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.startup(authContext.toStartupParameters()) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: - self.encoder.encode(.sslRequest(.init())) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.ssl() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendPasswordMessage(let mode, let authContext): self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) case .sendSaslInitialResponse(let name, let initialResponse): - self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.saslInitialResponse(mechanism: name, bytes: initialResponse) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSaslResponse(let bytes): - self.encoder.encode(.saslResponse(.init(data: bytes))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.saslResponse(bytes) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .closeConnectionAndCleanup(let cleanupContext): self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: @@ -315,8 +312,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. On receipt of this message, the // backend closes the connection and terminates. - self.encoder.encode(.terminate) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.terminate() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } context.close(mode: .all, promise: promise) case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): @@ -381,89 +378,79 @@ final class PostgresChannelHandler: ChannelDuplexHandler { hash2.append(salt.3) let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() - self.encoder.encode(.password(.init(value: hash))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.password(hash.utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .cleartext: - self.encoder.encode(.password(.init(value: authContext.password ?? ""))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.password((authContext.password ?? "").utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } } private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { switch sendClose { case .preparedStatement(let name): - self.encoder.encode(.close(.preparedStatement(name))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.closePreparedStatement(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .portal(let name): - self.encoder.encode(.close(.portal(name))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.closePortal(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } } private func sendParseDecribeAndSyncMessage( statementName: String, query: String, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - let parse = PostgresFrontendMessage.Parse( - preparedStatementName: statementName, - query: query, - parameters: []) - self.encoder.encode(.parse(parse)) - self.encoder.encode(.describe(.preparedStatement(statementName))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) + self.encoder.describePreparedStatement(statementName) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func sendBindExecuteAndSyncMessage( executeStatement: PSQLExecuteStatement, context: ChannelHandlerContext ) { - let bind = PostgresFrontendMessage.Bind( + self.encoder.bind( portalName: "", preparedStatementName: executeStatement.name, - bind: executeStatement.binds) - - self.encoder.encode(.bind(bind)) - self.encoder.encode(.execute(.init(portalName: ""))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + bind: executeStatement.binds + ) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func sendParseDescribeBindExecuteAndSyncMessage( query: PostgresQuery, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" - let parse = PostgresFrontendMessage.Parse( + self.encoder.parse( preparedStatementName: unnamedStatementName, query: query.sql, - parameters: query.binds.metadata.map(\.dataType)) - let bind = PostgresFrontendMessage.Bind( - portalName: "", - preparedStatementName: unnamedStatementName, - bind: query.binds) - - self.encoder.encode(.parse(parse)) - self.encoder.encode(.describe(.preparedStatement(""))) - self.encoder.encode(.bind(bind)) - self.encoder.encode(.execute(.init(portalName: ""))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + parameters: query.binds.metadata.lazy.map(\.dataType) + ) + self.encoder.describePreparedStatement(unnamedStatementName) + self.encoder.bind(portalName: "", preparedStatementName: unnamedStatementName, bind: query.binds) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func succeedQueryWithRowStream( _ queryContext: ExtendedQueryContext, columns: [RowDescription.Column], - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { let rows = PSQLRowStream( rowDescription: columns, queryContext: queryContext, @@ -477,8 +464,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func succeedQueryWithoutRowStream( _ queryContext: ExtendedQueryContext, commandTag: String, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { let rows = PSQLRowStream( rowDescription: [], queryContext: queryContext, @@ -490,8 +477,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func closeConnectionAndCleanup( _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) // 1. fail all tasks diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 2017cd1a..3963bd62 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -5,6 +5,98 @@ import NIOCore /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) enum PostgresFrontendMessage: Equatable { + + struct Bind: Hashable { + /// The name of the destination portal (an empty string selects the unnamed portal). + var portalName: String + + /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). + var preparedStatementName: String + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var bind: PostgresBindings + } + + struct Cancel: Equatable { + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let requestCode: Int32 = 80877102 + + /// The process ID of the target backend. + let processID: Int32 + + /// The secret key for the target backend. + let secretKey: Int32 + } + + enum Close: Hashable { + case preparedStatement(String) + case portal(String) + } + + enum Describe: Hashable { + case preparedStatement(String) + case portal(String) + } + + struct Execute: Hashable { + /// The name of the portal to execute (an empty string selects the unnamed portal). + let portalName: String + + /// Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes “no limit”. + let maxNumberOfRows: Int32 + + init(portalName: String, maxNumberOfRows: Int32 = 0) { + self.portalName = portalName + self.maxNumberOfRows = maxNumberOfRows + } + } + + struct Parse: Hashable { + /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). + let preparedStatementName: String + + /// The query string to be parsed. + let query: String + + /// The number of parameter data types specified (can be zero). Note that this is not an indication of the number of parameters that might appear in the query string, only the number that the frontend wants to prespecify types for. + let parameters: [PostgresDataType] + } + + struct Password: Hashable { + let value: String + } + + struct SASLInitialResponse: Hashable { + + let saslMechanism: String + let initialData: [UInt8] + + /// Creates a new `SSLRequest`. + init(saslMechanism: String, initialData: [UInt8]) { + self.saslMechanism = saslMechanism + self.initialData = initialData + } + } + + struct SASLResponse: Hashable { + var data: [UInt8] + + /// Creates a new `SSLRequest`. + init(data: [UInt8]) { + self.data = data + } + } + + /// A message asking the PostgreSQL server if TLS is supported + /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 + struct SSLRequest: Hashable { + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let requestCode: Int32 = 80877103 + } + case bind(Bind) case cancel(Cancel) case close(Close) @@ -15,7 +107,7 @@ enum PostgresFrontendMessage: Equatable { case password(Password) case saslInitialResponse(SASLInitialResponse) case saslResponse(SASLResponse) - case sslRequest(SSLRequest) + case sslRequest case sync case startup(Startup) case terminate diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift new file mode 100644 index 00000000..46dbba42 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -0,0 +1,205 @@ +import NIOCore + +struct PostgresFrontendMessageEncoder { + private enum State { + case flushed + case writable + } + + private var buffer: ByteBuffer + private var state: State = .writable + + init(buffer: ByteBuffer) { + self.buffer = buffer + } + + mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) { + self.clearIfNeeded() + self.encodeLengthPrefixed { buffer in + buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) + buffer.writeNullTerminatedString("user") + buffer.writeNullTerminatedString(parameters.user) + + if let database = parameters.database { + buffer.writeNullTerminatedString("database") + buffer.writeNullTerminatedString(database) + } + + if let options = parameters.options { + buffer.writeNullTerminatedString("options") + buffer.writeNullTerminatedString(options) + } + + switch parameters.replication { + case .database: + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("replication") + case .true: + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("true") + case .false: + break + } + + buffer.writeInteger(UInt8(0)) + } + } + + mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) { + self.clearIfNeeded() + self.buffer.psqlWriteFrontendMessageID(.bind) + self.encodeLengthPrefixed { buffer in + buffer.writeNullTerminatedString(portalName) + buffer.writeNullTerminatedString(preparedStatementName) + + // The number of parameter format codes that follow (denoted C below). This can be + // zero to indicate that there are no parameters or that the parameters all use the + // default format (text); or one, in which case the specified format code is applied + // to all parameters; or it can equal the actual number of parameters. + buffer.writeInteger(UInt16(bind.count)) + + // The parameter format codes. Each must presently be zero (text) or one (binary). + bind.metadata.forEach { + buffer.writeInteger($0.format.rawValue) + } + + buffer.writeInteger(UInt16(bind.count)) + + var parametersCopy = bind.bytes + buffer.writeBuffer(¶metersCopy) + + // The number of result-column format codes that follow (denoted R below). This can be + // zero to indicate that there are no result columns or that the result columns should + // all use the default format (text); or one, in which case the specified format code + // is applied to all result columns (if any); or it can equal the actual number of + // result columns of the query. + buffer.writeInteger(1, as: Int16.self) + // The result-column format codes. Each must presently be zero (text) or one (binary). + buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) + } + } + + mutating func cancel(processID: Int32, secretKey: Int32) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(16), PostgresFrontendMessage.Cancel.requestCode, processID, secretKey) + } + + mutating func closePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func closePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func describePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func describePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.execute.rawValue, UInt32(9 + portalName.utf8.count)) + self.buffer.writeNullTerminatedString(portalName) + self.buffer.writeInteger(maxNumberOfRows) + } + + mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers( + PostgresFrontendMessage.ID.parse.rawValue, + UInt32(4 + preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) + ) + self.buffer.writeNullTerminatedString(preparedStatementName) + self.buffer.writeNullTerminatedString(query) + self.buffer.writeInteger(UInt16(parameters.count)) + + for dataType in parameters { + self.buffer.writeInteger(dataType.rawValue) + } + } + + mutating func password(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.password.rawValue, UInt32(5 + bytes.count)) + self.buffer.writeBytes(bytes) + self.buffer.writeInteger(UInt8(0)) + } + + mutating func flush() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.flush.rawValue, UInt32(4)) + } + + mutating func saslResponse(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt32(4 + bytes.count)) + self.buffer.writeBytes(bytes) + } + + mutating func saslInitialResponse(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers( + PostgresFrontendMessage.ID.saslInitialResponse.rawValue, + UInt32(4 + mechanism.utf8.count + 1 + 4 + bytes.count) + ) + self.buffer.writeNullTerminatedString(mechanism) + if bytes.count > 0 { + self.buffer.writeInteger(Int32(bytes.count)) + self.buffer.writeBytes(bytes) + } else { + self.buffer.writeInteger(Int32(-1)) + } + } + + mutating func ssl() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(8), PostgresFrontendMessage.SSLRequest.requestCode) + } + + mutating func sync() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.sync.rawValue, UInt32(4)) + } + + mutating func terminate() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.terminate.rawValue, UInt32(4)) + } + + mutating func flushBuffer() -> ByteBuffer { + self.state = .flushed + return self.buffer + } + + private mutating func clearIfNeeded() { + switch self.state { + case .flushed: + self.state = .writable + self.buffer.clear() + + case .writable: + break + } + } + + private mutating func encodeLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { + let startIndex = self.buffer.writerIndex + self.buffer.writeInteger(UInt32(0)) // placeholder for length + encode(&self.buffer) + let length = UInt32(self.buffer.writerIndex - startIndex) + self.buffer.setInteger(length, at: startIndex) + } + +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 311c41bd..342907ea 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -34,7 +34,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { switch code { case 80877103: self.isInStartup = true - return .sslRequest(.init()) + return .sslRequest case 196608: var user: String? diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 85768b10..d5ec5b30 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -5,15 +5,15 @@ import NIOCore class BindTests: XCTestCase { func testEncodeBind() { - let encoder = PSQLFrontendMessageEncoder() var bindings = PostgresBindings() bindings.append("Hello", context: .default) bindings.append("World", context: .default) - var byteBuffer = ByteBuffer() - let bind = PostgresFrontendMessage.Bind(portalName: "", preparedStatementName: "", bind: bindings) - let message = PostgresFrontendMessage.bind(bind) - encoder.encode(data: message, out: &byteBuffer) - + + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + + encoder.bind(portalName: "", preparedStatementName: "", bind: bindings) + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index c42f1999..5548aae3 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -5,18 +5,17 @@ import NIOCore class CancelTests: XCTestCase { func testEncodeCancel() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let cancel = PostgresFrontendMessage.Cancel(processID: 1234, secretKey: 4567) - let message = PostgresFrontendMessage.cancel(cancel) - encoder.encode(data: message, out: &byteBuffer) + let processID: Int32 = 1234 + let secretKey: Int32 = 4567 + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.cancel(processID: processID, secretKey: secretKey) + var byteBuffer = encoder.flushBuffer() XCTAssertEqual(byteBuffer.readableBytes, 16) XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length XCTAssertEqual(80877102, byteBuffer.readInteger(as: Int32.self)) // cancel request code - XCTAssertEqual(cancel.processID, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(cancel.secretKey, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(processID, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(secretKey, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(byteBuffer.readableBytes, 0) } - } diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index f6a0237b..a8e1cfeb 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -3,13 +3,11 @@ import NIOCore @testable import PostgresNIO class CloseTests: XCTestCase { - func testEncodeClosePortal() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.close(.portal("Hello")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePortal("Hello") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) @@ -19,11 +17,10 @@ class CloseTests: XCTestCase { } func testEncodeCloseUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.close(.preparedStatement("")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) @@ -31,5 +28,4 @@ class CloseTests: XCTestCase { XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } - } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index df26f3d7..cb3c745b 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -5,11 +5,10 @@ import NIOCore class DescribeTests: XCTestCase { func testEncodeDescribePortal() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.describe(.portal("Hello")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePortal("Hello") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) @@ -19,11 +18,10 @@ class DescribeTests: XCTestCase { } func testEncodeDescribeUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.describe(.preparedStatement("")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index dc5e2767..834ad0dd 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -5,11 +5,10 @@ import NIOCore class ExecuteTests: XCTestCase { func testEncodeExecute() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.execute(portalName: "", maxNumberOfRows: 0) + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 723ad1e6..9f81e4e4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -3,18 +3,19 @@ import NIOCore @testable import PostgresNIO class ParseTests: XCTestCase { - func testEncode() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let parse = PostgresFrontendMessage.Parse( - preparedStatementName: "test", - query: "SELECT version()", - parameters: [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray]) - let message = PostgresFrontendMessage.parse(parse) - encoder.encode(data: message, out: &byteBuffer) + let preparedStatementName = "test" + let query = "SELECT version()" + let parameters: [PostgresDataType] = [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray] + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.parse( + preparedStatementName: preparedStatementName, + query: query, + parameters: parameters + ) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (parse.preparedStatementName.count + 1) + (parse.query.count + 1) + 2 + parse.parameters.count * 4 + let length: Int = 1 + 4 + (preparedStatementName.count + 1) + (query.count + 1) + 2 + parameters.count * 4 // 1 id // + 4 length @@ -24,17 +25,11 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) - XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parse.parameters.count)) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bool.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.int8.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bytea.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.varchar.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.text.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.uuid.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.json.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.jsonbArray.rawValue) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), preparedStatementName) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), query) + XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parameters.count)) + for dataType in parameters { + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), dataType.rawValue) + } } - } diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 7572d382..4a4833d2 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -5,11 +5,11 @@ import NIOCore class PasswordTests: XCTestCase { func testEncodePassword() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 - let message = PostgresFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) - encoder.encode(data: message, out: &byteBuffer) + let password = "md522d085ed8dc3377968dc1c1a40519a2a" + encoder.password(password.utf8) + var byteBuffer = encoder.flushBuffer() let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 08b3097d..90aa6b34 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -4,15 +4,14 @@ import NIOCore class SASLInitialResponseTests: XCTestCase { - func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLInitialResponse( - saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PostgresFrontendMessage.saslInitialResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) + func testEncode() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count // 1 id // + 4 length @@ -23,21 +22,20 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) - XCTAssertEqual(byteBuffer.readBytes(length: sasl.initialData.count), sasl.initialData) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(initialData.count)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLInitialResponse( - saslMechanism: "hello", initialData: []) - let message = PostgresFrontendMessage.saslInitialResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count // 1 id // + 4 length @@ -48,8 +46,9 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index e148420f..cdb0f10b 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -5,28 +5,26 @@ import NIOCore class SASLResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PostgresFrontendMessage.saslResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) - - let length: Int = 1 + 4 + (sasl.data.count) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (data.count) XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readBytes(length: sasl.data.count), sasl.data) + XCTAssertEqual(byteBuffer.readBytes(length: data.count), data) XCTAssertEqual(byteBuffer.readableBytes, 0) } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLResponse(data: []) - let message = PostgresFrontendMessage.saslResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + let length: Int = 1 + 4 XCTAssertEqual(byteBuffer.readableBytes, length) diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index 9a973f2b..e9e6af81 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -5,16 +5,14 @@ import NIOCore class SSLRequestTests: XCTestCase { func testSSLRequest() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let request = PostgresFrontendMessage.SSLRequest() - let message = PostgresFrontendMessage.sslRequest(request) - encoder.encode(data: message, out: &byteBuffer) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.ssl() + var byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(request.code, byteBuffer.readInteger()) - + XCTAssertEqual(PostgresFrontendMessage.SSLRequest.requestCode, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 08a9ee21..e72f0f34 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -5,7 +5,7 @@ import NIOCore class StartupTests: XCTestCase { func testStartupMessage() { - let encoder = PSQLFrontendMessageEncoder() + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [ @@ -22,13 +22,12 @@ class StartupTests: XCTestCase { replication: replication ) - let startup = PostgresFrontendMessage.Startup.versionThree(parameters: parameters) - let message = PostgresFrontendMessage.startup(startup) - encoder.encode(data: message, out: &byteBuffer) - + encoder.startup(parameters) + byteBuffer = encoder.flushBuffer() + let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(startup.protocolVersion, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 59b69bae..33afbe0d 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -23,30 +23,30 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: Encoder func testEncodeFlush() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .flush, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.flush() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } func testEncodeSync() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .sync, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.sync() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } func testEncodeTerminate() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .terminate, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.terminate() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index d76b8223..97ad892f 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -49,15 +49,9 @@ class PostgresChannelHandlerTests: XCTestCase { handler ]) - var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - guard case .sslRequest(let request) = maybeMessage else { - return XCTFail("Unexpected message") - } - - XCTAssertEqual(request.code, 80877103) - + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslSupported)) // a NIOSSLHandler has been added, after it SSL had been negotiated @@ -92,14 +86,8 @@ class PostgresChannelHandlerTests: XCTestCase { eventHandler ]) - var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - guard case .sslRequest(let request) = maybeMessage else { - return XCTFail("Unexpected message") - } - - XCTAssertEqual(request.code, 80877103) + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) var responseBuffer = ByteBuffer() responseBuffer.writeInteger(UInt8(ascii: "S")) @@ -134,7 +122,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertTrue(embedded.isActive) // read the ssl request message - XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest(.init())) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest) try embedded.writeInbound(PostgresBackendMessage.sslUnsupported) // the event handler should have seen an error From 0c9391c68a38be8d9990688717fe26eaad41e395 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 5 Aug 2023 14:49:47 +0200 Subject: [PATCH 005/106] Add async listen; Refactor all listen code (#264) --- .../Connection/PostgresConnection.swift | 139 +++++----- .../ConnectionStateMachine.swift | 5 +- .../ListenStateMachine.swift | 247 ++++++++++++++++++ .../New/NotificationListener.swift | 157 +++++++++++ Sources/PostgresNIO/New/PSQLError.swift | 15 ++ Sources/PostgresNIO/New/PSQLTask.swift | 12 +- .../New/PostgresChannelHandler.swift | 241 ++++++++++++++--- .../New/PostgresFrontendMessage.swift | 7 +- .../New/PostgresNotificationSequence.swift | 22 ++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 4 +- Tests/IntegrationTests/AsyncTests.swift | 23 ++ .../PSQLFrontendMessageDecoder.swift | 81 +++++- .../New/PSQLConnectionTests.swift | 37 --- .../New/PostgresChannelHandlerTests.swift | 43 +-- .../New/PostgresConnectionTests.swift | 245 +++++++++++++++++ 15 files changed, 1104 insertions(+), 174 deletions(-) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift create mode 100644 Sources/PostgresNIO/New/NotificationListener.swift create mode 100644 Sources/PostgresNIO/New/PostgresNotificationSequence.swift delete mode 100644 Tests/PostgresNIOTests/New/PSQLConnectionTests.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresConnectionTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index c24041c9..d6420a6e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -38,15 +38,7 @@ public final class PostgresConnection: @unchecked Sendable { } } - /// A dictionary to store notification callbacks in - /// - /// Those are used when `PostgresConnection.addListener` is invoked. This only lives here since properties - /// can not be added in extensions. All relevant code lives in `PostgresConnection+Notifications` - var notificationListeners: [String: [(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void)]] = [:] { - willSet { - self.channel.eventLoop.preconditionInEventLoop() - } - } + private let internalListenID = ManagedAtomic(0) public var isClosed: Bool { return !self.channel.isActive @@ -87,10 +79,10 @@ public final class PostgresConnection: @unchecked Sendable { let channelHandler = PostgresChannelHandler( configuration: configuration, + eventLoop: channel.eventLoop, logger: logger, configureSSLCallback: configureSSLCallback ) - channelHandler.notificationDelegate = self let eventHandler = PSQLEventsHandler(logger: logger) @@ -164,14 +156,16 @@ public final class PostgresConnection: @unchecked Sendable { // thread and the EventLoop. return eventLoop.flatSubmit { () -> EventLoopFuture in let connectFuture: EventLoopFuture - let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) switch configuration.connection { case .resolved(let address): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(to: address) case .unresolvedTCP(let host, let port): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(host: host, port: port) case .unresolvedUDS(let path): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(unixDomainSocketPath: path) case .bootstrapped(let channel): guard channel.isActive else { @@ -224,9 +218,10 @@ public final class PostgresConnection: @unchecked Sendable { let context = ExtendedQueryContext( query: query, logger: logger, - promise: promise) + promise: promise + ) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -241,7 +236,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(PSQLTask.preparedStatement(context), promise: nil) + self.channel.write(HandlerTask.preparedStatement(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -257,7 +252,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,7 +260,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.channel.write(PSQLTask.closeCommand(context), promise: nil) + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -417,7 +412,7 @@ extension PostgresConnection { promise: promise ) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -428,6 +423,31 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + /// Start listening for a channel + public func listen(_ channel: String) async throws -> PostgresNotificationSequence { + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try Task.checkCancellation() + + return try await withCheckedThrowingContinuation { continuation in + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + checkedContinuation: continuation + ) + + let task = HandlerTask.startListening(listener) + + self.channel.write(task, promise: nil) + } + } onCancel: { + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) + } + } } // MARK: EventLoopFuture interface @@ -569,73 +589,58 @@ internal enum PostgresCommands: PostgresRequest { // MARK: Notifications /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. -public final class PostgresListenContext { - var stopper: (() -> Void)? +public final class PostgresListenContext: Sendable { + private let promise: EventLoopPromise + + var future: EventLoopFuture { + self.promise.futureResult + } + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func cancel() { + self.promise.succeed() + } /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. public func stop() { - stopper?() - stopper = nil + self.promise.succeed() } } extension PostgresConnection { /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. @discardableResult - public func addListener(channel: String, handler notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) -> PostgresListenContext { + @preconcurrency + public func addListener( + channel: String, + handler notificationHandler: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) -> PostgresListenContext { + let listenContext = PostgresListenContext(promise: self.eventLoop.makePromise(of: Void.self)) + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + context: listenContext, + closure: notificationHandler + ) - let listenContext = PostgresListenContext() + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) - self.channel.pipeline.handler(type: PostgresChannelHandler.self).whenSuccess { handler in - if self.notificationListeners[channel] != nil { - self.notificationListeners[channel]!.append((listenContext, notificationHandler)) - } - else { - self.notificationListeners[channel] = [(listenContext, notificationHandler)] - } - } - - listenContext.stopper = { [weak self, weak listenContext] in - // self is weak, since the connection can long be gone, when the listeners stop is - // triggered. listenContext must be weak to prevent a retain cycle - - self?.channel.eventLoop.execute { - guard - let self = self, // the connection is already gone - var listeners = self.notificationListeners[channel] // we don't have the listeners for this topic ¯\_(ツ)_/¯ - else { - return - } - - assert(listeners.filter { $0.0 === listenContext }.count <= 1, "Listeners can not appear twice in a channel!") - listeners.removeAll(where: { $0.0 === listenContext }) // just in case a listener shows up more than once in a release build, remove all, not just first - self.notificationListeners[channel] = listeners.isEmpty ? nil : listeners - } + listenContext.future.whenComplete { _ in + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) } return listenContext } } -extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { - func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) { - self.eventLoop.assertInEventLoop() - - guard let listeners = self.notificationListeners[notification.channel] else { - return - } - - let postgresNotification = PostgresMessage.NotificationResponse( - backendPID: notification.backendPID, - channel: notification.channel, - payload: notification.payload) - - listeners.forEach { (listenContext, handler) in - handler(listenContext, postgresNotification) - } - } -} - enum CloseTarget { case preparedStatement(String) case portal(String) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index ba1e3c1f..761ba5f2 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1091,11 +1091,12 @@ extension ConnectionStateMachine { .tooManyParameters, .invalidCommandTag, .connectionError, - .uncleanShutdown: + .uncleanShutdown, + .unlistenFailed: return true case .queryCancelled: return false - case .server: + case .server, .listenFailed: guard let sqlState = error.serverInfo?[.sqlState] else { // any error message that doesn't have a sql state field, is unexpected by default. return true diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift new file mode 100644 index 00000000..c7f92428 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -0,0 +1,247 @@ +import NIOCore + +struct ListenStateMachine { + var channels: [String: ChannelState] + + init() { + self.channels = [:] + } + + enum StartListeningAction { + case none + case startListening(String) + case succeedListenStart(NotificationListener) + } + + mutating func startListening(_ new: NotificationListener) -> StartListeningAction { + return self.channels[new.channel, default: .init()].start(new) + } + + enum StartListeningSuccessAction { + case stopListening + case activateListeners(Dictionary.Values) + } + + mutating func startListeningSucceeded(channel: String) -> StartListeningSuccessAction { + return self.channels[channel]!.startListeningSucceeded() + } + + mutating func startListeningFailed(channel: String, error: Error) -> Dictionary.Values { + return self.channels[channel]!.startListeningFailed(error) + } + + enum StopListeningSuccessAction { + case startListening + case none + } + + mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction { + return self.channels[channel, default: .init()].stopListeningSucceeded() + } + + enum CancelAction { + case stopListening(String, cancelListener: NotificationListener) + case cancelListener(NotificationListener) + case none + } + + mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction { + return self.channels[channel, default: .init()].cancelListening(id: id) + } + + mutating func fail(_ error: Error) -> [NotificationListener] { + var result = [NotificationListener]() + while var (_, channel) = self.channels.popFirst() { + switch channel.fail(error) { + case .none: + continue + + case .failListeners(let listeners): + result.append(contentsOf: listeners) + } + } + return result + } + + enum ReceivedAction { + case none + case notify(Dictionary.Values) + } + + func notificationReceived(channel: String) -> ReceivedAction { + // TODO: Do we want to close the connection, if we receive a notification on a channel that we don't listen to? + // We can only change this with the next major release, as it would break current functionality. + return self.channels[channel]?.notificationReceived() ?? .none + } +} + +extension ListenStateMachine { + struct ChannelState { + enum State { + case initialized + case starting([Int: NotificationListener]) + case listening([Int: NotificationListener]) + case stopping([Int: NotificationListener]) + case failed(Error) + } + + private var state: State + + init() { + self.state = .initialized + } + + mutating func start(_ new: NotificationListener) -> StartListeningAction { + switch self.state { + case .initialized: + self.state = .starting([new.id: new]) + return .startListening(new.channel) + + case .starting(var listeners): + listeners[new.id] = new + self.state = .starting(listeners) + return .none + + case .listening(var listeners): + listeners[new.id] = new + self.state = .listening(listeners) + return .succeedListenStart(new) + + case .stopping(var listeners): + listeners[new.id] = new + self.state = .stopping(listeners) + return .none + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningSucceeded() -> StartListeningSuccessAction { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + if listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening + } else { + self.state = .listening(listeners) + return .activateListeners(listeners.values) + } + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningFailed(_ error: Error) -> Dictionary.Values { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + self.state = .initialized + return listeners.values + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func stopListeningSucceeded() -> StopListeningSuccessAction { + switch self.state { + case .initialized, .listening, .starting: + fatalError("Invalid state: \(self.state)") + + case .stopping(let listeners): + if listeners.isEmpty { + self.state = .initialized + return .none + } else { + self.state = .starting(listeners) + return .startListening + } + + case .failed: + return .none + } + } + + mutating func cancelListening(id: Int) -> CancelAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .starting(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .listening(var listeners): + precondition(!listeners.isEmpty) + let maybeLast = listeners.removeValue(forKey: id) + if let last = maybeLast, listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening(last.channel, cancelListener: last) + } else { + self.state = .listening(listeners) + if let notLast = maybeLast { + return .cancelListener(notLast) + } + return .none + } + + case .stopping(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .stopping(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .failed: + return .none + } + } + + enum FailAction { + case failListeners(Dictionary.Values) + case none + } + + mutating func fail(_ error: Error) -> FailAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners), .listening(let listeners), .stopping(let listeners): + self.state = .failed(error) + return .failListeners(listeners.values) + + case .failed: + return .none + } + } + + func notificationReceived() -> ReceivedAction { + switch self.state { + case .initialized, .starting: + fatalError("Invalid state: \(self.state)") + + case .listening(let listeners): + return .notify(listeners.values) + + case .stopping: + return .none + + default: + preconditionFailure("TODO: Implemented") + } + } + } +} diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift new file mode 100644 index 00000000..5f4bc3de --- /dev/null +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -0,0 +1,157 @@ +import NIOCore + +// This object is @unchecked Sendable, since we syncronize state on the EL +final class NotificationListener: @unchecked Sendable { + let eventLoop: EventLoop + + let channel: String + let id: Int + + private var state: State + + enum State { + case streamInitialized(CheckedContinuation) + case streamListening(AsyncThrowingStream.Continuation) + + case closure(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) + case done + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + checkedContinuation: CheckedContinuation + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .streamInitialized(checkedContinuation) + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + context: PostgresListenContext, + closure: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .closure(context, closure) + } + + func startListeningSucceeded(handler: PostgresChannelHandler) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + let (stream, continuation) = AsyncThrowingStream.makeStream(of: PostgresNotification.self) + let eventLoop = self.eventLoop + let channel = self.channel + let listenerID = self.id + continuation.onTermination = { reason in + switch reason { + case .cancelled: + eventLoop.execute { + handler.cancelNotificationListener(channel: channel, id: listenerID) + } + + case .finished: + break + + @unknown default: + break + } + } + self.state = .streamListening(continuation) + + let notificationSequence = PostgresNotificationSequence(base: stream) + checkedContinuation.resume(returning: notificationSequence) + + case .streamListening, .done: + fatalError("Invalid state: \(self.state)") + + case .closure: + break // ignore + } + } + + func notificationReceived(_ backendMessage: PostgresBackendMessage.NotificationResponse) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized, .done: + fatalError("Invalid state: \(self.state)") + case .streamListening(let continuation): + continuation.yield(.init(payload: backendMessage.payload)) + + case .closure(let postgresListenContext, let closure): + let message = PostgresMessage.NotificationResponse( + backendPID: backendMessage.backendPID, + channel: backendMessage.channel, + payload: backendMessage.payload + ) + closure(postgresListenContext, message) + } + } + + func failed(_ error: Error) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: error) + + case .streamListening(let continuation): + self.state = .done + continuation.finish(throwing: error) + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } + + func cancelled() { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: PSQLError(code: .queryCancelled)) + + case .streamListening(let continuation): + self.state = .done + continuation.finish() + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } +} + + +#if swift(<5.9) +// Async stream API backfill +extension AsyncThrowingStream { + static func makeStream( + of elementType: Element.Type = Element.self, + throwing failureType: Failure.Type = Failure.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { + var continuation: AsyncThrowingStream.Continuation! + let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) + } + } + #endif diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index df7dd7c1..a13d4209 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -22,6 +22,9 @@ public struct PSQLError: Error { case connectionClosed case connectionError case uncleanShutdown + + case listenFailed + case unlistenFailed } internal var base: Base @@ -46,6 +49,8 @@ public struct PSQLError: Error { public static let connectionClosed = Self(.connectionClosed) public static let connectionError = Self(.connectionError) public static let uncleanShutdown = Self.init(.uncleanShutdown) + public static let listenFailed = Self.init(.listenFailed) + public static let unlistenFailed = Self.init(.unlistenFailed) public var description: String { switch self.base { @@ -81,6 +86,10 @@ public struct PSQLError: Error { return "connectionError" case .uncleanShutdown: return "uncleanShutdown" + case .listenFailed: + return "listenFailed" + case .unlistenFailed: + return "unlistenFailed" } } } @@ -418,6 +427,12 @@ public struct PSQLError: Error { return error } + static func unlistenError(underlying: Error) -> PSQLError { + var error = PSQLError(code: .unlistenFailed) + error.underlying = underlying + return error + } + enum UnsupportedAuthScheme { case none case kerberosV5 diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f9ca1232..26312c0c 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,17 +1,27 @@ import Logging import NIOCore +enum HandlerTask { + case extendedQuery(ExtendedQueryContext) + case preparedStatement(PrepareStatementContext) + case closeCommand(CloseCommandContext) + case startListening(NotificationListener) + case cancelListening(String, Int) +} + enum PSQLTask { case extendedQuery(ExtendedQueryContext) case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) - + func failWithError(_ error: PSQLError) { switch self { case .extendedQuery(let extendedQueryContext): extendedQueryContext.promise.fail(error) + case .preparedStatement(let createPreparedStatementContext): createPreparedStatementContext.promise.fail(error) + case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 09feb521..4470e802 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -3,16 +3,13 @@ import NIOTLS import Crypto import Logging -protocol PSQLChannelHandlerNotificationDelegate: AnyObject { - func notificationReceived(_: PostgresBackendMessage.NotificationResponse) -} - final class PostgresChannelHandler: ChannelDuplexHandler { - typealias OutboundIn = PSQLTask + typealias OutboundIn = HandlerTask typealias InboundIn = ByteBuffer typealias OutboundOut = ByteBuffer private let logger: Logger + private let eventLoop: EventLoop private var state: ConnectionStateMachine /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). @@ -24,15 +21,18 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - - /// this delegate should only be accessed on the connections `EventLoop` - weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? - - init(configuration: PostgresConnection.InternalConfiguration, - logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)?) - { + + private var listenState: ListenStateMachine + + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + logger: Logger, + configureSSLCallback: ((Channel) throws -> Void)? + ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) + self.eventLoop = eventLoop + self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -41,12 +41,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler { #if DEBUG /// for testing purposes only - init(configuration: PostgresConnection.InternalConfiguration, - state: ConnectionStateMachine = .init(.initialized), - logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)?) - { + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + state: ConnectionStateMachine = .init(.initialized), + logger: Logger = .psqlNoOpLogger, + configureSSLCallback: ((Channel) throws -> Void)? + ) { self.state = state + self.eventLoop = eventLoop + self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -194,8 +198,46 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let task = self.unwrapOutboundIn(data) - let action = self.state.enqueue(task: task) + let handlerTask = self.unwrapOutboundIn(data) + let psqlTask: PSQLTask + + switch handlerTask { + case .closeCommand(let command): + psqlTask = .closeCommand(command) + case .extendedQuery(let query): + psqlTask = .extendedQuery(query) + case .preparedStatement(let statement): + psqlTask = .preparedStatement(statement) + + case .startListening(let listener): + switch self.listenState.startListening(listener) { + case .startListening(let channel): + psqlTask = self.makeStartListeningQuery(channel: channel, context: context) + + case .none: + return + + case .succeedListenStart(let listener): + listener.startListeningSucceeded(handler: self) + return + } + + case .cancelListening(let channel, let id): + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .none: + return + + case .stopListening(let channel, let listener): + psqlTask = self.makeUnlistenQuery(channel: channel, context: context) + listener.failed(CancellationError()) + + case .cancelListener(let listener): + listener.failed(CancellationError()) + return + } + } + + let action = self.state.enqueue(task: psqlTask) self.run(action, with: context) } @@ -223,9 +265,34 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + // MARK: Listening + + func cancelNotificationListener(channel: String, id: Int) { + self.eventLoop.preconditionInEventLoop() + + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .cancelListener(let listener): + listener.cancelled() + + case .stopListening(let channel, cancelListener: let listener): + listener.cancelled() + + guard let context = self.handlerContext else { + return + } + + let query = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: query) + self.run(action, with: context) + + case .none: + break + } + } + // MARK: Channel handler actions - func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { + private func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"]) switch action { @@ -333,16 +400,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) } case .forwardNotificationToListeners(let notification): - self.notificationDelegate?.notificationReceived(notification) + self.forwardNotificationToListeners(notification, context: context) } } // MARK: - Private Methods - private func connected(context: ChannelHandlerContext) { - let action = self.state.connected(tls: .init(self.configuration.tls)) - self.run(action, with: context) } @@ -362,8 +427,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func sendPasswordMessage( mode: PasswordAuthencationMode, authContext: AuthContext, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { switch mode { case .md5(let salt): let hash1 = (authContext.password ?? "") + authContext.username @@ -407,7 +472,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) self.encoder.describePreparedStatement(statementName) self.encoder.sync() @@ -485,11 +549,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler { cleanup.tasks.forEach { task in task.failWithError(cleanup.error) } - - // 2. fire an error + + // 2. stop all listeners + for listener in self.listenState.fail(cleanup.error) { + listener.failed(cleanup.error) + } + + // 3. fire an error context.fireErrorCaught(cleanup.error) - // 3. close the connection or fire channel inactive + // 4. close the connection or fire channel inactive switch cleanup.action { case .close: context.close(mode: .all, promise: cleanup.closePromise) @@ -498,6 +567,105 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.fireChannelInactive() } } + + private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + logger: self.logger, + promise: promise + ) + promise.futureResult.whenComplete { result in + self.startListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func startListenCompleted(_ result: Result, for channel: String, context: ChannelHandlerContext) { + switch result { + case .success: + switch self.listenState.startListeningSucceeded(channel: channel) { + case .activateListeners(let listeners): + for list in listeners { + list.startListeningSucceeded(handler: self) + } + + case .stopListening: + let task = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let finalError: PSQLError + if var psqlError = error as? PSQLError { + psqlError.code = .listenFailed + finalError = psqlError + } else { + var psqlError = PSQLError(code: .listenFailed) + psqlError.underlying = error + finalError = psqlError + } + let listeners = self.listenState.startListeningFailed(channel: channel, error: finalError) + for list in listeners { + list.failed(finalError) + } + } + } + + private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + logger: self.logger, + promise: promise + ) + promise.futureResult.whenComplete { result in + self.stopListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func stopListenCompleted( + _ result: Result, + for channel: String, + context: ChannelHandlerContext + ) { + switch result { + case .success: + switch self.listenState.stopListeningSucceeded(channel: channel) { + case .none: + break + + case .startListening: + let task = self.makeStartListeningQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let action = self.state.errorHappened(.unlistenError(underlying: error)) + self.run(action, with: context) + } + } + + private func forwardNotificationToListeners( + _ notification: PostgresBackendMessage.NotificationResponse, + context: ChannelHandlerContext + ) { + switch self.listenState.notificationReceived(channel: notification.channel) { + case .none: + break + + case .notify(let listeners): + for listener in listeners { + listener.notificationReceived(notification) + } + } + } + } extension PostgresChannelHandler: PSQLRowsDataSource { @@ -578,16 +746,3 @@ extension ConnectionStateMachine.TLSConfiguration { } } } - -extension PostgresChannelHandler { - convenience init( - configuration: PostgresConnection.InternalConfiguration, - configureSSLCallback: ((Channel) throws -> Void)?) - { - self.init( - configuration: configuration, - logger: .psqlNoOpLogger, - configureSSLCallback: configureSSLCallback - ) - } -} diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 3963bd62..2a7ec9f1 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -14,7 +14,12 @@ enum PostgresFrontendMessage: Equatable { var preparedStatementName: String /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. - var bind: PostgresBindings + var parameterFormats: [PostgresFormat] + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var parameters: [ByteBuffer?] + + var resultColumnFormats: [PostgresFormat] } struct Cancel: Equatable { diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift new file mode 100644 index 00000000..735c01b0 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -0,0 +1,22 @@ + +public struct PostgresNotification: Sendable { + public let payload: String +} + +public struct PostgresNotificationSequence: AsyncSequence, Sendable { + public typealias Element = PostgresNotification + + let base: AsyncThrowingStream + + public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(base: self.base.makeAsyncIterator()) + } + + public struct AsyncIterator: AsyncIteratorProtocol { + var base: AsyncThrowingStream.AsyncIterator + + public mutating func next() async throws -> Element? { + try await self.base.next() + } + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index ff9773f5..10970b26 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -5,7 +5,7 @@ extension PSQLError { switch self.code.base { case .queryCancelled: return self - case .server: + case .server, .listenFailed: guard let serverInfo = self.serverInfo else { return self } @@ -43,6 +43,8 @@ extension PSQLError { return PostgresError.connectionClosed case .connectionError: return self.underlying ?? self + case .unlistenFailed: + return self.underlying ?? self case .uncleanShutdown: return PostgresError.protocol("Unexpected connection close") } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 7a45c5c0..f68ef1f3 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -224,6 +224,29 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testListenAndNotify() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen("foo") + var iterator = stream.makeAsyncIterator() + + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + + try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) + } + + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") + + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } + } + #if canImport(Network) func testSelect10kRowsNetworkFramework() async throws { let eventLoopGroup = NIOTSEventLoopGroup() diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 342907ea..b9677000 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -125,17 +125,90 @@ extension PostgresFrontendMessage { static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresFrontendMessage { switch messageID { case .bind: - preconditionFailure("TODO: Unimplemented") + guard let portalName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let preparedStatementName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let parameterFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let parameterFormats = (0.. ByteBuffer? in + let length = buffer.readInteger(as: UInt16.self) + switch length { + case .some(..<0): + return nil + case .some(0...): + return buffer.readSlice(length: Int(length!)) + default: + preconditionFailure("TODO: Unimplemented") + } + } + + guard let resultColumnFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let resultColumnFormats = (0.. (PostgresConnection, NIOAsyncTestingChannel) { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = await NIOAsyncTestingChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + ], loop: eventLoop) + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + + self.addTeardownBlock { + try await connection.close() + } + + return (connection, channel) + } +} + +extension NIOAsyncTestingChannel { + + func waitForUnpreparedRequest() async throws -> UnpreparedRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) + } +} + +struct UnpreparedRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} From 5ffc8fc811f3e36317089031f80f15a4d31b5c44 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 05:03:31 -0500 Subject: [PATCH 006/106] Upgrade CI (#382) --- .github/workflows/test.yml | 100 +++++++++++------- .../PostgresNIO/Docs.docc/images/article.svg | 1 + .../Docs.docc/images/vapor-postgres-logo.svg | 36 +++++++ .../PostgresNIO/Docs.docc/theme-settings.json | 46 ++++++++ 4 files changed, 143 insertions(+), 40 deletions(-) create mode 100644 Sources/PostgresNIO/Docs.docc/images/article.svg create mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg create mode 100644 Sources/PostgresNIO/Docs.docc/theme-settings.json diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24821c77..2da05f81 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,49 +17,52 @@ jobs: strategy: fail-fast: false matrix: - container: + swift-image: - swift:5.6-focal - swift:5.7-jammy - swift:5.8-jammy - swiftlang/swift:nightly-5.9-jammy - swiftlang/swift:nightly-main-jammy - container: ${{ matrix.container }} + include: + - swift-image: swift:5.8-jammy + code-coverage: true + container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: - - name: Note Swift version - if: ${{ contains(matrix.swiftver, 'nightly') }} - run: | - echo "SWIFT_PLATFORM=$(. /etc/os-release && echo "${ID}${VERSION_ID}")" >>"${GITHUB_ENV}" - echo "SWIFT_VERSION=$(cat /.swift_tag)" >>"${GITHUB_ENV}" - name: Display OS and Swift versions + shell: bash run: | - printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" + swift --version - name: Check out package uses: actions/checkout@v3 - - name: Run unit tests with code coverage and Thread Sanitizer - run: swift test --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - - name: Submit coverage report to Codecov.io - uses: vapor/swift-codecov-action@v0.2 - with: - cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' - cc_fail_ci_if_error: false + - name: Run unit tests with Thread Sanitizer + env: + CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} + run: | + swift test --filter=^PostgresNIOTests --sanitize=thread ${CODE_COVERAGE} + - name: Submit code coverage + if: ${{ matrix.code-coverage }} + uses: vapor/swift-codecov-action@v0.2 linux-integration-and-dependencies: if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: - dbimage: + postgres-image: - postgres:15 - postgres:13 - postgres:11 include: - - dbimage: postgres:15 - dbauth: scram-sha-256 - - dbimage: postgres:13 - dbauth: md5 - - dbimage: postgres:11 - dbauth: trust + - postgres-image: postgres:15 + postgres-auth: scram-sha-256 + - postgres-image: postgres:13 + postgres-auth: md5 + - postgres-image: postgres:11 + postgres-auth: trust container: image: swift:5.8-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] @@ -79,29 +82,31 @@ jobs: POSTGRES_HOSTNAME_A: 'psql-a' POSTGRES_HOSTNAME_B: 'psql-b' POSTGRES_SOCKET: '/var/run/postgresql/.s.PGSQL.5432' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} services: psql-a: - image: ${{ matrix.dbimage }} + image: ${{ matrix.postgres-image }} volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' POSTGRES_PASSWORD: 'test_password' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} - POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} psql-b: - image: ${{ matrix.dbimage }} + image: ${{ matrix.postgres-image }} volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' POSTGRES_PASSWORD: 'test_password' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} - POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} steps: - name: Display OS and Swift versions run: | + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 @@ -128,33 +133,34 @@ jobs: strategy: fail-fast: false matrix: - dbimage: + postgres-formula: # Only test one version on macOS, let Linux do the rest - postgresql@14 - dbauth: + postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 - xcode: - - latest-stable + xcode-version: + - '~14.3' + - '15.0-beta' runs-on: macos-13 env: POSTGRES_HOSTNAME: 127.0.0.1 POSTGRES_USER: 'test_username' POSTGRES_PASSWORD: 'test_password' POSTGRES_DB: 'postgres' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_AUTH_METHOD: ${{ matrix.postgres-auth }} POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' - POSTGRES_VERSION: ${{ matrix.dbimage }} + POSTGRES_FORMULA: ${{ matrix.postgres-formula }} steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 with: - xcode-version: ${{ matrix.xcode }} + xcode-version: ${{ matrix.xcode-version }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="$(brew --prefix)/opt/${POSTGRES_VERSION}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install "${POSTGRES_VERSION}" && brew link --force "${POSTGRES_VERSION}" - initdb --locale=C --auth-host "${POSTGRES_HOST_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") + export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + (brew unlink postgresql || true) && brew install "${POSTGRES_FORMULA}" && brew link --force "${POSTGRES_FORMULA}" + initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 2 - name: Checkout code @@ -165,7 +171,7 @@ jobs: api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:5.8-jammy + container: swift:jammy steps: - name: Checkout uses: actions/checkout@v3 @@ -177,3 +183,17 @@ jobs: - name: API breaking changes run: swift package diagnose-api-breaking-changes origin/main + gh-codeql: + runs-on: ubuntu-latest + permissions: { security-events: write } + steps: + - name: Check out code + uses: actions/checkout@v3 + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: swift + - name: Perform build + run: swift build + - name: Run CodeQL analyze + uses: github/codeql-action/analyze@v2 diff --git a/Sources/PostgresNIO/Docs.docc/images/article.svg b/Sources/PostgresNIO/Docs.docc/images/article.svg new file mode 100644 index 00000000..3dc6a66c --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/article.svg @@ -0,0 +1 @@ + diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg new file mode 100644 index 00000000..e1c1223b --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json new file mode 100644 index 00000000..c6ce054e --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -0,0 +1,46 @@ +{ + "theme": { + "aside": { + "border-radius": "6px", + "border-style": "double", + "border-width": "3px" + }, + "border-radius": "0", + "button": { + "border-radius": "16px", + "border-width": "1px", + "border-style": "solid" + }, + "code": { + "border-radius": "16px", + "border-width": "1px", + "border-style": "solid" + }, + "color": { + "fill": { + "dark": "rgb(20, 20, 22)", + "light": "rgb(255, 255, 255)" + }, + "psql-blue": "#336791", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #1f1d1f 100%)", + "documentation-intro-accent": "var(--color-psql-blue)", + "documentation-intro-accent-outer": { + "dark": "rgb(255, 255, 255)", + "light": "rgb(51, 51, 51)" + }, + "documentation-intro-accent-inner": { + "dark": "rgb(51, 51, 51)", + "light": "rgb(255, 255, 255)" + } + }, + "icons": { + "technology": "/postgresnio/images/vapor-postgres-logo.svg", + "article": "/postgresnio/images/article.svg" + } + }, + "features": { + "quickNavigation": { + "enable": true + } + } +} From 329ce83ee4d45c063b908f3f66efb49c930ac5f6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 7 Aug 2023 12:16:52 +0200 Subject: [PATCH 007/106] Cleanup PostgresBackendMessage (#384) --- .../ConnectionStateMachine.swift | 13 +---- .../New/Messages/Authentication.swift | 37 ++------------ .../New/Messages/BackendKeyData.swift | 2 +- .../PostgresNIO/New/Messages/DataRow.swift | 2 +- .../New/Messages/ErrorResponse.swift | 4 +- .../New/Messages/NotificationResponse.swift | 2 +- .../New/Messages/ParameterDescription.swift | 2 +- .../New/Messages/ParameterStatus.swift | 2 +- .../New/Messages/ReadyForQuery.swift | 34 ++----------- .../New/Messages/RowDescription.swift | 4 +- .../New/PostgresBackendMessage.swift | 2 +- .../New/PostgresChannelHandler.swift | 8 +-- .../AuthenticationStateMachineTests.swift | 16 +++--- .../ConnectionStateMachineTests.swift | 4 +- .../PSQLBackendMessage+Equatable.swift | 49 ------------------- .../PSQLBackendMessageEncoder.swift | 8 +-- .../New/Messages/AuthenticationTests.swift | 22 +++++---- .../New/PSQLBackendMessageTests.swift | 5 +- .../New/PostgresChannelHandlerTests.swift | 20 +++----- 19 files changed, 57 insertions(+), 179 deletions(-) delete mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 761ba5f2..93312c86 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1258,18 +1258,7 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { enum PasswordAuthencationMode: Equatable { case cleartext - case md5(salt: (UInt8, UInt8, UInt8, UInt8)) - - static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.cleartext, .cleartext): - return true - case (.md5(let lhs), .md5(let rhs)): - return lhs == rhs - default: - return false - } - } + case md5(salt: UInt32) } extension ConnectionStateMachine.State: CustomDebugStringConvertible { diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index bd0d2e57..eff62e91 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -2,10 +2,10 @@ import NIOCore extension PostgresBackendMessage { - enum Authentication: PayloadDecodable { + enum Authentication: PayloadDecodable, Hashable { case ok case kerberosV5 - case md5(salt: (UInt8, UInt8, UInt8, UInt8)) + case md5(salt: UInt32) case plaintext case scmCredential case gss @@ -26,7 +26,7 @@ extension PostgresBackendMessage { case 3: return .plaintext case 5: - guard let salt = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt8, UInt8, UInt8).self) else { + guard let salt = buffer.readInteger(as: UInt32.self) else { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(4, actual: buffer.readableBytes) } return .md5(salt: salt) @@ -61,37 +61,6 @@ extension PostgresBackendMessage { } } -extension PostgresBackendMessage.Authentication: Equatable { - static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.ok, .ok): - return true - case (.kerberosV5, .kerberosV5): - return true - case (.md5(let lhs), .md5(let rhs)): - return lhs == rhs - case (.plaintext, .plaintext): - return true - case (.scmCredential, .scmCredential): - return true - case (.gss, .gss): - return true - case (.sspi, .sspi): - return true - case (.gssContinue(let lhs), .gssContinue(let rhs)): - return lhs == rhs - case (.sasl(let lhs), .sasl(let rhs)): - return lhs == rhs - case (.saslContinue(let lhs), .saslContinue(let rhs)): - return lhs == rhs - case (.saslFinal(let lhs), .saslFinal(let rhs)): - return lhs == rhs - default: - return false - } - } -} - extension PostgresBackendMessage.Authentication: CustomDebugStringConvertible { var debugDescription: String { switch self { diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index 498c5110..31a676d2 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct BackendKeyData: PayloadDecodable, Equatable { + struct BackendKeyData: PayloadDecodable, Hashable { let processID: Int32 let secretKey: Int32 diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index b181e600..491e10dc 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -9,7 +9,7 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler @usableFromInline -struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Equatable { +struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Hashable { @usableFromInline var columnCount: Int16 @usableFromInline diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 818c1ebf..d0bb6044 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -80,7 +80,7 @@ extension PostgresBackendMessage { case routine = 0x52 /// R } - struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Equatable { + struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Hashable { let fields: [PostgresBackendMessage.Field: String] init(fields: [PostgresBackendMessage.Field: String]) { @@ -88,7 +88,7 @@ extension PostgresBackendMessage { } } - struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Equatable { + struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Hashable { let fields: [PostgresBackendMessage.Field: String] init(fields: [PostgresBackendMessage.Field: String]) { diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index 5cd9422e..01b9ab4a 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct NotificationResponse: PayloadDecodable, Equatable { + struct NotificationResponse: PayloadDecodable, Hashable { let backendPID: Int32 let channel: String let payload: String diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 1ccc91e5..4d12b1b6 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct ParameterDescription: PayloadDecodable, Equatable { + struct ParameterDescription: PayloadDecodable, Hashable { /// Specifies the object ID of the parameter data type. var dataTypes: [PostgresDataType] diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 4ffcbe12..52d07e01 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct ParameterStatus: PayloadDecodable, Equatable { + struct ParameterStatus: PayloadDecodable, Hashable { /// The name of the run-time parameter being reported. var parameter: String diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index a300f714..41af1b60 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -1,37 +1,11 @@ import NIOCore extension PostgresBackendMessage { - enum TransactionState: PayloadDecodable, RawRepresentable { - typealias RawValue = UInt8 - - case idle - case inTransaction - case inFailedTransaction - - init?(rawValue: UInt8) { - switch rawValue { - case UInt8(ascii: "I"): - self = .idle - case UInt8(ascii: "T"): - self = .inTransaction - case UInt8(ascii: "E"): - self = .inFailedTransaction - default: - return nil - } - } + enum TransactionState: UInt8, PayloadDecodable, Hashable { + case idle = 73 // ascii: I + case inTransaction = 84 // ascii: T + case inFailedTransaction = 69 // ascii: E - var rawValue: Self.RawValue { - switch self { - case .idle: - return UInt8(ascii: "I") - case .inTransaction: - return UInt8(ascii: "T") - case .inFailedTransaction: - return UInt8(ascii: "E") - } - } - static func decode(from buffer: inout ByteBuffer) throws -> Self { let value = try buffer.throwingReadInteger(as: UInt8.self) guard let state = Self.init(rawValue: value) else { diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 66c71215..766d06e9 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -9,13 +9,13 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler. @usableFromInline -struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Equatable { +struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Hashable { /// Specifies the object ID of the parameter data type. @usableFromInline var columns: [Column] @usableFromInline - struct Column: Equatable, Sendable { + struct Column: Hashable, Sendable { /// The field name. @usableFromInline var name: String diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index ecccd1e9..71c3cacd 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -20,7 +20,7 @@ protocol PSQLMessagePayloadDecodable { /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) -enum PostgresBackendMessage { +enum PostgresBackendMessage: Hashable { typealias PayloadDecodable = PSQLMessagePayloadDecodable diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 4470e802..32c35927 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -437,10 +437,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { var hash2 = [UInt8]() hash2.reserveCapacity(pwdhash.count + 4) hash2.append(contentsOf: pwdhash) - hash2.append(salt.0) - hash2.append(salt.1) - hash2.append(salt.2) - hash2.append(salt.3) + var saltNetworkOrder = salt.bigEndian + withUnsafeBytes(of: &saltNetworkOrder) { ptr in + hash2.append(contentsOf: ptr) + } let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() self.encoder.password(hash.utf8) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index 87478e63..b06b69ab 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -19,8 +19,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) @@ -30,8 +30,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: nil, database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil))) @@ -49,8 +49,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) let fields: [PostgresBackendMessage.Field: String] = [ @@ -107,12 +107,12 @@ class AuthenticationStateMachineTests: XCTestCase { } func testUnexpectedMessagesAfterPasswordSent() { - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) + let salt: UInt32 = 0x00_01_02_03 var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) let unexpected: [PostgresBackendMessage.Authentication] = [ .kerberosV5, - .md5(salt: (0, 1, 2, 3)), + .md5(salt: salt), .plaintext, .scmCredential, .gss, diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 289665fb..d5d4ecb1 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -23,7 +23,7 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.sslHandlerAdded(), .wait) XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0,1,2,3) + let salt: UInt32 = 0x00_01_02_03 XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) } @@ -154,7 +154,7 @@ class ConnectionStateMachineTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let authContext = AuthContext(username: "test", password: "abc123", database: "test") - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) + let salt: UInt32 = 0x00_01_02_03 let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift deleted file mode 100644 index c459ffeb..00000000 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ /dev/null @@ -1,49 +0,0 @@ -@testable import PostgresNIO - -extension PostgresBackendMessage: Equatable { - - public static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.authentication(let lhs), .authentication(let rhs)): - return lhs == rhs - case (.backendKeyData(let lhs), .backendKeyData(let rhs)): - return lhs == rhs - case (.bindComplete, bindComplete): - return true - case (.closeComplete, closeComplete): - return true - case (.commandComplete(let lhs), commandComplete(let rhs)): - return lhs == rhs - case (.dataRow(let lhs), dataRow(let rhs)): - return lhs == rhs - case (.emptyQueryResponse, emptyQueryResponse): - return true - case (.error(let lhs), error(let rhs)): - return lhs == rhs - case (.noData, noData): - return true - case (.notice(let lhs), notice(let rhs)): - return lhs == rhs - case (.notification(let lhs), .notification(let rhs)): - return lhs == rhs - case (.parameterDescription(let lhs), parameterDescription(let rhs)): - return lhs == rhs - case (.parameterStatus(let lhs), parameterStatus(let rhs)): - return lhs == rhs - case (.parseComplete, parseComplete): - return true - case (.portalSuspended, portalSuspended): - return true - case (.readyForQuery(let lhs), readyForQuery(let rhs)): - return lhs == rhs - case (.rowDescription(let lhs), rowDescription(let rhs)): - return lhs == rhs - case (.sslSupported, sslSupported): - return true - case (.sslUnsupported, sslUnsupported): - return true - default: - return false - } - } -} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index eea7dec3..e51c14f9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -9,7 +9,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { /// - parameters: /// - data: The data to encode into a `ByteBuffer`. /// - out: The `ByteBuffer` into which we want to encode. - func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) throws { + func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) { switch message { case .authentication(let authentication): self.encode(messageID: message.id, payload: authentication, into: &buffer) @@ -144,11 +144,7 @@ extension PostgresBackendMessage.Authentication: PSQLMessagePayloadEncodable { buffer.writeInteger(Int32(3)) case .md5(salt: let salt): - buffer.writeInteger(Int32(5)) - buffer.writeInteger(salt.0) - buffer.writeInteger(salt.1) - buffer.writeInteger(salt.2) - buffer.writeInteger(salt.3) + buffer.writeMultipleIntegers(Int32(5), salt) case .scmCredential: buffer.writeInteger(Int32(6)) diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 31a21a91..06e39aae 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -11,35 +11,37 @@ class AuthenticationTests: XCTestCase { let encoder = PSQLBackendMessageEncoder() // add ok - XCTAssertNoThrow(try encoder.encode(data: .authentication(.ok), out: &buffer)) + encoder.encode(data: .authentication(.ok), out: &buffer) expected.append(.authentication(.ok)) // add kerberos - XCTAssertNoThrow(try encoder.encode(data: .authentication(.kerberosV5), out: &buffer)) + encoder.encode(data: .authentication(.kerberosV5), out: &buffer) expected.append(.authentication(.kerberosV5)) // add plaintext - XCTAssertNoThrow(try encoder.encode(data: .authentication(.plaintext), out: &buffer)) + encoder.encode(data: .authentication(.plaintext), out: &buffer) expected.append(.authentication(.plaintext)) // add md5 - XCTAssertNoThrow(try encoder.encode(data: .authentication(.md5(salt: (1, 2, 3, 4))), out: &buffer)) - expected.append(.authentication(.md5(salt: (1, 2, 3, 4)))) - + let salt: UInt32 = 0x01_02_03_04 + encoder.encode(data: .authentication(.md5(salt: salt)), out: &buffer) + expected.append(.authentication(.md5(salt: salt))) + // add scm credential - XCTAssertNoThrow(try encoder.encode(data: .authentication(.scmCredential), out: &buffer)) + encoder.encode(data: .authentication(.scmCredential), out: &buffer) expected.append(.authentication(.scmCredential)) // add gss - XCTAssertNoThrow(try encoder.encode(data: .authentication(.gss), out: &buffer)) + encoder.encode(data: .authentication(.gss), out: &buffer) expected.append(.authentication(.gss)) // add sspi - XCTAssertNoThrow(try encoder.encode(data: .authentication(.sspi), out: &buffer)) + encoder.encode(data: .authentication(.sspi), out: &buffer) expected.append(.authentication(.sspi)) XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + )) } } diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 10e8503a..195c7fb4 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -256,11 +256,12 @@ class PSQLBackendMessageTests: XCTestCase { } func testDebugDescription() { + let salt: UInt32 = 0x00_01_02_03 XCTAssertEqual("\(PostgresBackendMessage.authentication(.ok))", ".authentication(.ok)") XCTAssertEqual("\(PostgresBackendMessage.authentication(.kerberosV5))", ".authentication(.kerberosV5)") - XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: (0, 1, 2, 3))))", - ".authentication(.md5(salt: (0, 1, 2, 3)))") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: salt)))", + ".authentication(.md5(salt: \(salt)))") XCTAssertEqual("\(PostgresBackendMessage.authentication(.plaintext))", ".authentication(.plaintext)") XCTAssertEqual("\(PostgresBackendMessage.authentication(.scmCredential))", diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 4484d6a4..5388e8b5 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -152,19 +152,19 @@ class PostgresChannelHandlerTests: XCTestCase { let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ], loop: self.eventLoop) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + let salt: UInt32 = 0x00_01_02_03 + + let encoder = PSQLBackendMessageEncoder() + var byteBuffer = ByteBuffer() + encoder.encode(data: .authentication(.md5(salt: salt)), out: &byteBuffer) + XCTAssertNoThrow(try embedded.writeInbound(byteBuffer)) - XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.md5(salt: (0,1,2,3))))) - - var message: PostgresFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - - XCTAssertEqual(message, .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) } func testRunAuthenticateCleartext() { @@ -187,11 +187,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.plaintext))) - - var message: PostgresFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - - XCTAssertEqual(message, .password(.init(value: password))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: password))) } func testHandlerThatSendsMultipleWrongMessages() { From 0a1c54e38961a8989d37bb8ee75da38c3f7232aa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 7 Aug 2023 16:18:42 +0200 Subject: [PATCH 008/106] PostgresBackendMessage.ID should be backed by UInt8 directly (#386) --- .../New/PostgresBackendMessage.swift | 160 +++--------------- .../New/PostgresBackendMessageDecoder.swift | 12 +- 2 files changed, 31 insertions(+), 141 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index 71c3cacd..792beec3 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -46,141 +46,31 @@ enum PostgresBackendMessage: Hashable { } extension PostgresBackendMessage { - enum ID: RawRepresentable, Equatable { - typealias RawValue = UInt8 - - case authentication - case backendKeyData - case bindComplete - case closeComplete - case commandComplete - case copyData - case copyDone - case copyInResponse - case copyOutResponse - case copyBothResponse - case dataRow - case emptyQueryResponse - case error - case functionCallResponse - case negotiateProtocolVersion - case noData - case noticeResponse - case notificationResponse - case parameterDescription - case parameterStatus - case parseComplete - case portalSuspended - case readyForQuery - case rowDescription - - init?(rawValue: UInt8) { - switch rawValue { - case UInt8(ascii: "R"): - self = .authentication - case UInt8(ascii: "K"): - self = .backendKeyData - case UInt8(ascii: "2"): - self = .bindComplete - case UInt8(ascii: "3"): - self = .closeComplete - case UInt8(ascii: "C"): - self = .commandComplete - case UInt8(ascii: "d"): - self = .copyData - case UInt8(ascii: "c"): - self = .copyDone - case UInt8(ascii: "G"): - self = .copyInResponse - case UInt8(ascii: "H"): - self = .copyOutResponse - case UInt8(ascii: "W"): - self = .copyBothResponse - case UInt8(ascii: "D"): - self = .dataRow - case UInt8(ascii: "I"): - self = .emptyQueryResponse - case UInt8(ascii: "E"): - self = .error - case UInt8(ascii: "V"): - self = .functionCallResponse - case UInt8(ascii: "v"): - self = .negotiateProtocolVersion - case UInt8(ascii: "n"): - self = .noData - case UInt8(ascii: "N"): - self = .noticeResponse - case UInt8(ascii: "A"): - self = .notificationResponse - case UInt8(ascii: "t"): - self = .parameterDescription - case UInt8(ascii: "S"): - self = .parameterStatus - case UInt8(ascii: "1"): - self = .parseComplete - case UInt8(ascii: "s"): - self = .portalSuspended - case UInt8(ascii: "Z"): - self = .readyForQuery - case UInt8(ascii: "T"): - self = .rowDescription - default: - return nil - } - } - - var rawValue: UInt8 { - switch self { - case .authentication: - return UInt8(ascii: "R") - case .backendKeyData: - return UInt8(ascii: "K") - case .bindComplete: - return UInt8(ascii: "2") - case .closeComplete: - return UInt8(ascii: "3") - case .commandComplete: - return UInt8(ascii: "C") - case .copyData: - return UInt8(ascii: "d") - case .copyDone: - return UInt8(ascii: "c") - case .copyInResponse: - return UInt8(ascii: "G") - case .copyOutResponse: - return UInt8(ascii: "H") - case .copyBothResponse: - return UInt8(ascii: "W") - case .dataRow: - return UInt8(ascii: "D") - case .emptyQueryResponse: - return UInt8(ascii: "I") - case .error: - return UInt8(ascii: "E") - case .functionCallResponse: - return UInt8(ascii: "V") - case .negotiateProtocolVersion: - return UInt8(ascii: "v") - case .noData: - return UInt8(ascii: "n") - case .noticeResponse: - return UInt8(ascii: "N") - case .notificationResponse: - return UInt8(ascii: "A") - case .parameterDescription: - return UInt8(ascii: "t") - case .parameterStatus: - return UInt8(ascii: "S") - case .parseComplete: - return UInt8(ascii: "1") - case .portalSuspended: - return UInt8(ascii: "s") - case .readyForQuery: - return UInt8(ascii: "Z") - case .rowDescription: - return UInt8(ascii: "T") - } - } + enum ID: UInt8, Hashable { + case authentication = 82 // ascii: R + case backendKeyData = 75 // ascii: K + case bindComplete = 50 // ascii: 2 + case closeComplete = 51 // ascii: 3 + case commandComplete = 67 // ascii: C + case copyData = 100 // ascii: d + case copyDone = 99 // ascii: c + case copyInResponse = 71 // ascii: G + case copyOutResponse = 72 // ascii: H + case copyBothResponse = 87 // ascii: W + case dataRow = 68 // ascii: D + case emptyQueryResponse = 73 // ascii: I + case error = 69 // ascii: E + case functionCallResponse = 86 // ascii: V + case negotiateProtocolVersion = 118 // ascii: v + case noData = 110 // ascii: n + case noticeResponse = 78 // ascii: N + case notificationResponse = 65 // ascii: A + case parameterDescription = 116 // ascii: t + case parameterStatus = 83 // ascii: S + case parseComplete = 49 // ascii: 1 + case portalSuspended = 115 // ascii: s + case readyForQuery = 90 // ascii: Z + case rowDescription = 84 // ascii: T } } diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index ee7e1b84..6f6be7ec 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -107,8 +107,8 @@ struct PostgresMessageDecodingError: Error { static func withPartialError( _ partialError: PSQLPartialDecodingError, messageID: UInt8, - messageBytes: ByteBuffer) -> Self - { + messageBytes: ByteBuffer + ) -> Self { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! @@ -124,8 +124,8 @@ struct PostgresMessageDecodingError: Error { messageID: UInt8, messageBytes: ByteBuffer, file: String = #fileID, - line: Int = #line) -> Self - { + line: Int = #line + ) -> Self { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! @@ -153,8 +153,8 @@ struct PSQLPartialDecodingError: Error { value: Target.RawValue, asType: Target.Type, file: String = #fileID, - line: Int = #line) -> Self - { + line: Int = #line + ) -> Self { return PSQLPartialDecodingError( description: "Can not represent '\(value)' with type '\(asType)'.", file: file, line: line) From 220eb501f336ec3e22605e9c16dc7d8ce4251e6b Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 16:18:57 -0500 Subject: [PATCH 009/106] Typo fix: Storiage -> Storage (#387) --- Sources/PostgresNIO/New/PSQLError.swift | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index a13d4209..5d9e534c 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -96,7 +96,7 @@ public struct PSQLError: Error { private var backing: Backing - private mutating func copyBackingStoriageIfNecessary() { + private mutating func copyBackingStorageIfNecessary() { if !isKnownUniquelyReferenced(&self.backing) { self.backing = self.backing.copy() } @@ -106,7 +106,7 @@ public struct PSQLError: Error { public internal(set) var code: Code { get { self.backing.code } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.code = newValue } } @@ -115,7 +115,7 @@ public struct PSQLError: Error { public internal(set) var serverInfo: ServerInfo? { get { self.backing.serverInfo } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.serverInfo = newValue } } @@ -124,7 +124,7 @@ public struct PSQLError: Error { public internal(set) var underlying: Error? { get { self.backing.underlying } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.underlying = newValue } } @@ -133,7 +133,7 @@ public struct PSQLError: Error { public internal(set) var file: String? { get { self.backing.file } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.file = newValue } } @@ -142,7 +142,7 @@ public struct PSQLError: Error { public internal(set) var line: Int? { get { self.backing.line } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.line = newValue } } @@ -151,7 +151,7 @@ public struct PSQLError: Error { public internal(set) var query: PostgresQuery? { get { self.backing.query } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.query = newValue } } @@ -161,7 +161,7 @@ public struct PSQLError: Error { var backendMessage: PostgresBackendMessage? { get { self.backing.backendMessage } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.backendMessage = newValue } } @@ -171,7 +171,7 @@ public struct PSQLError: Error { var unsupportedAuthScheme: UnsupportedAuthScheme? { get { self.backing.unsupportedAuthScheme } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.unsupportedAuthScheme = newValue } } @@ -181,7 +181,7 @@ public struct PSQLError: Error { var invalidCommandTag: String? { get { self.backing.invalidCommandTag } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.invalidCommandTag = newValue } } From c5737e8a54c59da09bb1e699ab1c4e4b4fd99844 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 23:35:42 -0500 Subject: [PATCH 010/106] [no ci] Fix missing docs attribute --- Sources/PostgresNIO/Docs.docc/index.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index b4dc7e30..ebe27cd0 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -1,5 +1,9 @@ # ``PostgresNIO`` +@Metadata { + @TitleHeading(Package) +} + 🐘 Non-blocking, event-driven Swift client for PostgreSQL built on SwiftNIO. ## Overview From b6597f7c419a70a31b08b0dcafafe052c58b1d86 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 9 Aug 2023 23:07:39 +0200 Subject: [PATCH 011/106] Remove PrepareStatementStateMachine (#391) Preparing a statement is a substep of running an extended query. For this reason we should reuse the `ExtendedQueryStateMachine` as much as we can. This patch removes the `PrepareStatementStateMachine` and uses the `ExtendedQueryStateMachine`. As a result of this we can simplify our code in lots of other places. --- .../Connection/PostgresConnection.swift | 7 +- .../ConnectionStateMachine.swift | 165 +++---------- .../ExtendedQueryStateMachine.swift | 157 +++++++++---- .../PrepareStatementStateMachine.swift | 147 ------------ Sources/PostgresNIO/New/PSQLRowStream.swift | 36 ++- Sources/PostgresNIO/New/PSQLTask.swift | 71 +++--- .../New/PostgresChannelHandler.swift | 67 +++--- .../ConnectionStateMachineTests.swift | 6 +- .../ExtendedQueryStateMachineTests.swift | 16 +- .../PrepareStatementStateMachineTests.swift | 47 ++-- .../ConnectionAction+TestUtils.swift | 17 +- .../New/PSQLRowStreamTests.swift | 140 ++++------- .../New/PostgresRowSequenceTests.swift | 220 ++++++++---------- 13 files changed, 421 insertions(+), 675 deletions(-) delete mode 100644 Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d6420a6e..6f849bdd 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -230,13 +230,14 @@ public final class PostgresConnection: @unchecked Sendable { func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) - let context = PrepareStatementContext( + let context = ExtendedQueryContext( name: name, query: query, logger: logger, - promise: promise) + promise: promise + ) - self.channel.write(HandlerTask.preparedStatement(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 93312c86..0f3e96c9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -31,7 +31,6 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) - case prepareStatement(PrepareStatementStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) case error(PSQLError) @@ -89,10 +88,9 @@ struct ConnectionStateMachine { // --- general actions case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) - case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) - case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) - case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) - + case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + case succeedQuery(EventLoopPromise, with: QueryResult) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -101,9 +99,9 @@ struct ConnectionStateMachine { // Prepare statement actions case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) - case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) - + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + // Close actions case sendCloseSync(CloseTarget) case succeedClose(CloseCommandContext) @@ -159,7 +157,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -214,7 +211,6 @@ struct ConnectionStateMachine { .authenticating, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand: return self.errorHappened(.uncleanShutdown) @@ -245,7 +241,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -274,7 +269,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -296,7 +290,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -322,7 +315,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -391,12 +383,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(query, connectionContext) return .wait } - case .prepareStatement(let prepareState, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .prepareStatement(prepareState, connectionContext) - return .wait - } case .closeCommand(let closeState, var connectionContext): return self.avoidingStateMachineCoW { machine in connectionContext.parameters[status.parameter] = status.value @@ -450,15 +436,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQueryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext): - if preparedState.isComplete { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) - } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.errorReceived(errorMessage) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -493,13 +470,6 @@ struct ConnectionStateMachine { let action = queryState.errorHappened(error) return self.modify(with: action) } - case .prepareStatement(var prepareState, _): - if prepareState.isComplete { - return self.closeConnectionAndCleanup(error) - } else { - let action = prepareState.errorHappened(error) - return self.modify(with: action) - } case .closeCommand(var closeState, _): if closeState.isComplete { return self.closeConnectionAndCleanup(error) @@ -567,16 +537,6 @@ struct ConnectionStateMachine { self.state = .readyForQuery(connectionContext) return self.executeNextQueryFromQueue() - case .prepareStatement(let preparedStateMachine, var connectionContext): - guard preparedStateMachine.isComplete else { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) - } - - connectionContext.transactionState = transactionState - - self.state = .readyForQuery(connectionContext) - return self.executeNextQueryFromQueue() - case .closeCommand(let closeStateMachine, var connectionContext): guard closeStateMachine.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) @@ -597,9 +557,13 @@ struct ConnectionStateMachine { if case .quiescing = self.quiescingState { switch task { case .extendedQuery(let queryContext): - return .failQuery(queryContext, with: .connectionQuiescing, cleanupContext: nil) - case .preparedStatement(let prepareContext): - return .failPreparedStatementCreation(prepareContext, with: .connectionQuiescing, cleanupContext: nil) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) + } + case .closeCommand(let closeContext): return .failClose(closeContext, with: .connectionQuiescing, cleanupContext: nil) } @@ -611,9 +575,12 @@ struct ConnectionStateMachine { case .closed: switch task { case .extendedQuery(let queryContext): - return .failQuery(queryContext, with: .connectionClosed, cleanupContext: nil) - case .preparedStatement(let prepareContext): - return .failPreparedStatementCreation(prepareContext, with: .connectionClosed, cleanupContext: nil) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) + } case .closeCommand(let closeContext): return .failClose(closeContext, with: .connectionClosed, cleanupContext: nil) } @@ -633,7 +600,6 @@ struct ConnectionStateMachine { .authenticating, .authenticated, .readyForQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -676,12 +642,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQuery, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedStatement, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = preparedStatement.readEventCaught() - machine.state = .prepareStatement(preparedStatement, connectionContext) - return machine.modify(with: action) - } case .closeCommand(var closeState, let connectionContext): return self.avoidingStateMachineCoW { machine in let action = closeState.readEventCaught() @@ -709,12 +669,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.parseCompletedReceived() - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parseComplete)) } @@ -740,12 +694,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.parameterDescriptionReceived(description) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterDescription(description))) } @@ -759,12 +707,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.rowDescriptionReceived(description) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -778,12 +720,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.noDataReceived() - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.noData)) } @@ -909,6 +845,7 @@ struct ConnectionStateMachine { preconditionFailure("Expect to fail auth") } return .closeConnectionAndCleanup(cleanupContext) + case .extendedQuery(var queryStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error) @@ -921,9 +858,10 @@ struct ConnectionStateMachine { switch queryStateMachine.errorHappened(error) { case .sendParseDescribeBindExecuteSync, + .sendParseDescribeSync, .sendBindExecuteSync, .succeedQuery, - .succeedQueryNoRowsComming, + .succeedPreparedStatementCreation, .forwardRows, .forwardStreamComplete, .wait, @@ -935,26 +873,10 @@ struct ConnectionStateMachine { return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + case .failPreparedStatementCreation(let promise, with: let error): + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } - case .prepareStatement(var prepareStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - - if prepareStateMachine.isComplete { - // in case the prepare state machine is complete all necessary actions have already - // been forwarded to the consumer. We can close and cleanup without caring about the - // substate machine. - return .closeConnectionAndCleanup(cleanupContext) - } - - switch prepareStateMachine.errorHappened(error) { - case .sendParseDescribeSync, - .succeedPreparedStatementCreation, - .read, - .wait: - preconditionFailure("Invalid state: \(self.state)") - case .failPreparedStatementCreation(let preparedStatementContext, with: let error): - return .failPreparedStatementCreation(preparedStatementContext, with: error, cleanupContext: cleanupContext) - } + case .closeCommand(var closeStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error) @@ -974,6 +896,7 @@ struct ConnectionStateMachine { case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } + case .error, .closing, .closed: // We might run into this case because of reentrancy. For example: After we received an // backend unexpected message, that we read of the wire, we bring this connection into @@ -1018,13 +941,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQuery, connectionContext) return machine.modify(with: action) } - case .preparedStatement(let prepareContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var prepareStatement = PrepareStatementStateMachine(createContext: prepareContext) - let action = prepareStatement.start() - machine.state = .prepareStatement(prepareStatement, connectionContext) - return machine.modify(with: action) - } case .closeCommand(let closeContext): return self.avoidingStateMachineCoW { machine -> ConnectionAction in var closeStateMachine = CloseStateMachine(closeContext: closeContext) @@ -1153,10 +1069,8 @@ extension ConnectionStateMachine { case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) - case .succeedQuery(let requestContext, columns: let columns): - return .succeedQuery(requestContext, columns: columns) - case .succeedQueryNoRowsComming(let requestContext, let commandTag): - return .succeedQueryNoRowsComming(requestContext, commandTag: commandTag) + case .succeedQuery(let requestContext, with: let result): + return .succeedQuery(requestContext, with: result) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): @@ -1174,24 +1088,13 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - } - } -} - -extension ConnectionStateMachine { - mutating func modify(with action: PrepareStatementStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { - switch action { - case .sendParseDescribeSync(let name, let query): + case .sendParseDescribeSync(name: let name, query: let query): return .sendParseDescribeSync(name: name, query: query) - case .succeedPreparedStatementCreation(let prepareContext, with: let rowDescription): - return .succeedPreparedStatementCreation(prepareContext, with: rowDescription) - case .failPreparedStatementCreation(let prepareContext, with: let error): + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + return .succeedPreparedStatementCreation(promise, with: rowDescription) + case .failPreparedStatementCreation(let promise, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) - return .failPreparedStatementCreation(prepareContext, with: error, cleanupContext: cleanupContext) - case .read: - return .read - case .wait: - return .wait + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } } } @@ -1282,8 +1185,6 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".readyForQuery(connectionContext: \(String(reflecting: connectionContext)))" case .extendedQuery(let subStateMachine, let connectionContext): return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" - case .prepareStatement(let subStateMachine, let connectionContext): - return ".prepareStatement(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .error(let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 8b46fd0b..3a84031b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -4,7 +4,7 @@ struct ExtendedQueryStateMachine { private enum State { case initialized(ExtendedQueryContext) - case parseDescribeBindExecuteSyncSent(ExtendedQueryContext) + case messagesSent(ExtendedQueryContext) case parseCompleteReceived(ExtendedQueryContext) case parameterDescriptionReceived(ExtendedQueryContext) @@ -26,15 +26,18 @@ struct ExtendedQueryStateMachine { enum Action { case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendParseDescribeSync(name: String, query: String) case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions - case failQuery(ExtendedQueryContext, with: PSQLError) - case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) - case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) + case failQuery(EventLoopPromise, with: PSQLError) + case succeedQuery(EventLoopPromise, with: QueryResult) case evaluateErrorAtConnectionLevel(PSQLError) + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -59,13 +62,13 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed(let query): + case .unnamed(let query, _): return self.avoidingStateMachineCoW { state -> Action in - state = .parseDescribeBindExecuteSyncSent(queryContext) + state = .messagesSent(queryContext) return .sendParseDescribeBindExecuteSync(query) } - case .preparedStatement(let prepared): + case .executeStatement(let prepared, _): return self.avoidingStateMachineCoW { state -> Action in switch prepared.rowDescription { case .some(let rowDescription): @@ -75,6 +78,12 @@ struct ExtendedQueryStateMachine { } return .sendBindExecuteSync(prepared) } + + case .prepareStatement(let name, let query, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendParseDescribeSync(name: name, query: query) + } } } @@ -83,7 +92,7 @@ struct ExtendedQueryStateMachine { case .initialized: preconditionFailure("Start must be called immediatly after the query was created") - case .parseDescribeBindExecuteSyncSent(let queryContext), + case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), .parameterDescriptionReceived(let queryContext), .rowDescriptionReceived(let queryContext, _), @@ -94,7 +103,13 @@ struct ExtendedQueryStateMachine { } self.isCancelled = true - return .failQuery(queryContext, with: .queryCancelled) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .queryCancelled) + + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) + } case .streaming(let columns, var streamStateMachine): precondition(!self.isCancelled) @@ -117,7 +132,7 @@ struct ExtendedQueryStateMachine { } mutating func parseCompletedReceived() -> Action { - guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else { + guard case .messagesSent(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.parseComplete)) } @@ -143,9 +158,18 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(.unexpectedBackendMessage(.noData)) } - return self.avoidingStateMachineCoW { state -> Action in - state = .noDataMessageReceived(queryContext) - return .wait + switch queryContext.query { + case .unnamed, .executeStatement: + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .wait + } + + case .prepareStatement(_, _, let promise): + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .succeedPreparedStatementCreation(promise, with: nil) + } } } @@ -153,40 +177,56 @@ struct ExtendedQueryStateMachine { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } - - return self.avoidingStateMachineCoW { state -> Action in - // In Postgres extended queries we receive the `rowDescription` before we send the - // `Bind` message. Well actually it's vice versa, but this is only true since we do - // pipelining during a query. - // - // In the actual protocol description we receive a rowDescription before the Bind - - // In Postgres extended queries we always request the response rows to be returned in - // `.binary` format. - let columns = rowDescription.columns.map { column -> RowDescription.Column in - var column = column - column.format = .binary - return column - } + + // In Postgres extended queries we receive the `rowDescription` before we send the + // `Bind` message. Well actually it's vice versa, but this is only true since we do + // pipelining during a query. + // + // In the actual protocol description we receive a rowDescription before the Bind + + // In Postgres extended queries we always request the response rows to be returned in + // `.binary` format. + let columns = rowDescription.columns.map { column -> RowDescription.Column in + var column = column + column.format = .binary + return column + } + + self.avoidingStateMachineCoW { state in state = .rowDescriptionReceived(queryContext, columns) + } + + switch queryContext.query { + case .unnamed, .executeStatement: return .wait + + case .prepareStatement(_, _, let eventLoopPromise): + return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) } } mutating func bindCompleteReceived() -> Action { switch self.state { - case .rowDescriptionReceived(let context, let columns): - return self.avoidingStateMachineCoW { state -> Action in - state = .streaming(columns, .init()) - return .succeedQuery(context, columns: columns) + case .rowDescriptionReceived(let queryContext, let columns): + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .streaming(columns, .init()) + let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) } + case .noDataMessageReceived(let queryContext): return self.avoidingStateMachineCoW { state -> Action in state = .bindCompleteReceived(queryContext) return .wait } case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, @@ -224,7 +264,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -241,9 +281,16 @@ struct ExtendedQueryStateMachine { mutating func commandCompletedReceived(_ commandTag: String) -> Action { switch self.state { case .bindCompleteReceived(let context): - return self.avoidingStateMachineCoW { state -> Action in - state = .commandComplete(commandTag: commandTag) - return .succeedQueryNoRowsComming(context, commandTag: commandTag) + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + preconditionFailure("Invalid state: \(self.state)") } case .streaming(_, var demandStateMachine): @@ -258,7 +305,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -280,7 +327,7 @@ struct ExtendedQueryStateMachine { switch self.state { case .initialized: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) - case .parseDescribeBindExecuteSyncSent, + case .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived: @@ -331,7 +378,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -354,7 +401,7 @@ struct ExtendedQueryStateMachine { .commandComplete, .drain, .error, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -381,7 +428,7 @@ struct ExtendedQueryStateMachine { mutating func readEventCaught() -> Action { switch self.state { - case .parseDescribeBindExecuteSyncSent, + case .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -417,7 +464,7 @@ struct ExtendedQueryStateMachine { private mutating func setAndFireError(_ error: PSQLError) -> Action { switch self.state { case .initialized(let context), - .parseDescribeBindExecuteSyncSent(let context), + .messagesSent(let context), .parseCompleteReceived(let context), .parameterDescriptionReceived(let context), .rowDescriptionReceived(let context, _), @@ -427,7 +474,12 @@ struct ExtendedQueryStateMachine { if self.isCancelled { return .evaluateErrorAtConnectionLevel(error) } else { - return .failQuery(context, with: error) + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: error) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: error) + } } case .drain: @@ -455,11 +507,22 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, - .error: + case .commandComplete, .error: return true - default: + + case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): + switch context.query { + case .prepareStatement: + return true + case .unnamed, .executeStatement: + return false + } + + case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: return false + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift deleted file mode 100644 index 5b65fc90..00000000 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ /dev/null @@ -1,147 +0,0 @@ - -struct PrepareStatementStateMachine { - - enum State { - case initialized(PrepareStatementContext) - case parseDescribeSent(PrepareStatementContext) - - case parseCompleteReceived(PrepareStatementContext) - case parameterDescriptionReceived(PrepareStatementContext) - case rowDescriptionReceived - case noDataMessageReceived - - case error(PSQLError) - } - - enum Action { - case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) - case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError) - - case read - case wait - } - - var state: State - - init(createContext: PrepareStatementContext) { - self.state = .initialized(createContext) - } - - #if DEBUG - /// for testing purposes only - init(_ state: State) { - self.state = state - } - #endif - - mutating func start() -> Action { - guard case .initialized(let createContext) = self.state else { - preconditionFailure("Start must only be called after the query has been initialized") - } - - self.state = .parseDescribeSent(createContext) - - return .sendParseDescribeSync(name: createContext.name, query: createContext.query) - } - - mutating func parseCompletedReceived() -> Action { - guard case .parseDescribeSent(let createContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.parseComplete)) - } - - self.state = .parseCompleteReceived(createContext) - return .wait - } - - mutating func parameterDescriptionReceived(_ parameterDescription: PostgresBackendMessage.ParameterDescription) -> Action { - guard case .parseCompleteReceived(let createContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.parameterDescription(parameterDescription))) - } - - self.state = .parameterDescriptionReceived(createContext) - return .wait - } - - mutating func noDataReceived() -> Action { - guard case .parameterDescriptionReceived(let queryContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.noData)) - } - - self.state = .noDataMessageReceived - return .succeedPreparedStatementCreation(queryContext, with: nil) - } - - mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { - guard case .parameterDescriptionReceived(let queryContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) - } - - self.state = .rowDescriptionReceived - return .succeedPreparedStatementCreation(queryContext, with: rowDescription) - } - - mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { - let error = PSQLError.server(errorMessage) - switch self.state { - case .initialized: - return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) - - case .parseDescribeSent, - .parseCompleteReceived, - .parameterDescriptionReceived: - return self.setAndFireError(error) - - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - preconditionFailure(""" - This state must not be reached. If the prepared statement `.isComplete`, the - ConnectionStateMachine must not send any further events to the substate machine. - """) - } - } - - mutating func errorHappened(_ error: PSQLError) -> Action { - return self.setAndFireError(error) - } - - private mutating func setAndFireError(_ error: PSQLError) -> Action { - switch self.state { - case .initialized(let context), - .parseDescribeSent(let context), - .parseCompleteReceived(let context), - .parameterDescriptionReceived(let context): - self.state = .error(error) - return .failPreparedStatementCreation(context, with: error) - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - preconditionFailure(""" - This state must not be reached. If the prepared statement `.isComplete`, the - ConnectionStateMachine must not send any further events to the substate machine. - """) - } - } - - // MARK: Channel actions - - mutating func readEventCaught() -> Action { - return .read - } - - var isComplete: Bool { - switch self.state { - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - return true - case .initialized, - .parseDescribeSent, - .parseCompleteReceived, - .parameterDescriptionReceived: - return false - } - } - -} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 4c842275..b008d185 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -1,12 +1,23 @@ import NIOCore import Logging +struct QueryResult { + enum Value: Equatable { + case noRows(String) + case rowDescription([RowDescription.Column]) + } + + var value: Value + + var logger: Logger +} + // Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source - enum RowSource { - case stream(PSQLRowsDataSource) + enum Source { + case stream([RowDescription.Column], PSQLRowsDataSource) case noRows(Result) } @@ -31,27 +42,28 @@ final class PSQLRowStream: @unchecked Sendable { private let lookupTable: [String: Int] private var downstreamState: DownstreamState - init(rowDescription: [RowDescription.Column], - queryContext: ExtendedQueryContext, - eventLoop: EventLoop, - rowSource: RowSource) - { + init( + source: Source, + eventLoop: EventLoop, + logger: Logger + ) { let bufferState: BufferState - switch rowSource { - case .stream(let dataSource): + switch source { + case .stream(let rowDescription, let dataSource): + self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) case .noRows(.success(let commandTag)): + self.rowDescription = [] bufferState = .finished(buffer: .init(), commandTag: commandTag) case .noRows(.failure(let error)): + self.rowDescription = [] bufferState = .failure(error) } self.downstreamState = .waitingForConsumer(bufferState) self.eventLoop = eventLoop - self.logger = queryContext.logger - - self.rowDescription = rowDescription + self.logger = logger var lookup = [String: Int]() lookup.reserveCapacity(rowDescription.count) diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 26312c0c..f5de6561 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -3,7 +3,6 @@ import NIOCore enum HandlerTask { case extendedQuery(ExtendedQueryContext) - case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) @@ -11,16 +10,19 @@ enum HandlerTask { enum PSQLTask { case extendedQuery(ExtendedQueryContext) - case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) func failWithError(_ error: PSQLError) { switch self { case .extendedQuery(let extendedQueryContext): - extendedQueryContext.promise.fail(error) - - case .preparedStatement(let createPreparedStatementContext): - createPreparedStatementContext.promise.fail(error) + switch extendedQueryContext.query { + case .unnamed(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .executeStatement(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .prepareStatement(_, _, let eventLoopPromise): + eventLoopPromise.fail(error) + } case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) @@ -30,49 +32,40 @@ enum PSQLTask { final class ExtendedQueryContext { enum Query { - case unnamed(PostgresQuery) - case preparedStatement(PSQLExecuteStatement) + case unnamed(PostgresQuery, EventLoopPromise) + case executeStatement(PSQLExecuteStatement, EventLoopPromise) + case prepareStatement(name: String, query: String, EventLoopPromise) } let query: Query let logger: Logger - - let promise: EventLoopPromise - init(query: PostgresQuery, - logger: Logger, - promise: EventLoopPromise) - { - self.query = .unnamed(query) + init( + query: PostgresQuery, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .unnamed(query, promise) self.logger = logger - self.promise = promise } - init(executeStatement: PSQLExecuteStatement, - logger: Logger, - promise: EventLoopPromise) - { - self.query = .preparedStatement(executeStatement) + init( + executeStatement: PSQLExecuteStatement, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .executeStatement(executeStatement, promise) self.logger = logger - self.promise = promise } -} -final class PrepareStatementContext { - let name: String - let query: String - let logger: Logger - let promise: EventLoopPromise - - init(name: String, - query: String, - logger: Logger, - promise: EventLoopPromise) - { - self.name = name - self.query = query + init( + name: String, + query: String, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .prepareStatement(name: name, query: query, promise) self.logger = logger - self.promise = promise } } @@ -83,8 +76,8 @@ final class CloseCommandContext { init(target: CloseTarget, logger: Logger, - promise: EventLoopPromise) - { + promise: EventLoopPromise + ) { self.target = target self.logger = logger self.promise = promise diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32c35927..abfa5aeb 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -206,8 +206,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { psqlTask = .closeCommand(command) case .extendedQuery(let query): psqlTask = .extendedQuery(query) - case .preparedStatement(let statement): - psqlTask = .preparedStatement(statement) case .startListening(let listener): switch self.listenState.startListening(listener) { @@ -326,12 +324,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) - case .succeedQuery(let queryContext, columns: let columns): - self.succeedQueryWithRowStream(queryContext, columns: columns, context: context) - case .succeedQueryNoRowsComming(let queryContext, let commandTag): - self.succeedQueryWithoutRowStream(queryContext, commandTag: commandTag, context: context) - case .failQuery(let queryContext, with: let error, let cleanupContext): - queryContext.promise.fail(error) + case .succeedQuery(let promise, with: let result): + self.succeedQuery(promise, result: result, context: context) + case .failQuery(let promise, with: let error, let cleanupContext): + promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } @@ -383,10 +379,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } context.close(mode: .all, promise: promise) - case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): - preparedContext.promise.succeed(rowDescription) - case .failPreparedStatementCreation(let preparedContext, with: let error, let cleanupContext): - preparedContext.promise.fail(error) + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + promise.succeed(rowDescription) + case .failPreparedStatementCreation(let promise, with: let error, let cleanupContext): + promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } @@ -510,33 +506,30 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } - private func succeedQueryWithRowStream( - _ queryContext: ExtendedQueryContext, - columns: [RowDescription.Column], + private func succeedQuery( + _ promise: EventLoopPromise, + result: QueryResult, context: ChannelHandlerContext ) { - let rows = PSQLRowStream( - rowDescription: columns, - queryContext: queryContext, - eventLoop: context.channel.eventLoop, - rowSource: .stream(self)) - - self.rowStream = rows - queryContext.promise.succeed(rows) - } - - private func succeedQueryWithoutRowStream( - _ queryContext: ExtendedQueryContext, - commandTag: String, - context: ChannelHandlerContext - ) { - let rows = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: context.channel.eventLoop, - rowSource: .noRows(.success(commandTag)) - ) - queryContext.promise.succeed(rows) + let rows: PSQLRowStream + switch result.value { + case .rowDescription(let columns): + rows = PSQLRowStream( + source: .stream(columns, self), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + self.rowStream = rows + + case .noRows(let commandTag): + rows = PSQLRowStream( + source: .noRows(.success(commandTag)), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + } + + promise.succeed(rows) } private func closeConnectionAndCleanup( diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index d5d4ecb1..5fd3bc20 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -180,9 +180,9 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.errorReceived(.init(fields: fields)), .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext)], error: .server(.init(fields: fields)), closePromise: nil))) - XCTAssertNil(extendedQueryContext.promise.futureResult._value) - + XCTAssertNil(queryPromise.futureResult._value) + // make sure we don't crash - extendedQueryContext.promise.fail(PSQLError.server(.init(fields: fields))) + queryPromise.fail(PSQLError.server(.init(fields: fields))) } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index eac46e5f..40e32468 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQueryNoRowsComming(queryContext, commandTag: "DELETE 1")) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -49,7 +49,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) @@ -93,7 +93,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let psqlError = PSQLError.unexpectedBackendMessage(.authentication(.ok)) XCTAssertEqual(state.authenticationMessageReceived(.ok), - .failQuery(queryContext, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) + .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } func testExtendedQueryIsCancelledImmediatly() { @@ -121,7 +121,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) @@ -165,7 +165,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) @@ -207,7 +207,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let dataRows1: [DataRow] = [ [ByteBuffer(string: "test1")], [ByteBuffer(string: "test2")], @@ -251,7 +251,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual( - state.errorReceived(serverError), .failQuery(queryContext, with: .server(serverError), cleanupContext: .none) + state.errorReceived(serverError), .failQuery(promise, with: .server(serverError), cleanupContext: .none) ) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) @@ -269,7 +269,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - XCTAssertEqual(state.cancelQueryStream(), .failQuery(queryContext, with: .queryCancelled, cleanupContext: .none)) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual(state.errorReceived(serverError), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 6cff280e..6a08afeb 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -3,7 +3,6 @@ import NIOEmbedded @testable import PostgresNIO class PrepareStatementStateMachineTests: XCTestCase { - func testCreatePreparedStatementReturningRowDescription() { var state = ConnectionStateMachine.readyForQuery() @@ -12,10 +11,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# - let prepareStatementContext = PrepareStatementContext( - name: name, query: query, logger: .psqlTest, promise: promise) - - XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), .sendParseDescribeSync(name: name, query: query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -25,7 +25,7 @@ class PrepareStatementStateMachineTests: XCTestCase { ] XCTAssertEqual(state.rowDescriptionReceived(.init(columns: columns)), - .succeedPreparedStatementCreation(prepareStatementContext, with: .init(columns: columns))) + .succeedPreparedStatementCreation(promise, with: .init(columns: columns))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -37,25 +37,42 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# - let prepareStatementContext = PrepareStatementContext( - name: name, query: query, logger: .psqlTest, promise: promise) - - XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), .sendParseDescribeSync(name: name, query: query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), - .succeedPreparedStatementCreation(prepareStatementContext, with: nil)) + .succeedPreparedStatementCreation(promise, with: nil)) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } func testErrorReceivedAfter() { - let connectionContext = ConnectionStateMachine.createConnectionContext() - var state = ConnectionStateMachine(.prepareStatement(.init(.noDataMessageReceived), connectionContext)) - + var state = ConnectionStateMachine.readyForQuery() + + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + + let name = "haha" + let query = #"DELETE FROM users WHERE id = $1 "# + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), + .sendParseDescribeSync(name: name, query: query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + XCTAssertEqual(state.noDataReceived(), + .succeedPreparedStatementCreation(promise, with: nil)) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(.ok)), closePromise: nil))) } - } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 72420798..febeee37 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -25,13 +25,10 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lquery == rquery case (.fireEventReadyForQuery, .fireEventReadyForQuery): return true - - case (.succeedQueryNoRowsComming(let lhsContext, let lhsCommandTag), .succeedQueryNoRowsComming(let rhsContext, let rhsCommandTag)): - return lhsContext === rhsContext && lhsCommandTag == rhsCommandTag - case (.succeedQuery(let lhsContext, let lhsRowDescription), .succeedQuery(let rhsContext, let rhsRowDescription)): - return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription - case (.failQuery(let lhsContext, let lhsError, let lhsCleanupContext), .failQuery(let rhsContext, let rhsError, let rhsCleanupContext)): - return lhsContext === rhsContext && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext + case (.succeedQuery(let lhsPromise, let lhsResult), .succeedQuery(let rhsPromise, let rhsResult)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsResult.value == rhsResult.value + case (.failQuery(let lhsPromise, let lhsError, let lhsCleanupContext), .failQuery(let rhsPromise, let rhsError, let rhsCleanupContext)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext case (.forwardRows(let lhsRows), .forwardRows(let rhsRows)): return lhsRows == rhsRows case (.forwardStreamComplete(let lhsBuffer, let lhsCommandTag), .forwardStreamComplete(let rhsBuffer, let rhsCommandTag)): @@ -40,8 +37,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): return lhsName == rhsName && lhsQuery == rhsQuery - case (.succeedPreparedStatementCreation(let lhsContext, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsContext, let rhsRowDescription)): - return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription + case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription case (.fireChannelInactive, .fireChannelInactive): return true default: @@ -110,8 +107,6 @@ extension PSQLTask: Equatable { switch (lhs, rhs) { case (.extendedQuery(let lhs), .extendedQuery(let rhs)): return lhs === rhs - case (.preparedStatement(let lhs), .preparedStatement(let rhs)): - return lhs === rhs case (.closeCommand(let lhs), .closeCommand(let rhs)): return lhs === rhs default: diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index f27ff060..1af35fac 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -5,44 +5,27 @@ import XCTest import NIOCore import NIOEmbedded -class PSQLRowStreamTests: XCTestCase { +final class PSQLRowStreamTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + let eventLoop = EmbeddedEventLoop() + func testEmptyStream() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "INSERT INTO foo bar;", logger: logger, promise: promise - ) - let stream = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .noRows(.success("INSERT 0 1")) + source: .noRows(.success("INSERT 0 1")), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(try stream.all().wait(), []) XCTAssertEqual(stream.commandTag, "INSERT 0 1") } func testFailedStream() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let stream = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .noRows(.failure(PSQLError.connectionClosed)) + source: .noRows(.failure(PSQLError.connectionClosed)), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertThrowsError(try stream.all().wait()) { XCTAssertEqual($0 as? PSQLError, .connectionClosed) @@ -50,24 +33,15 @@ class PSQLRowStreamTests: XCTestCase { } func testGetArrayAfterStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -89,22 +63,15 @@ class PSQLRowStreamTests: XCTestCase { } func testGetArrayBeforeStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise) let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -139,24 +106,15 @@ class PSQLRowStreamTests: XCTestCase { } func testOnRowAfterStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -183,24 +141,15 @@ class PSQLRowStreamTests: XCTestCase { } func testOnRowThrowsErrorOnInitialBatch() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -232,24 +181,15 @@ class PSQLRowStreamTests: XCTestCase { func testOnRowBeforeStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index e1fdad11..fc589c0b 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -7,21 +7,21 @@ import NIOCore import Logging final class PostgresRowSequenceTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + let eventLoop = EmbeddedEventLoop() func testBackpressureWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -38,20 +38,19 @@ final class PostgresRowSequenceTests: XCTestCase { XCTAssertNil(empty) } + func testCancellationWorksWhileIterating() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -72,19 +71,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testCancellationWorksBeforeIterating() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -99,19 +96,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testDroppingTheSequenceCancelsTheSource() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) var rowSequence: PostgresRowSequence? = stream.asyncSequence() rowSequence = nil @@ -121,19 +116,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamBasedOnCompletedQuery() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } @@ -150,19 +143,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamIfInitializedWithAllData() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) @@ -180,19 +171,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamIfInitializedWithError() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) stream.receive(completion: .failure(PSQLError.connectionClosed)) @@ -210,19 +199,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testSucceedingRowContinuationsWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() @@ -244,19 +231,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testFailingRowContinuationsWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() @@ -282,19 +267,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testAdaptiveRowBufferShrinksAndGrows() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let initialDataRows: [DataRow] = (0.. Date: Wed, 9 Aug 2023 23:11:53 +0200 Subject: [PATCH 012/106] PostgresNotificationSequence is not Sendable in 5.6 (#392) `AsyncThrowingStream` is not `Sendable` in Swift 5.6. Because of this `PostgresNotificationSequence` can not be `Sendable` in 5.6. --- Sources/PostgresNIO/New/PostgresNotificationSequence.swift | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 735c01b0..55fb0670 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -3,7 +3,7 @@ public struct PostgresNotification: Sendable { public let payload: String } -public struct PostgresNotificationSequence: AsyncSequence, Sendable { +public struct PostgresNotificationSequence: AsyncSequence { public typealias Element = PostgresNotification let base: AsyncThrowingStream @@ -20,3 +20,8 @@ public struct PostgresNotificationSequence: AsyncSequence, Sendable { } } } + +#if swift(>=5.7) +// AsyncThrowingStream is marked as Sendable in Swift 5.6 +extension PostgresNotificationSequence: Sendable {} +#endif From a5758b0c1bcbf3f0a27335d60813509a93027dc5 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Wed, 9 Aug 2023 23:17:01 +0200 Subject: [PATCH 013/106] Use EventLoop provided by SwiftNIO's MultiThreadedEventLoopGroup.singleton (#389) Co-authored-by: Fabian Fett --- Package.swift | 4 ++-- README.md | 20 +----------------- .../Connection/PostgresConnection.swift | 21 +++++++++++++++++-- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/Package.swift b/Package.swift index c1cb4bda..a45925ed 100644 --- a/Package.swift +++ b/Package.swift @@ -14,8 +14,8 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.52.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.18.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), diff --git a/README.md b/README.md index 51e0b8c5..441a41e3 100644 --- a/README.md +++ b/README.md @@ -67,19 +67,7 @@ let config = PostgresConnection.Configuration( ) ``` -A connection must be created on a SwiftNIO `EventLoop`. In most server use cases, an -`EventLoopGroup` is created at app startup and closed during app shutdown. - -```swift -import NIOPosix - -let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - -// Much later -try await eventLoopGroup.shutdownGracefully() -``` - -A [`Logger`] is also required. +To create a connection we need a [`Logger`], that is used to log connection background events. ```swift import Logging @@ -91,10 +79,8 @@ Now we can put it together: ```swift import PostgresNIO -import NIOPosix import Logging -let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let logger = Logger(label: "postgres-logger") let config = PostgresConnection.Configuration( @@ -107,7 +93,6 @@ let config = PostgresConnection.Configuration( ) let connection = try await PostgresConnection.connect( - on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger @@ -115,9 +100,6 @@ let connection = try await PostgresConnection.connect( // Close your connection once done try await connection.close() - -// Shutdown the EventLoopGroup, once all connections are closed. -try await eventLoopGroup.shutdownGracefully() ``` #### Querying diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 6f849bdd..f8a9709e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -360,13 +360,13 @@ extension PostgresConnection { /// Creates a new connection to a Postgres server. /// /// - Parameters: - /// - eventLoop: The `EventLoop` the request shall be created on + /// - eventLoop: The `EventLoop` the connection shall be created on. /// - configuration: A ``Configuration`` that shall be used for the connection /// - connectionID: An `Int` id, used for metadata logging /// - logger: A logger to log background events into /// - Returns: An established ``PostgresConnection`` asynchronously that can be used to run queries. public static func connect( - on eventLoop: EventLoop, + on eventLoop: EventLoop = PostgresConnection.defaultEventLoopGroup.any(), configuration: PostgresConnection.Configuration, id connectionID: ID, logger: Logger @@ -661,3 +661,20 @@ extension EventLoopFuture { } } } + +extension PostgresConnection { + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { +#if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return NIOTSEventLoopGroup.singleton + } else { + return MultiThreadedEventLoopGroup.singleton + } +#else + return MultiThreadedEventLoopGroup.singleton +#endif + } +} From 52d5636edd2da896d1669dfd7fd4f83de94686c4 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 10 Aug 2023 08:03:36 +0200 Subject: [PATCH 014/106] `close()` closes immediately; Add new `closeGracefully()` (#383) Fixes #370. This patch changes the behavior of `PostgresConnection.close()`. Currently `close()` terminates the connection only after all queued queries have been successfully processed by the server. This however leads to an unwanted dependency on the Postgres server to close a connection. If a server stops responding, the client is currently unable to close its connection. Because of this, this patch changes the behavior of `close()`. `close()` now terminates a connection immediately and fails all running or queued queries. To allow users to continue to use the existing behavior we introduce a `closeGracefully()` that now has the same behavior as close had previously. Since we never documented the old close behavior and we consider it dangerous in certain situations we are fine with changing the behavior without tagging a major version. --- .../Connection/PostgresConnection.swift | 11 ++ .../ConnectionStateMachine.swift | 164 ++++++++++-------- .../ListenStateMachine.swift | 11 +- Sources/PostgresNIO/New/PSQLError.swift | 54 ++++-- .../PostgresNIO/New/PSQLEventsHandler.swift | 2 + .../New/PostgresChannelHandler.swift | 7 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 6 +- .../ConnectionStateMachineTests.swift | 6 +- .../New/PSQLRowStreamTests.swift | 4 +- .../New/PostgresChannelHandlerTests.swift | 7 +- .../New/PostgresConnectionTests.swift | 92 ++++++++++ .../New/PostgresRowSequenceTests.swift | 8 +- 12 files changed, 263 insertions(+), 109 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index f8a9709e..7ac8ec57 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -384,6 +384,17 @@ extension PostgresConnection { try await self.close().get() } + /// Closes the connection to the server, _after all queries_ that have been created on this connection have been run. + public func closeGracefully() async throws { + try await withTaskCancellationHandler { () async throws -> () in + let promise = self.eventLoop.makePromise(of: Void.self) + self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise) + return try await promise.futureResult.get() + } onCancel: { + _ = self.close() + } + } + /// Run a query on the Postgres server the connection is connected to. /// /// - Parameters: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 0f3e96c9..bbfa0faa 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -32,11 +32,10 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) - - case error(PSQLError) - case closing - case closed - + + case closing(PSQLError?) + case closed(clientInitiated: Bool, error: PSQLError?) + case modifying } @@ -158,7 +157,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed, .modifying: @@ -170,9 +168,9 @@ struct ConnectionStateMachine { self.startAuthentication(authContext) } - mutating func close(_ promise: EventLoopPromise?) -> ConnectionAction { + mutating func gracefulClose(_ promise: EventLoopPromise?) -> ConnectionAction { switch self.state { - case .closing, .closed, .error: + case .closing, .closed: // we are already closed, but sometimes an upstream handler might want to close the // connection, though it has already been closed by the remote. Typical race condition. return .closeConnection(promise) @@ -180,7 +178,7 @@ struct ConnectionStateMachine { precondition(self.taskQueue.isEmpty, """ The state should only be .readyForQuery if there are no more tasks in the queue """) - self.state = .closing + self.state = .closing(nil) return .closeConnection(promise) default: switch self.quiescingState { @@ -194,7 +192,11 @@ struct ConnectionStateMachine { return .wait } } - + + mutating func close(promise: EventLoopPromise?) -> ConnectionAction { + return self.closeConnectionAndCleanup(.clientClosedConnection(underlying: nil), closePromise: promise) + } + mutating func closed() -> ConnectionAction { switch self.state { case .initialized: @@ -214,8 +216,8 @@ struct ConnectionStateMachine { .closeCommand: return self.errorHappened(.uncleanShutdown) - case .error, .closing: - self.state = .closed + case .closing(let error): + self.state = .closed(clientInitiated: true, error: error) self.quiescingState = .notQuiescing return .fireChannelInactive @@ -242,7 +244,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) @@ -270,7 +271,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) @@ -291,7 +291,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: preconditionFailure("Can only add a ssl handler after negotiation: \(self.state)") @@ -316,7 +315,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: preconditionFailure("Can only establish a ssl connection after adding a ssl handler: \(self.state)") @@ -363,8 +361,7 @@ struct ConnectionStateMachine { .waitingToStartAuthentication, .authenticating, .closing: - self.state = .error(.unexpectedBackendMessage(.parameterStatus(status))) - return .wait + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterStatus(status))) case .authenticated(let keyData, var parameters): return self.avoidingStateMachineCoW { machine in parameters[status.parameter] = status.value @@ -389,8 +386,6 @@ struct ConnectionStateMachine { machine.state = .closeCommand(closeState, connectionContext) return .wait } - case .error(_): - return .wait case .initialized, .closed: preconditionFailure("We shouldn't receive messages if we are not connected") @@ -406,8 +401,7 @@ struct ConnectionStateMachine { .sslHandlerAdded, .waitingToStartAuthentication, .authenticated, - .readyForQuery, - .error: + .readyForQuery: return self.closeConnectionAndCleanup(.server(errorMessage)) case .authenticating(var authState): if authState.isComplete { @@ -477,8 +471,6 @@ struct ConnectionStateMachine { let action = closeState.errorHappened(error) return self.modify(with: action) } - case .error: - return .wait case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -553,40 +545,54 @@ struct ConnectionStateMachine { } mutating func enqueue(task: PSQLTask) -> ConnectionAction { + let psqlErrror: PSQLError + // check if we are quiescing. if so fail task immidiatly - if case .quiescing = self.quiescingState { - switch task { - case .extendedQuery(let queryContext): - switch queryContext.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): - return .failQuery(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) - case .prepareStatement(_, _, let eventLoopPromise): - return .failPreparedStatementCreation(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) - } + switch self.quiescingState { + case .quiescing: + psqlErrror = PSQLError.clientClosesConnection(underlying: nil) + + case .notQuiescing: + switch self.state { + case .initialized, + .authenticated, + .authenticating, + .closeCommand, + .extendedQuery, + .sslNegotiated, + .sslHandlerAdded, + .sslRequestSent, + .waitingToStartAuthentication: + self.taskQueue.append(task) + return .wait + + case .readyForQuery: + return self.executeTask(task) + + case .closing(let error): + psqlErrror = PSQLError.clientClosesConnection(underlying: error) + + case .closed(clientInitiated: true, error: let error): + psqlErrror = PSQLError.clientClosedConnection(underlying: error) - case .closeCommand(let closeContext): - return .failClose(closeContext, with: .connectionQuiescing, cleanupContext: nil) + case .closed(clientInitiated: false, error: let error): + psqlErrror = PSQLError.serverClosedConnection(underlying: error) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } - switch self.state { - case .readyForQuery: - return self.executeTask(task) - case .closed: - switch task { - case .extendedQuery(let queryContext): - switch queryContext.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): - return .failQuery(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) - case .prepareStatement(_, _, let eventLoopPromise): - return .failPreparedStatementCreation(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) - } - case .closeCommand(let closeContext): - return .failClose(closeContext, with: .connectionClosed, cleanupContext: nil) + switch task { + case .extendedQuery(let queryContext): + switch queryContext.query { + case .executeStatement(_, let promise), .unnamed(_, let promise): + return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .prepareStatement(_, _, let promise): + return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } - default: - self.taskQueue.append(task) - return .wait + case .closeCommand(let closeContext): + return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) } } @@ -601,7 +607,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .closeCommand, - .error, .closing, .closed: return .wait @@ -648,8 +653,6 @@ struct ConnectionStateMachine { machine.state = .closeCommand(closeState, connectionContext) return machine.modify(with: action) } - case .error: - return .read case .closing: return .read case .closed: @@ -818,7 +821,7 @@ struct ConnectionStateMachine { } } - private mutating func closeConnectionAndCleanup(_ error: PSQLError) -> ConnectionAction { + private mutating func closeConnectionAndCleanup(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction { switch self.state { case .initialized, .sslRequestSent, @@ -827,12 +830,12 @@ struct ConnectionStateMachine { .waitingToStartAuthentication, .authenticated, .readyForQuery: - let cleanupContext = self.setErrorAndCreateCleanupContext(error) + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) return .closeConnectionAndCleanup(cleanupContext) case .authenticating(var authState): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if authState.isComplete { // in case the auth state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -847,8 +850,8 @@ struct ConnectionStateMachine { return .closeConnectionAndCleanup(cleanupContext) case .extendedQuery(var queryStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if queryStateMachine.isComplete { // in case the query state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -867,19 +870,23 @@ struct ConnectionStateMachine { .wait, .read: preconditionFailure("Invalid state: \(self.state)") + case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) + case .failQuery(let queryContext, with: let error): return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + case .failPreparedStatementCreation(let promise, with: let error): return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } case .closeCommand(var closeStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if closeStateMachine.isComplete { // in case the close state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -897,7 +904,7 @@ struct ConnectionStateMachine { return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } - case .error, .closing, .closed: + case .closing, .closed: // We might run into this case because of reentrancy. For example: After we received an // backend unexpected message, that we read of the wire, we bring this connection into // the error state and will try to close the connection. However the server might have @@ -921,7 +928,7 @@ struct ConnectionStateMachine { // if we don't have anything left to do and we are quiescing, next we should close if case .quiescing(let promise) = self.quiescingState { - self.state = .closing + self.state = .closing(nil) return .closeConnection(promise) } @@ -1024,9 +1031,9 @@ extension ConnectionStateMachine { } return false - case .connectionQuiescing: + case .clientClosesConnection, .clientClosedConnection: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") - case .connectionClosed: + case .serverClosedConnection: preconditionFailure("Pure client error, that is thrown directly and should never ") } } @@ -1039,23 +1046,28 @@ extension ConnectionStateMachine { return self.setErrorAndCreateCleanupContext(error) } - mutating func setErrorAndCreateCleanupContext(_ error: PSQLError) -> ConnectionAction.CleanUpContext { + mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { let tasks = Array(self.taskQueue) self.taskQueue.removeAll() - var closePromise: EventLoopPromise? = nil - if case .quiescing(let promise) = self.quiescingState { - closePromise = promise + var forwardedPromise: EventLoopPromise? = nil + if case .quiescing(.some(let quiescePromise)) = self.quiescingState, let closePromise = closePromise { + quiescePromise.futureResult.cascade(to: closePromise) + forwardedPromise = quiescePromise + } else if case .quiescing(.some(let quiescePromise)) = self.quiescingState { + forwardedPromise = quiescePromise + } else { + forwardedPromise = closePromise } - - self.state = .error(error) - + + self.state = .closing(error) + var action = ConnectionAction.CleanUpContext.Action.close if case .uncleanShutdown = error.code.base { action = .fireChannelInactive } - return .init(action: action, tasks: tasks, error: error, closePromise: closePromise) + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) } } @@ -1187,8 +1199,6 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" - case .error(let error): - return ".error(\(String(reflecting: error)))" case .closing: return ".closing" case .closed: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift index c7f92428..89f40469 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -36,7 +36,14 @@ struct ListenStateMachine { } mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction { - return self.channels[channel, default: .init()].stopListeningSucceeded() + switch self.channels[channel]!.stopListeningSucceeded() { + case .none: + self.channels.removeValue(forKey: channel) + return .none + + case .startListening: + return .startListening + } } enum CancelAction { @@ -46,7 +53,7 @@ struct ListenStateMachine { } mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction { - return self.channels[channel, default: .init()].cancelListening(id: id) + return self.channels[channel]?.cancelListening(id: id) ?? .none } mutating func fail(_ error: Error) -> [NotificationListener] { diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 5d9e534c..1fec59b1 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -18,8 +18,9 @@ public struct PSQLError: Error { case queryCancelled case tooManyParameters - case connectionQuiescing - case connectionClosed + case clientClosesConnection + case clientClosedConnection + case serverClosedConnection case connectionError case uncleanShutdown @@ -45,13 +46,20 @@ public struct PSQLError: Error { public static let invalidCommandTag = Self(.invalidCommandTag) public static let queryCancelled = Self(.queryCancelled) public static let tooManyParameters = Self(.tooManyParameters) - public static let connectionQuiescing = Self(.connectionQuiescing) - public static let connectionClosed = Self(.connectionClosed) + public static let clientClosesConnection = Self(.clientClosesConnection) + public static let clientClosedConnection = Self(.clientClosedConnection) + public static let serverClosedConnection = Self(.serverClosedConnection) public static let connectionError = Self(.connectionError) public static let uncleanShutdown = Self.init(.uncleanShutdown) public static let listenFailed = Self.init(.listenFailed) public static let unlistenFailed = Self.init(.unlistenFailed) + @available(*, deprecated, renamed: "clientClosesConnection") + public static let connectionQuiescing = Self.clientClosesConnection + + @available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead") + public static let connectionClosed = Self.serverClosedConnection + public var description: String { switch self.base { case .sslUnsupported: @@ -78,10 +86,12 @@ public struct PSQLError: Error { return "queryCancelled" case .tooManyParameters: return "tooManyParameters" - case .connectionQuiescing: - return "connectionQuiescing" - case .connectionClosed: - return "connectionClosed" + case .clientClosesConnection: + return "clientClosesConnection" + case .clientClosedConnection: + return "clientClosedConnection" + case .serverClosedConnection: + return "serverClosedConnection" case .connectionError: return "connectionError" case .uncleanShutdown: @@ -377,19 +387,33 @@ public struct PSQLError: Error { return new } - static var connectionQuiescing: PSQLError { PSQLError(code: .connectionQuiescing) } + static func clientClosesConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .clientClosesConnection) + error.underlying = underlying + return error + } + + static func clientClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .clientClosedConnection) + error.underlying = underlying + return error + } - static var connectionClosed: PSQLError { PSQLError(code: .connectionClosed) } + static func serverClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .serverClosedConnection) + error.underlying = underlying + return error + } - static var authMechanismRequiresPassword: PSQLError { PSQLError(code: .authMechanismRequiresPassword) } + static let authMechanismRequiresPassword = PSQLError(code: .authMechanismRequiresPassword) - static var sslUnsupported: PSQLError { PSQLError(code: .sslUnsupported) } + static let sslUnsupported = PSQLError(code: .sslUnsupported) - static var queryCancelled: PSQLError { PSQLError(code: .queryCancelled) } + static let queryCancelled = PSQLError(code: .queryCancelled) - static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) } + static let uncleanShutdown = PSQLError(code: .uncleanShutdown) - static var receivedUnencryptedDataAfterSSLRequest: PSQLError { PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) } + static let receivedUnencryptedDataAfterSSLRequest = PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError { var error = PSQLError(code: .server) diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 3233fb77..2bf0d6d8 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -7,6 +7,8 @@ enum PSQLOutgoingEvent { /// /// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration` case authenticate(AuthContext) + + case gracefulShutdown } enum PSQLEvent { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index abfa5aeb..7801d4d6 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -247,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return } - let action = self.state.close(promise) + let action = self.state.close(promise: promise) self.run(action, with: context) } @@ -258,6 +258,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case PSQLOutgoingEvent.authenticate(let authContext): let action = self.state.provideAuthenticationContext(authContext) self.run(action, with: context) + + case PSQLOutgoingEvent.gracefulShutdown: + let action = self.state.gracefulClose(promise) + self.run(action, with: context) + default: context.triggerUserOutboundEvent(event, promise: promise) } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 10970b26..1989e5bc 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -37,9 +37,9 @@ extension PSQLError { return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: return self - case .connectionQuiescing: - return PostgresError.connectionClosed - case .connectionClosed: + case .clientClosesConnection, + .clientClosedConnection, + .serverClosedConnection: return PostgresError.connectionClosed case .connectionError: return self.underlying ?? self diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 5fd3bc20..f3d72a5e 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -137,14 +137,14 @@ class ConnectionStateMachineTests: XCTestCase { func testErrorIsIgnoredWhenClosingConnection() { // test ignore unclean shutdown when closing connection - var stateIgnoreChannelError = ConnectionStateMachine(.closing) - + var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil)) + XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait) XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) // test ignore any other error when closing connection - var stateIgnoreErrorMessage = ConnectionStateMachine(.closing) + var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil)) XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait) XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive) } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 1af35fac..d6d03107 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -22,13 +22,13 @@ final class PSQLRowStreamTests: XCTestCase { func testFailedStream() { let stream = PSQLRowStream( - source: .noRows(.failure(PSQLError.connectionClosed)), + source: .noRows(.failure(PSQLError.serverClosedConnection(underlying: nil))), eventLoop: self.eventLoop, logger: self.logger ) XCTAssertThrowsError(try stream.all().wait()) { - XCTAssertEqual($0 as? PSQLError, .connectionClosed) + XCTAssertEqual($0 as? PSQLError, .serverClosedConnection(underlying: nil)) } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 5388e8b5..eed5ada7 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -24,8 +24,11 @@ class PostgresChannelHandlerTests: XCTestCase { ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ], loop: self.eventLoop) - defer { XCTAssertNoThrow(try embedded.finish()) } - + defer { + do { try embedded.finish() } + catch { print("\(String(reflecting: error))") } + } + var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0622d51e..46f864ce 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -182,6 +182,98 @@ class PostgresConnectionTests: XCTestCase { } } + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: self.logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, 1) + let second = try await iterator.next() + XCTAssertNil(second) + } + } + + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(terminate, .terminate) + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + } + + func testCloseClosesImmediatly() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: self.logger) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + async let close: () = connection.close() + + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + try await close + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + XCTFail("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + return XCTFail("Unexpected error type: \(failure)") + } + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + } + } func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index fc589c0b..872c098d 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -183,7 +183,7 @@ final class PostgresRowSequenceTests: XCTestCase { logger: self.logger ) - stream.receive(completion: .failure(PSQLError.connectionClosed)) + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) let rowSequence = stream.asyncSequence() @@ -194,7 +194,7 @@ final class PostgresRowSequenceTests: XCTestCase { } XCTFail("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .connectionClosed) + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) } } @@ -255,14 +255,14 @@ final class PostgresRowSequenceTests: XCTestCase { XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { - stream.receive(completion: .failure(PSQLError.connectionClosed)) + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } do { _ = try await rowIterator.next() XCTFail("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .connectionClosed) + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) } } From 5217ba7557f8aa292fcf5f0440bfc2bed7862efb Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 10 Aug 2023 06:16:17 -0500 Subject: [PATCH 015/106] Use README header image compatible with light/dark mode (#393) --- README.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 441a41e3..b4f8f70e 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,18 @@ -PostgresNIO - -[![SSWG Incubating Badge](https://img.shields.io/badge/sswg-incubating-green.svg)][SSWG Incubation] -[![Documentation](http://img.shields.io/badge/read_the-docs-2196f3.svg)][Documentation] -[![Team Chat](https://img.shields.io/discord/431917998102675485.svg)][Team Chat] -[![MIT License](http://img.shields.io/badge/license-MIT-brightgreen.svg)][MIT License] -[![Continuous Integration](https://github.com/vapor/postgres-nio/actions/workflows/test.yml/badge.svg)][Continuous Integration] -[![Swift 5.6](http://img.shields.io/badge/swift-5.6-brightgreen.svg)][Swift 5.6] +

+ + + + PostgresNIO +

- +SSWG Incubation +Documentation +MIT License +Continuous Integration +Swift 5.6 +

+
🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. Features: From d5c52584cb3f19b3166040e05271f7581b0befa3 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 18 Aug 2023 11:12:18 +0100 Subject: [PATCH 016/106] async/await prepared statement API (#390) This patch adds a new `PreparedStatement` protocol to represent prepared statements and an `execute` function on `PostgresConnection` to prepare and execute statements. To implement the features the patch also introduces a `PreparedStatementStateMachine` that keeps track of the state of a prepared statement at the connection level. This ensures that, for each connection, each statement is prepared once at time of first use and then subsequent uses are going to skip the preparation step and just execute it. ## Example usage First define the struct to represent the prepared statement: ```swift struct ExamplePreparedStatement: PreparedStatement { static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) var state: String func makeBindings() -> PostgresBindings { var bindings = PostgresBindings() bindings.append(self.state) return bindings } func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { try row.decode(Row.self) } } ``` then, assuming you already have a `PostgresConnection` you can execute it: ```swift let preparedStatement = ExamplePreparedStatement(state: "active") let results = try await connection.execute(preparedStatement, logger: logger) for (pid, database) in results { print("PID: \(pid), database: \(database)") } ``` --------- Co-authored-by: Fabian Fett --- .../Connection/PostgresConnection.swift | 66 ++++ .../PreparedStatementStateMachine.swift | 93 +++++ Sources/PostgresNIO/New/PSQLTask.swift | 23 ++ .../New/PostgresChannelHandler.swift | 115 +++++- Sources/PostgresNIO/New/PostgresQuery.swift | 10 + .../PostgresNIO/New/PreparedStatement.swift | 40 ++ Tests/IntegrationTests/AsyncTests.swift | 42 +++ .../PreparedStatementStateMachineTests.swift | 159 ++++++++ .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/PostgresConnectionTests.swift | 352 ++++++++++++++++++ 10 files changed, 898 insertions(+), 4 deletions(-) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift create mode 100644 Sources/PostgresNIO/New/PreparedStatement.swift create mode 100644 Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7ac8ec57..d3f51ca9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -460,6 +460,72 @@ extension PostgresConnection { self.channel.write(task, promise: nil) } } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + + } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> String where Statement.Row == () { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.commandTag } + .get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift new file mode 100644 index 00000000..5afa4d0b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -0,0 +1,93 @@ +import NIOCore + +struct PreparedStatementStateMachine { + enum State { + case preparing([PreparedStatementContext]) + case prepared(RowDescription?) + case error(PSQLError) + } + + var preparedStatements: [String: State] = [:] + + enum LookupAction { + case prepareStatement + case waitForAlreadyInFlightPreparation + case executeStatement(RowDescription?) + case returnError(PSQLError) + } + + mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction { + if let state = self.preparedStatements[preparedStatement.name] { + switch state { + case .preparing(var statements): + statements.append(preparedStatement) + self.preparedStatements[preparedStatement.name] = .preparing(statements) + return .waitForAlreadyInFlightPreparation + case .prepared(let rowDescription): + return .executeStatement(rowDescription) + case .error(let error): + return .returnError(error) + } + } else { + self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement]) + return .prepareStatement + } + } + + struct PreparationCompleteAction { + var statements: [PreparedStatementContext] + var rowDescription: RowDescription? + } + + mutating func preparationComplete( + name: String, + rowDescription: RowDescription? + ) -> PreparationCompleteAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + // When sending the bindings we are going to ask for binary data. + if var rowDescription = rowDescription { + for i in 0.. ErrorHappenedAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + self.preparedStatements[name] = .error(error) + return ErrorHappenedAction( + statements: statements, + error: error + ) + case .prepared, .error: + preconditionFailure("Error happened in an unexpected state \(state)") + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f5de6561..9425c12b 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -6,6 +6,7 @@ enum HandlerTask { case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) + case executePreparedStatement(PreparedStatementContext) } enum PSQLTask { @@ -69,6 +70,28 @@ final class ExtendedQueryContext { } } +final class PreparedStatementContext{ + let name: String + let sql: String + let bindings: PostgresBindings + let logger: Logger + let promise: EventLoopPromise + + init( + name: String, + sql: String, + bindings: PostgresBindings, + logger: Logger, + promise: EventLoopPromise + ) { + self.name = name + self.sql = sql + self.bindings = bindings + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7801d4d6..bf56d6d1 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -22,7 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - private var listenState: ListenStateMachine + private var listenState = ListenStateMachine() + private var preparedStatementState = PreparedStatementStateMachine() init( configuration: PostgresConnection.InternalConfiguration, @@ -32,7 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -50,7 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = state self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -233,6 +232,29 @@ final class PostgresChannelHandler: ChannelDuplexHandler { listener.failed(CancellationError()) return } + case .executePreparedStatement(let preparedStatement): + let action = self.preparedStatementState.lookup( + preparedStatement: preparedStatement + ) + switch action { + case .prepareStatement: + psqlTask = self.makePrepareStatementTask( + preparedStatement: preparedStatement, + context: context + ) + case .waitForAlreadyInFlightPreparation: + // The state machine already keeps track of this + // and will execute the statement as soon as it's prepared + return + case .executeStatement(let rowDescription): + psqlTask = self.makeExecutePreparedStatementTask( + preparedStatement: preparedStatement, + rowDescription: rowDescription + ) + case .returnError(let error): + preparedStatement.promise.fail(error) + return + } } let action = self.state.enqueue(task: psqlTask) @@ -664,6 +686,93 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + private func makePrepareStatementTask( + preparedStatement: PreparedStatementContext, + context: ChannelHandlerContext + ) -> PSQLTask { + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + promise.futureResult.whenComplete { result in + switch result { + case .success(let rowDescription): + self.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + case .failure(let error): + let psqlError: PSQLError + if let error = error as? PSQLError { + psqlError = error + } else { + psqlError = .connectionError(underlying: error) + } + self.prepareStatementFailed( + name: preparedStatement.name, + error: psqlError, + context: context + ) + } + } + return .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + logger: preparedStatement.logger, + promise: promise + )) + } + + private func makeExecutePreparedStatementTask( + preparedStatement: PreparedStatementContext, + rowDescription: RowDescription? + ) -> PSQLTask { + return .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + } + + private func prepareStatementComplete( + name: String, + rowDescription: RowDescription?, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.preparationComplete( + name: name, + rowDescription: rowDescription + ) + for preparedStatement in action.statements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: action.rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + ) + self.run(action, with: context) + } + } + + private func prepareStatementFailed( + name: String, + error: PSQLError, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.errorHappened( + name: name, + error: error + ) + for statement in action.statements { + statement.promise.fail(action.error) + } + } } extension PostgresChannelHandler: PSQLRowsDataSource { diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 2e06e1d9..4ca1e454 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) } + @inlinable + public mutating func append(_ value: Value) throws { + try self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, @@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append(_ value: Value) { + self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift new file mode 100644 index 00000000..1e0b5d5a --- /dev/null +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -0,0 +1,40 @@ +/// A prepared statement. +/// +/// Structs conforming to this protocol will need to provide the SQL statement to +/// send to the server and a way of creating bindings are decoding the result. +/// +/// As an example, consider this struct: +/// ```swift +/// struct Example: PostgresPreparedStatement { +/// static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// typealias Row = (Int, String) +/// +/// var state: String +/// +/// func makeBindings() -> PostgresBindings { +/// var bindings = PostgresBindings() +/// bindings.append(self.state) +/// return bindings +/// } +/// +/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { +/// try row.decode(Row.self) +/// } +/// } +/// ``` +/// +/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, +/// which will take care of preparing the statement on the server side and executing it. +public protocol PostgresPreparedStatement: Sendable { + /// The type rows returned by the statement will be decoded into + associatedtype Row + + /// The SQL statement to prepare on the database server. + static var sql: String { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + func makeBindings() throws -> PostgresBindings + + /// Decode a row returned by the database into an instance of `Row` + func decodeRow(_ row: PostgresRow) throws -> Row +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index f68ef1f3..bf945a67 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await connection.query("SELECT 1;", logger: .psqlTest) } } + + func testPreparedStatement() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct TestPreparedStatement: PostgresPreparedStatement { + static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + let preparedStatement = TestPreparedStatement(state: "active") + try await withTestConnection(on: eventLoop) { connection in + var results = try await connection.execute(preparedStatement, logger: .psqlTest) + var counter = 0 + + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + + // Second execution, which reuses the existing prepared statement + results = try await connection.execute(preparedStatement, logger: .psqlTest) + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift new file mode 100644 index 00000000..ab77a57c --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -0,0 +1,159 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PreparedStatementStateMachineTests: XCTestCase { + func testPrepareAndExecuteStatement() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + XCTAssertNil(preparationCompleteAction.rowDescription) + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // The statement is already preparead, lookups tell us to execute it + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .executeStatement(nil) = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + func testPrepareAndExecuteStatementWithError() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Simulate an error occurring during preparation + let error = PSQLError(code: .server) + let preparationCompleteAction = stateMachine.errorHappened( + name: "test", + error: error + ) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + firstPreparedStatement.promise.fail(error) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Ensure that we don't try again to prepare a statement we know will fail + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .returnError = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.fail(error) + } + + func testBatchStatementPreparation() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // A new request comes in before the statement completes + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .waitForAlreadyInFlightPreparation = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state. + // The action tells us to execute both the pending statements. + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 2) + XCTAssertNil(preparationCompleteAction.rowDescription) + + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + private func makePreparedStatementContext(eventLoop: EmbeddedEventLoop) -> PreparedStatementContext { + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + return PreparedStatementContext( + name: "test", + sql: "INSERT INTO test_table (column1) VALUES (1)", + bindings: PostgresBindings(), + logger: .psqlTest, + promise: promise + ) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index b9677000..46c043b1 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -142,7 +142,7 @@ extension PostgresFrontendMessage { } let parameters = (0.. ByteBuffer? in - let length = buffer.readInteger(as: UInt16.self) + let length = buffer.readInteger(as: UInt32.self) switch length { case .some(..<0): return nil diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 46f864ce..9c4dc5cb 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,288 @@ class PostgresConnectionTests: XCTestCase { } } + struct TestPrepareStatement: PostgresPreparedStatement { + static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + typealias Row = String + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + func testPreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest = try await channel.waitForPreparedRequest() + XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(preparedRequest.bind.parameters.count, 1) + XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } + } + + func testSerialExecutionOfSamePreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter1, "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter2, "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_active", database) + } + XCTAssertEqual(rows, 1) + } + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_idle", database) + } + XCTAssertEqual(rows, 1) + } + + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationFailure() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ @@ -327,6 +609,66 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .sync = sync + else { + fatalError("Unexpected message") + } + + return PrepareRequest(parse: parse, describe: describe) + } + + func sendPrepareResponse( + parameterDescription: PostgresBackendMessage.ParameterDescription, + rowDescription: RowDescription + ) async throws { + try await self.writeInbound(PostgresBackendMessage.parseComplete) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.parameterDescription(parameterDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } + + func waitForPreparedRequest() async throws -> PreparedRequest { + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return PreparedRequest(bind: bind, execute: execute) + } + + func sendPreparedResponse( + dataRows: [DataRow], + commandTag: String + ) async throws { + try await self.writeInbound(PostgresBackendMessage.bindComplete) + try await self.testingEventLoop.executeInContext { self.read() } + for dataRow in dataRows { + try await self.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + } + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.commandComplete(commandTag)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } } struct UnpreparedRequest { @@ -335,3 +677,13 @@ struct UnpreparedRequest { var bind: PostgresFrontendMessage.Bind var execute: PostgresFrontendMessage.Execute } + +struct PrepareRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe +} + +struct PreparedRequest { + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} From ef3a00f9dfd79ad5cd40a0a9fa242e8d3169cf2f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Aug 2023 13:39:47 +0200 Subject: [PATCH 017/106] Cleanup encoding Startup message (#395) Further cleanup of message encoding: - Move Startup struct into PostgresFrontendMessageEncoder - Move PSQLMessagePayloadEncodable into tests, since it isn't used in PostgresNIO anymore - Only support the parameters that are actually used in encoding startup messages --- .../PostgresNIO/New/Messages/Startup.swift | 52 ------------ .../New/PostgresChannelHandler.swift | 13 +-- .../New/PostgresFrontendMessage.swift | 48 ++++++++++- .../New/PostgresFrontendMessageEncoder.swift | 22 +---- .../PSQLBackendMessageEncoder.swift | 4 + .../New/Messages/StartupTests.swift | 82 ++++++++----------- .../New/PostgresChannelHandlerTests.swift | 11 +++ 7 files changed, 98 insertions(+), 134 deletions(-) delete mode 100644 Sources/PostgresNIO/New/Messages/Startup.swift diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift deleted file mode 100644 index 16d23e09..00000000 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ /dev/null @@ -1,52 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - struct Startup: Hashable { - static let versionThree: Int32 = 0x00_03_00_00 - - /// Creates a `Startup` with "3.0" as the protocol version. - static func versionThree(parameters: Parameters) -> Startup { - return .init(protocolVersion: Self.versionThree, parameters: parameters) - } - - /// The protocol version number. The most significant 16 bits are the major - /// version number (3 for the protocol described here). The least significant - /// 16 bits are the minor version number (0 for the protocol described here). - var protocolVersion: Int32 - - /// The protocol version number is followed by one or more pairs of parameter - /// name and value strings. A zero byte is required as a terminator after - /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Hashable { - enum Replication { - case `true` - case `false` - case database - } - - /// The database user name to connect as. Required; there is no default. - var user: String - - /// The database to connect to. Defaults to the user name. - var database: String? - - /// Command-line arguments for the backend. (This is deprecated in favor - /// of setting individual run-time parameters.) Spaces within this string are - /// considered to separate arguments, unless escaped with a - /// backslash (\); write \\ to represent a literal backslash. - var options: String? - - /// Used to connect in streaming replication mode, where a small set of - /// replication commands can be issued instead of SQL statements. Value - /// can be true, false, or database, and the default is false. - var replication: Replication - } - var parameters: Parameters - - /// Creates a new `PostgreSQLStartupMessage`. - init(protocolVersion: Int32, parameters: Parameters) { - self.protocolVersion = protocolVersion - self.parameters = parameters - } - } -} diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index bf56d6d1..7b31a776 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.startup(authContext.toStartupParameters()) + self.encoder.startup(user: authContext.username, database: authContext.database) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: self.encoder.ssl() @@ -793,17 +793,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource { } } -extension AuthContext { - func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { - PostgresFrontendMessage.Startup.Parameters( - user: self.username, - database: self.database, - options: nil, - replication: .false - ) - } -} - private extension Insecure.MD5.Digest { private static let lowercaseLookup: [UInt8] = [ diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 2a7ec9f1..ef7ce8f8 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -102,6 +102,50 @@ enum PostgresFrontendMessage: Equatable { static let requestCode: Int32 = 80877103 } + struct Startup: Hashable { + static let versionThree: Int32 = 0x00_03_00_00 + + /// Creates a `Startup` with "3.0" as the protocol version. + static func versionThree(parameters: Parameters) -> Startup { + return .init(protocolVersion: Self.versionThree, parameters: parameters) + } + + /// The protocol version number. The most significant 16 bits are the major + /// version number (3 for the protocol described here). The least significant + /// 16 bits are the minor version number (0 for the protocol described here). + var protocolVersion: Int32 + + /// The protocol version number is followed by one or more pairs of parameter + /// name and value strings. A zero byte is required as a terminator after + /// the last name/value pair. `user` is required, others are optional. + struct Parameters: Hashable { + enum Replication { + case `true` + case `false` + case database + } + + /// The database user name to connect as. Required; there is no default. + var user: String + + /// The database to connect to. Defaults to the user name. + var database: String? + + /// Command-line arguments for the backend. (This is deprecated in favor + /// of setting individual run-time parameters.) Spaces within this string are + /// considered to separate arguments, unless escaped with a + /// backslash (\); write \\ to represent a literal backslash. + var options: String? + + /// Used to connect in streaming replication mode, where a small set of + /// replication commands can be issued instead of SQL statements. Value + /// can be true, false, or database, and the default is false. + var replication: Replication + } + + var parameters: Parameters + } + case bind(Bind) case cancel(Cancel) case close(Close) @@ -225,7 +269,3 @@ extension PostgresFrontendMessage { } } } - -protocol PSQLMessagePayloadEncodable { - func encode(into buffer: inout ByteBuffer) -} diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 46dbba42..d4747163 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -13,34 +13,18 @@ struct PostgresFrontendMessageEncoder { self.buffer = buffer } - mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) { + mutating func startup(user: String, database: String?) { self.clearIfNeeded() self.encodeLengthPrefixed { buffer in buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) buffer.writeNullTerminatedString("user") - buffer.writeNullTerminatedString(parameters.user) + buffer.writeNullTerminatedString(user) - if let database = parameters.database { + if let database = database { buffer.writeNullTerminatedString("database") buffer.writeNullTerminatedString(database) } - if let options = parameters.options { - buffer.writeNullTerminatedString("options") - buffer.writeNullTerminatedString(options) - } - - switch parameters.replication { - case .database: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("replication") - case .true: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("true") - case .false: - break - } - buffer.writeInteger(UInt8(0)) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index e51c14f9..9614bf1e 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -257,3 +257,7 @@ extension RowDescription: PSQLMessagePayloadEncodable { } } } + +protocol PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) +} diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index e72f0f34..39e9bb42 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -4,56 +4,44 @@ import NIOCore class StartupTests: XCTestCase { - func testStartupMessage() { + func testStartupMessageWithDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() - - let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [ - .`true`, - .`false`, - .database - ] - - for replication in replicationValues { - let parameters = PostgresFrontendMessage.Startup.Parameters( - user: "test", - database: "abc123", - options: "some options", - replication: replication - ) - - encoder.startup(parameters) - byteBuffer = encoder.flushBuffer() - - let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options") - if replication != .false { - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue) - } - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) - - XCTAssertEqual(byteBuffer.readableBytes, 0) - } + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) } -} -extension PostgresFrontendMessage.Startup.Parameters.Replication { - var stringValue: String { - switch self { - case .true: - return "true" - case .false: - return "false" - case .database: - return "replication" - } + func testStartupMessageWithoutDatabase() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + + encoder.startup(user: user, database: nil) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index eed5ada7..b047cd72 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -277,3 +277,14 @@ class TestEventHandler: ChannelInboundHandler { self.events.append(psqlEvent) } } + +extension AuthContext { + func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { + PostgresFrontendMessage.Startup.Parameters( + user: self.username, + database: self.database, + options: nil, + replication: .false + ) + } +} From c1de89a187eca87eafb1ca398645845e4ed8af23 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Aug 2023 16:37:59 +0200 Subject: [PATCH 018/106] Make sure correct error is thrown, if server closes connection (#397) --- .../ConnectionStateMachine.swift | 28 ++++++++++--------- .../New/PostgresConnectionTests.swift | 28 +++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index bbfa0faa..b7ecc461 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -203,7 +203,7 @@ struct ConnectionStateMachine { preconditionFailure("How can a connection be closed, if it was never connected.") case .closed: - preconditionFailure("How can a connection be closed, if it is already closed.") + return .wait case .authenticated, .sslRequestSent, @@ -214,8 +214,8 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand: - return self.errorHappened(.uncleanShutdown) - + return self.errorHappened(.serverClosedConnection(underlying: nil)) + case .closing(let error): self.state = .closed(clientInitiated: true, error: error) self.quiescingState = .notQuiescing @@ -910,7 +910,7 @@ struct ConnectionStateMachine { // the error state and will try to close the connection. However the server might have // send further follow up messages. In those cases we will run into this method again // and again. We should just ignore those events. - return .wait + return .closeConnection(closePromise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -1034,16 +1034,16 @@ extension ConnectionStateMachine { case .clientClosesConnection, .clientClosedConnection: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .serverClosedConnection: - preconditionFailure("Pure client error, that is thrown directly and should never ") + return true } } mutating func setErrorAndCreateCleanupContextIfNeeded(_ error: PSQLError) -> ConnectionAction.CleanUpContext? { - guard self.shouldCloseConnection(reason: error) else { - return nil + if self.shouldCloseConnection(reason: error) { + return self.setErrorAndCreateCleanupContext(error) } - return self.setErrorAndCreateCleanupContext(error) + return nil } mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { @@ -1060,13 +1060,15 @@ extension ConnectionStateMachine { forwardedPromise = closePromise } - self.state = .closing(error) - - var action = ConnectionAction.CleanUpContext.Action.close - if case .uncleanShutdown = error.code.base { + let action: ConnectionAction.CleanUpContext.Action + if case .serverClosedConnection = error.code.base { + self.state = .closed(clientInitiated: false, error: error) action = .fireChannelInactive + } else { + self.state = .closing(error) + action = .close } - + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 9c4dc5cb..59917c40 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,34 @@ class PostgresConnectionTests: XCTestCase { } } + func testIfServerJustClosesTheErrorReflectsThat() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + async let response = try await connection.query("SELECT 1;", logger: self.logger) + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + + do { + _ = try await response + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + + // retry on same connection + + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + } + struct TestPrepareStatement: PostgresPreparedStatement { static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String From 12584c6666bd0b197e8063ef2415a7c9281152fb Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 18 Aug 2023 13:54:32 -0500 Subject: [PATCH 019/106] Fix a few inaccurate or confusing precondition failure messages (#398) --- .../ConnectionStateMachine.swift | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index b7ecc461..22c4087e 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -858,8 +858,9 @@ struct ConnectionStateMachine { // substate machine. return .closeConnectionAndCleanup(cleanupContext) } - - switch queryStateMachine.errorHappened(error) { + + let action = queryStateMachine.errorHappened(error) + switch action { case .sendParseDescribeBindExecuteSync, .sendParseDescribeSync, .sendBindExecuteSync, @@ -869,7 +870,7 @@ struct ConnectionStateMachine { .forwardStreamComplete, .wait, .read: - preconditionFailure("Invalid state: \(self.state)") + preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) @@ -894,12 +895,13 @@ struct ConnectionStateMachine { return .closeConnectionAndCleanup(cleanupContext) } - switch closeStateMachine.errorHappened(error) { + let action = closeStateMachine.errorHappened(error) + switch action { case .sendCloseSync, .succeedClose, .read, .wait: - preconditionFailure("Invalid state: \(self.state)") + preconditionFailure("Invalid close state machine action in state: \(self.state), action: \(action)") case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } @@ -1032,7 +1034,7 @@ extension ConnectionStateMachine { return false case .clientClosesConnection, .clientClosedConnection: - preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") + preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen") case .serverClosedConnection: return true } From 9a02d740a0fdb6fa52818c91d27875deb05add24 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 19 Aug 2023 11:10:55 +0200 Subject: [PATCH 020/106] Move PostgresFrontendMessage to tests (#399) --- .../New/Extensions/ByteBuffer+PSQL.swift | 8 -- .../New/PostgresFrontendMessageEncoder.swift | 95 +++++++++++++------ .../New/Extensions/ByteBuffer+Utils.swift | 5 +- .../Extensions}/PostgresFrontendMessage.swift | 1 + 4 files changed, 71 insertions(+), 38 deletions(-) rename {Sources/PostgresNIO/New => Tests/PostgresNIOTests/New/Extensions}/PostgresFrontendMessage.swift (99%) diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 6d632b6f..838e624d 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -2,14 +2,6 @@ import NIOCore internal extension ByteBuffer { - mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { - self.writeInteger(messageID.rawValue) - } - - mutating func psqlWriteFrontendMessageID(_ messageID: PostgresFrontendMessage.ID) { - self.writeInteger(messageID.rawValue) - } - @usableFromInline mutating func psqlReadFloat() -> Float? { return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index d4747163..e98ab1f1 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -1,6 +1,18 @@ import NIOCore struct PostgresFrontendMessageEncoder { + + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let sslRequestCode: Int32 = 80877103 + + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let cancelRequestCode: Int32 = 80877102 + + static let startupVersionThree: Int32 = 0x00_03_00_00 + private enum State { case flushed case writable @@ -15,8 +27,8 @@ struct PostgresFrontendMessageEncoder { mutating func startup(user: String, database: String?) { self.clearIfNeeded() - self.encodeLengthPrefixed { buffer in - buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) + self.buffer.psqlLengthPrefixed { buffer in + buffer.writeInteger(Self.startupVersionThree) buffer.writeNullTerminatedString("user") buffer.writeNullTerminatedString(user) @@ -31,8 +43,7 @@ struct PostgresFrontendMessageEncoder { mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) { self.clearIfNeeded() - self.buffer.psqlWriteFrontendMessageID(.bind) - self.encodeLengthPrefixed { buffer in + self.buffer.psqlLengthPrefixed(id: .bind) { buffer in buffer.writeNullTerminatedString(portalName) buffer.writeNullTerminatedString(preparedStatementName) @@ -65,45 +76,45 @@ struct PostgresFrontendMessageEncoder { mutating func cancel(processID: Int32, secretKey: Int32) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(UInt32(16), PostgresFrontendMessage.Cancel.requestCode, processID, secretKey) + self.buffer.writeMultipleIntegers(UInt32(16), Self.cancelRequestCode, processID, secretKey) } mutating func closePreparedStatement(_ preparedStatement: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) self.buffer.writeNullTerminatedString(preparedStatement) } mutating func closePortal(_ portal: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) self.buffer.writeNullTerminatedString(portal) } mutating func describePreparedStatement(_ preparedStatement: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) self.buffer.writeNullTerminatedString(preparedStatement) } mutating func describePortal(_ portal: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) self.buffer.writeNullTerminatedString(portal) } mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.execute.rawValue, UInt32(9 + portalName.utf8.count)) + self.buffer.psqlWriteMultipleIntegers(id: .execute, length: UInt32(5 + portalName.utf8.count)) self.buffer.writeNullTerminatedString(portalName) self.buffer.writeInteger(maxNumberOfRows) } mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { self.clearIfNeeded() - self.buffer.writeMultipleIntegers( - PostgresFrontendMessage.ID.parse.rawValue, - UInt32(4 + preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) + self.buffer.psqlWriteMultipleIntegers( + id: .parse, + length: UInt32(preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) ) self.buffer.writeNullTerminatedString(preparedStatementName) self.buffer.writeNullTerminatedString(query) @@ -116,28 +127,25 @@ struct PostgresFrontendMessageEncoder { mutating func password(_ bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.password.rawValue, UInt32(5 + bytes.count)) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count) + 1) self.buffer.writeBytes(bytes) self.buffer.writeInteger(UInt8(0)) } mutating func flush() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.flush.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .flush, length: 0) } mutating func saslResponse(_ bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt32(4 + bytes.count)) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count)) self.buffer.writeBytes(bytes) } mutating func saslInitialResponse(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers( - PostgresFrontendMessage.ID.saslInitialResponse.rawValue, - UInt32(4 + mechanism.utf8.count + 1 + 4 + bytes.count) - ) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(mechanism.utf8.count + 1 + 4 + bytes.count)) self.buffer.writeNullTerminatedString(mechanism) if bytes.count > 0 { self.buffer.writeInteger(Int32(bytes.count)) @@ -149,17 +157,17 @@ struct PostgresFrontendMessageEncoder { mutating func ssl() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(UInt32(8), PostgresFrontendMessage.SSLRequest.requestCode) + self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } mutating func sync() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.sync.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) } mutating func terminate() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.terminate.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .terminate, length: 0) } mutating func flushBuffer() -> ByteBuffer { @@ -177,13 +185,42 @@ struct PostgresFrontendMessageEncoder { break } } +} - private mutating func encodeLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { - let startIndex = self.buffer.writerIndex - self.buffer.writeInteger(UInt32(0)) // placeholder for length - encode(&self.buffer) - let length = UInt32(self.buffer.writerIndex - startIndex) - self.buffer.setInteger(length, at: startIndex) +private enum FrontendMessageID: UInt8, Hashable, Sendable { + case bind = 66 // B + case close = 67 // C + case describe = 68 // D + case execute = 69 // E + case flush = 72 // H + case parse = 80 // P + case password = 112 // p - also both sasl values + case sync = 83 // S + case terminate = 88 // X +} + +extension ByteBuffer { + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32) { + self.writeMultipleIntegers(id.rawValue, 4 + length) + } + + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32, _ t1: T1) { + self.writeMultipleIntegers(id.rawValue, 4 + length, t1) } + mutating fileprivate func psqlLengthPrefixed(id: FrontendMessageID, _ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + 1 + self.psqlWriteMultipleIntegers(id: id, length: 0) + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } + + mutating fileprivate func psqlLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + self.writeInteger(UInt32(0)) // placeholder + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift index 71994596..7d073873 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -2,7 +2,10 @@ import NIOCore @testable import PostgresNIO extension ByteBuffer { - + mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { + self.writeInteger(messageID.rawValue) + } + static func backendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer { var byteBuffer = ByteBuffer() try byteBuffer.writeBackendMessage(id: id, payload) diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift similarity index 99% rename from Sources/PostgresNIO/New/PostgresFrontendMessage.swift rename to Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index ef7ce8f8..010667dc 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -1,4 +1,5 @@ import NIOCore +import PostgresNIO /// A wire message that is created by a Postgres client to be consumed by Postgres server. /// From 8f8557bfe6a3ca379da2cf84059acbdba1c3958f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 20 Aug 2023 17:46:21 +0200 Subject: [PATCH 021/106] Remove PSQLError.Code.clientClosesConnection (#400) --- .../ConnectionStateMachine.swift | 6 +++--- Sources/PostgresNIO/New/PSQLError.swift | 14 ++------------ .../PostgresNIO/New/PostgresChannelHandler.swift | 6 ++++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 3 +-- .../New/PostgresChannelHandlerTests.swift | 3 +-- 5 files changed, 11 insertions(+), 21 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 22c4087e..eca251ff 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -550,7 +550,7 @@ struct ConnectionStateMachine { // check if we are quiescing. if so fail task immidiatly switch self.quiescingState { case .quiescing: - psqlErrror = PSQLError.clientClosesConnection(underlying: nil) + psqlErrror = PSQLError.clientClosedConnection(underlying: nil) case .notQuiescing: switch self.state { @@ -570,7 +570,7 @@ struct ConnectionStateMachine { return self.executeTask(task) case .closing(let error): - psqlErrror = PSQLError.clientClosesConnection(underlying: error) + psqlErrror = PSQLError.clientClosedConnection(underlying: error) case .closed(clientInitiated: true, error: let error): psqlErrror = PSQLError.clientClosedConnection(underlying: error) @@ -1033,7 +1033,7 @@ extension ConnectionStateMachine { } return false - case .clientClosesConnection, .clientClosedConnection: + case .clientClosedConnection: preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen") case .serverClosedConnection: return true diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 1fec59b1..7060a690 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -18,7 +18,6 @@ public struct PSQLError: Error { case queryCancelled case tooManyParameters - case clientClosesConnection case clientClosedConnection case serverClosedConnection case connectionError @@ -46,7 +45,6 @@ public struct PSQLError: Error { public static let invalidCommandTag = Self(.invalidCommandTag) public static let queryCancelled = Self(.queryCancelled) public static let tooManyParameters = Self(.tooManyParameters) - public static let clientClosesConnection = Self(.clientClosesConnection) public static let clientClosedConnection = Self(.clientClosedConnection) public static let serverClosedConnection = Self(.serverClosedConnection) public static let connectionError = Self(.connectionError) @@ -54,8 +52,8 @@ public struct PSQLError: Error { public static let listenFailed = Self.init(.listenFailed) public static let unlistenFailed = Self.init(.unlistenFailed) - @available(*, deprecated, renamed: "clientClosesConnection") - public static let connectionQuiescing = Self.clientClosesConnection + @available(*, deprecated, renamed: "clientClosedConnection") + public static let connectionQuiescing = Self.clientClosedConnection @available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead") public static let connectionClosed = Self.serverClosedConnection @@ -86,8 +84,6 @@ public struct PSQLError: Error { return "queryCancelled" case .tooManyParameters: return "tooManyParameters" - case .clientClosesConnection: - return "clientClosesConnection" case .clientClosedConnection: return "clientClosedConnection" case .serverClosedConnection: @@ -387,12 +383,6 @@ public struct PSQLError: Error { return new } - static func clientClosesConnection(underlying: Error?) -> PSQLError { - var error = PSQLError(code: .clientClosesConnection) - error.underlying = underlying - return error - } - static func clientClosedConnection(underlying: Error?) -> PSQLError { var error = PSQLError(code: .clientClosedConnection) error.underlying = underlying diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7b31a776..6d9d08b3 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -576,8 +576,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } // 3. fire an error - context.fireErrorCaught(cleanup.error) - + if cleanup.error.code != .clientClosedConnection { + context.fireErrorCaught(cleanup.error) + } + // 4. close the connection or fire channel inactive switch cleanup.action { case .close: diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 1989e5bc..c4f30624 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -37,8 +37,7 @@ extension PSQLError { return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: return self - case .clientClosesConnection, - .clientClosedConnection, + case .clientClosedConnection, .serverClosedConnection: return PostgresError.connectionClosed case .connectionError: diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index b047cd72..b81d0899 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -25,8 +25,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) defer { - do { try embedded.finish() } - catch { print("\(String(reflecting: error))") } + XCTAssertNoThrow({ try embedded.finish() }) } var maybeMessage: PostgresFrontendMessage? From 689e4aabd783df4d8fb0eedee0787014a141f9e8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Aug 2023 17:12:10 +0200 Subject: [PATCH 022/106] Use variadic generics to decode rows in Swift 5.9 (#341) --- .../New/PostgresRow-multi-decode.swift | 2 + .../PostgresRowSequence-multi-decode.swift | 2 +- .../PostgresNIO/New/VariadicGenerics.swift | 174 ++++++++++++++++++ Tests/IntegrationTests/AsyncTests.swift | 7 +- Tests/IntegrationTests/PostgresNIOTests.swift | 8 +- .../New/PostgresRowSequenceTests.swift | 12 +- 6 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 Sources/PostgresNIO/New/VariadicGenerics.swift diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index cb62c325..71aa04dc 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -1,5 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh +#if compiler(<5.9) extension PostgresRow { @inlinable @_alwaysEmitIntoClient @@ -1171,3 +1172,4 @@ extension PostgresRow { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) } } +#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index 53d9a7ea..f45357d8 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,6 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh -#if canImport(_Concurrency) +#if compiler(<5.9) extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift new file mode 100644 index 00000000..312d36dc --- /dev/null +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -0,0 +1,174 @@ +#if compiler(>=5.9) +extension PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + try self.decode(Column.self, context: .default, file: file, line: line) + } + + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + precondition(self.columns.count >= 1) + let columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + let column = columnIterator.next().unsafelyUnwrapped + let swiftTargetType: Any.Type = Column.self + + do { + let r0 = try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + let packCount = ComputeParameterPackLength.count(ofPack: repeat (each Column).self) + precondition(self.columns.count >= packCount) + + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var columnIterator = self.columns.makeIterator() + + return ( + repeat try Self.decodeNextColumn( + (each Column).self, + cellIterator: &cellIterator, + columnIterator: &columnIterator, + columnIndex: &columnIndex, + context: context, + file: file, + line: line + ) + ) + } + + @inlinable + static func decodeNextColumn( + _ columnType: Column.Type, + cellIterator: inout IndexingIterator, + columnIterator: inout IndexingIterator<[RowDescription.Column]>, + columnIndex: inout Int, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws -> Column { + defer { columnIndex += 1 } + + let column = columnIterator.next().unsafelyUnwrapped + var cellData = cellIterator.next().unsafelyUnwrapped + do { + return try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: Column.self, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + try self.decode(columnType, context: .default, file: file, line: line) + } +} + +extension AsyncSequence where Element == PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(Column.self, context: context, file: file, line: line) + } + } + + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(Column.self, context: .default, file: file, line: line) + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(columnType, context: context, file: file, line: line) + } + } + + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(columnType, context: .default, file: file, line: line) + } +} + +@usableFromInline +enum ComputeParameterPackLength { + @usableFromInline + enum BoolConverter { + @usableFromInline + typealias Bool = Swift.Bool + } + + @inlinable + static func count(ofPack t: repeat each T) -> Int { + MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride + } +} +#endif // compiler(>=5.9) + diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index bf945a67..5c77ba29 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -8,7 +8,6 @@ import NIOPosix import NIOCore final class AsyncPostgresConnectionTests: XCTestCase { - func test1kRoundTrips() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -37,7 +36,8 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) var counter = 0 - for try await element in rows.decode(Int.self, context: .default) { + for try await row in rows { + let element = try row.decode(Int.self) XCTAssertEqual(element, counter + 1) counter += 1 } @@ -259,7 +259,8 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) var counter = 1 - for try await element in rows.decode(Int.self, context: .default) { + for try await row in rows { + let element = try row.decode(Int.self, context: .default) XCTAssertEqual(element, counter) counter += 1 } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 19c4e167..ea4d8d05 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1246,10 +1246,10 @@ final class PostgresNIOTests: XCTestCase { return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) - var resutIterator = queries?.makeIterator() - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "a") - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "b") - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "c") + var resultIterator = queries?.makeIterator() + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "a") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "b") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "c") } // https://github.com/vapor/postgres-nio/issues/122 diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 872c098d..816daf04 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -59,7 +59,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 if counter == 64 { @@ -135,7 +135,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 } @@ -163,7 +163,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 } @@ -220,7 +220,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) + XCTAssertEqual(try row1?.decode(Int.self), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .success("SELECT 1")) @@ -252,7 +252,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) + XCTAssertEqual(try row1?.decode(Int.self), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) @@ -415,7 +415,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 1 for _ in 0..<(2 * messagePerChunk - 1) { let row = try await rowIterator.next() - XCTAssertEqual(try row?.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row?.decode(Int.self), counter) counter += 1 } From 0d9f13be024047397c0f1bf72edf7ffd36cac67a Mon Sep 17 00:00:00 2001 From: Marius Seufzer <44228394+marius-se@users.noreply.github.com> Date: Sun, 27 Aug 2023 23:42:19 +1200 Subject: [PATCH 023/106] Add `PostgresDynamicTypeThrowingEncodable` and `PostgresDynamicTypeEncodable` (#365) --- .../New/Data/Array+PostgresCodable.swift | 7 ++ .../New/Data/Range+PostgresCodable.swift | 22 ++++++- Sources/PostgresNIO/New/PostgresCodable.swift | 65 +++++++++++++++---- Sources/PostgresNIO/New/PostgresQuery.swift | 22 +++---- .../New/PostgresQueryTests.swift | 37 +++++++++++ 5 files changed, 128 insertions(+), 25 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index fb2b62e3..d605a6c1 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -136,6 +136,10 @@ extension Array: PostgresEncodable where Element: PostgresArrayEncodable { } } +// explicitly conforming to PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresThrowingDynamicTypeEncodable where Element: PostgresArrayEncodable {} + extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { Element.psqlArrayType @@ -173,6 +177,9 @@ extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncoda } } +// explicitly conforming to PostgresDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresDynamicTypeEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable {} extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { public init( diff --git a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift index e5a3e60e..6279cf4b 100644 --- a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift @@ -191,6 +191,11 @@ extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where } } +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension PostgresRange: PostgresThrowingDynamicTypeEncodable & PostgresDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + extension PostgresRange where Bound: Comparable { @inlinable init(range: Range) { @@ -227,6 +232,11 @@ extension Range: PostgresEncodable where Bound: PostgresRangeEncodable { extension Range: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Range: PostgresDynamicTypeEncodable & PostgresThrowingDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { @inlinable public init( @@ -249,7 +259,7 @@ extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { else { throw PostgresDecodingError.Code.failure } - + self = lowerBound..( @@ -301,7 +319,7 @@ extension ClosedRange: PostgresDecodable where Bound: PostgresRangeDecodable { if lowerBound > upperBound { throw PostgresDecodingError.Code.failure } - + self = lowerBound...upperBound } } diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 36937de4..53dbd708 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -2,29 +2,62 @@ import NIOCore import class Foundation.JSONEncoder import class Foundation.JSONDecoder +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +public protocol PostgresThrowingDynamicTypeEncodable { + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)`` + var psqlType: PostgresDataType { get } + + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. + var psqlFormat: PostgresFormat { get } + + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws +} + +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +/// +/// This is the non-throwing alternative to ``PostgresThrowingDynamicTypeEncodable``. It allows users +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without having to spell `try`. +public protocol PostgresDynamicTypeEncodable: PostgresThrowingDynamicTypeEncodable { + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) +} + /// A type that can encode itself to a postgres wire binary representation. -public protocol PostgresEncodable { +public protocol PostgresEncodable: PostgresThrowingDynamicTypeEncodable { // TODO: Rename to `PostgresThrowingEncodable` with next major release - /// identifies the data type that we will encode into `byteBuffer` in `encode` + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)``. static var psqlType: PostgresDataType { get } - /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. static var psqlFormat: PostgresFormat { get } - - /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting - /// the byte count. This method is called from the ``PostgresBindings``. - func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws } /// A type that can encode itself to a postgres wire binary representation. It enforces that the /// ``PostgresEncodable/encode(into:context:)-1jkcp`` does not throw. This allows users -/// to create ``PostgresQuery``s using the `ExpressibleByStringInterpolation` without +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without /// having to spell `try`. -public protocol PostgresNonThrowingEncodable: PostgresEncodable { +public protocol PostgresNonThrowingEncodable: PostgresEncodable, PostgresDynamicTypeEncodable { // TODO: Rename to `PostgresEncodable` with next major release - - func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) } /// A type that can decode itself from a postgres wire binary representation. @@ -84,6 +117,14 @@ extension PostgresDecodable { public typealias PostgresCodable = PostgresEncodable & PostgresDecodable extension PostgresEncodable { + @inlinable + public var psqlType: PostgresDataType { Self.psqlType } + + @inlinable + public var psqlFormat: PostgresFormat { Self.psqlFormat } +} + +extension PostgresThrowingDynamicTypeEncodable { @inlinable func encodeRaw( into buffer: inout ByteBuffer, @@ -103,7 +144,7 @@ extension PostgresEncodable { } } -extension PostgresNonThrowingEncodable { +extension PostgresDynamicTypeEncodable { @inlinable func encodeRaw( into buffer: inout ByteBuffer, diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 4ca1e454..1cfcf2dc 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -44,13 +44,13 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation(_ value: Value) throws { + public mutating func appendInterpolation(_ value: Value) throws { try self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } @inlinable - public mutating func appendInterpolation(_ value: Optional) throws { + public mutating func appendInterpolation(_ value: Optional) throws { switch value { case .none: self.binds.appendNull() @@ -62,13 +62,13 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation(_ value: Value) { + public mutating func appendInterpolation(_ value: Value) { self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } @inlinable - public mutating func appendInterpolation(_ value: Optional) { + public mutating func appendInterpolation(_ value: Optional) { switch value { case .none: self.binds.appendNull() @@ -80,7 +80,7 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation( + public mutating func appendInterpolation( _ value: Value, context: PostgresEncodingContext ) throws { @@ -136,8 +136,8 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - init(value: Value, protected: Bool) { - self.init(dataType: Value.psqlType, format: Value.psqlFormat, protected: protected) + init(value: Value, protected: Bool) { + self.init(dataType: value.psqlType, format: value.psqlFormat, protected: protected) } } @@ -168,12 +168,12 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - public mutating func append(_ value: Value) throws { + public mutating func append(_ value: Value) throws { try self.append(value, context: .default) } @inlinable - public mutating func append( + public mutating func append( _ value: Value, context: PostgresEncodingContext ) throws { @@ -182,12 +182,12 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - public mutating func append(_ value: Value) { + public mutating func append(_ value: Value) { self.append(value, context: .default) } @inlinable - public mutating func append( + public mutating func append( _ value: Value, context: PostgresEncodingContext ) { diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index f50d414a..4930f0c4 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -31,6 +31,27 @@ final class PostgresQueryTests: XCTestCase { XCTAssertEqual(query.binds.bytes, expected) } + func testStringInterpolationWithDynamicType() { + let type = PostgresDataType(16435) + let format = PostgresFormat.binary + let dynamicString = DynamicString(value: "Hello world", psqlType: type, psqlFormat: format) + + let query: PostgresQuery = """ + INSERT INTO foo (dynamicType) SET (\(dynamicString)); + """ + + XCTAssertEqual(query.sql, "INSERT INTO foo (dynamicType) SET ($1);") + + var expectedBindsBytes = ByteBuffer() + expectedBindsBytes.writeInteger(Int32(dynamicString.value.utf8.count)) + expectedBindsBytes.writeString(dynamicString.value) + + let expectedMetadata: [PostgresBindings.Metadata] = [.init(dataType: type, format: format, protected: true)] + + XCTAssertEqual(query.binds.bytes, expectedBindsBytes) + XCTAssertEqual(query.binds.metadata, expectedMetadata) + } + func testStringInterpolationWithCustomJSONEncoder() { struct Foo: Codable, PostgresCodable { var helloWorld: String @@ -89,3 +110,19 @@ final class PostgresQueryTests: XCTestCase { XCTAssertEqual(query.binds.bytes, expected) } } + +extension PostgresQueryTests { + struct DynamicString: PostgresDynamicTypeEncodable { + let value: String + + var psqlType: PostgresDataType + var psqlFormat: PostgresFormat + + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresNIO.PostgresEncodingContext + ) where JSONEncoder: PostgresJSONEncoder { + byteBuffer.writeString(value) + } + } +} From d89a72304d2cf847f115773467432ce955e43981 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 28 Aug 2023 03:38:11 -0500 Subject: [PATCH 024/106] Improve the logo image used by the DocC catalog (#404) --- .../Docs.docc/images/vapor-postgres-logo.svg | 37 +++++++++++-------- .../PostgresNIO/Docs.docc/theme-settings.json | 6 +-- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg index e1c1223b..d118faab 100644 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -2,35 +2,40 @@ - - - - - - - - - - - - + PostgresNIO + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index c6ce054e..e9fc3d9d 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -22,14 +22,14 @@ "light": "rgb(255, 255, 255)" }, "psql-blue": "#336791", - "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #1f1d1f 100%)", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psql-blue)", "documentation-intro-accent-outer": { "dark": "rgb(255, 255, 255)", - "light": "rgb(51, 51, 51)" + "light": "rgb(0, 0, 0)" }, "documentation-intro-accent-inner": { - "dark": "rgb(51, 51, 51)", + "dark": "rgb(0, 0, 0)", "light": "rgb(255, 255, 255)" } }, From abca6b390235ae337999d367c40cc40c99629385 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 29 Aug 2023 18:09:29 +0200 Subject: [PATCH 025/106] Fix Segmentation faults in Swift 5.8 (#406) --- .../ConnectionStateMachine.swift | 292 ++++++++---------- 1 file changed, 122 insertions(+), 170 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index eca251ff..125d26bb 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -333,11 +333,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.authentication(message))) } - return self.avoidingStateMachineCoW { machine in - let action = authState.authenticationMessageReceived(message) - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = authState.authenticationMessageReceived(message) + self.state = .authenticating(authState) + return self.modify(with: action) } mutating func backendKeyDataReceived(_ keyData: PostgresBackendMessage.BackendKeyData) -> ConnectionAction { @@ -363,29 +362,29 @@ struct ConnectionStateMachine { .closing: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterStatus(status))) case .authenticated(let keyData, var parameters): - return self.avoidingStateMachineCoW { machine in - parameters[status.parameter] = status.value - machine.state = .authenticated(keyData, parameters) - return .wait - } + self.state = .modifying // avoid CoW + parameters[status.parameter] = status.value + self.state = .authenticated(keyData, parameters) + return .wait + case .readyForQuery(var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .readyForQuery(connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .readyForQuery(connectionContext) + return .wait + case .extendedQuery(let query, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .extendedQuery(query, connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .extendedQuery(query, connectionContext) + return .wait + case .closeCommand(let closeState, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .closeCommand(closeState, connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .closeCommand(closeState, connectionContext) + return .wait + case .initialized, .closed: preconditionFailure("We shouldn't receive messages if we are not connected") @@ -407,29 +406,29 @@ struct ConnectionStateMachine { if authState.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = authState.errorReceived(errorMessage) - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = authState.errorReceived(errorMessage) + self.state = .authenticating(authState) + return self.modify(with: action) + case .closeCommand(var closeStateMachine, let connectionContext): if closeStateMachine.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = closeStateMachine.errorReceived(errorMessage) - machine.state = .closeCommand(closeStateMachine, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeStateMachine.errorReceived(errorMessage) + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) + case .extendedQuery(var extendedQueryState, let connectionContext): if extendedQueryState.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = extendedQueryState.errorReceived(errorMessage) - machine.state = .extendedQuery(extendedQueryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQueryState.errorReceived(errorMessage) + self.state = .extendedQuery(extendedQueryState, connectionContext) + return self.modify(with: action) + case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -492,11 +491,11 @@ struct ConnectionStateMachine { mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> ConnectionAction { switch self.state { case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = extendedQuery.noticeReceived(notice) - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.noticeReceived(notice) + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + default: return .wait } @@ -612,11 +611,10 @@ struct ConnectionStateMachine { return .wait case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = extendedQuery.channelReadComplete() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.channelReadComplete() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) case .modifying: preconditionFailure("Invalid state") @@ -642,17 +640,17 @@ struct ConnectionStateMachine { case .readyForQuery: return .read case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = extendedQuery.readEventCaught() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.readEventCaught() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(var closeState, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = closeState.readEventCaught() - machine.state = .closeCommand(closeState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeState.readEventCaught() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) + case .closing: return .read case .closed: @@ -667,11 +665,11 @@ struct ConnectionStateMachine { mutating func parseCompleteReceived() -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.parseCompletedReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.parseCompletedReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parseComplete)) } @@ -682,21 +680,20 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.bindComplete)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.bindCompleteReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.bindCompleteReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func parameterDescriptionReceived(_ description: PostgresBackendMessage.ParameterDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.parameterDescriptionReceived(description) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.parameterDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterDescription(description))) } @@ -705,11 +702,11 @@ struct ConnectionStateMachine { mutating func rowDescriptionReceived(_ description: RowDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.rowDescriptionReceived(description) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.rowDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -718,11 +715,11 @@ struct ConnectionStateMachine { mutating func noDataReceived() -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.noDataReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.noDataReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.noData)) } @@ -737,11 +734,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.closeComplete)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = closeState.closeCompletedReceived() - machine.state = .closeCommand(closeState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeState.closeCompletedReceived() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) } mutating func commandCompletedReceived(_ commandTag: String) -> ConnectionAction { @@ -749,11 +745,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.commandCompletedReceived(commandTag) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { @@ -761,11 +756,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.emptyQueryResponseReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { @@ -773,11 +767,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.dataRowReceived(dataRow) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } // MARK: Consumer @@ -787,11 +780,10 @@ struct ConnectionStateMachine { preconditionFailure("Tried to cancel stream without active query") } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.cancel() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func requestQueryRows() -> ConnectionAction { @@ -799,11 +791,10 @@ struct ConnectionStateMachine { preconditionFailure("Tried to consume next row, without active query") } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.requestQueryRows() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } // MARK: - Private Methods - @@ -813,12 +804,11 @@ struct ConnectionStateMachine { preconditionFailure("Can only start authentication after connect or ssl establish") } - return self.avoidingStateMachineCoW { machine in - var authState = AuthenticationStateMachine(authContext: authContext) - let action = authState.start() - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var authState = AuthenticationStateMachine(authContext: authContext) + let action = authState.start() + self.state = .authenticating(authState) + return self.modify(with: action) } private mutating func closeConnectionAndCleanup(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction { @@ -944,19 +934,18 @@ struct ConnectionStateMachine { switch task { case .extendedQuery(let queryContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) - let action = extendedQuery.start() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) + let action = extendedQuery.start() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(let closeContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var closeStateMachine = CloseStateMachine(closeContext: closeContext) - let action = closeStateMachine.start() - machine.state = .closeCommand(closeStateMachine, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var closeStateMachine = CloseStateMachine(closeContext: closeContext) + let action = closeStateMachine.start() + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) } } @@ -965,43 +954,6 @@ struct ConnectionStateMachine { } } -// MARK: CoW helpers - -extension ConnectionStateMachine { - /// So, uh...this function needs some explaining. - /// - /// While the state machine logic above is great, there is a downside to having all of the state machine data in - /// associated data on enumerations: any modification of that data will trigger copy on write for heap-allocated - /// data. That means that for _every operation on the state machine_ we will CoW our underlying state, which is - /// not good. - /// - /// The way we can avoid this is by using this helper function. It will temporarily set state to a value with no - /// associated data, before attempting the body of the function. It will also verify that the state machine never - /// remains in this bad state. - /// - /// A key note here is that all callers must ensure that they return to a good state before they exit. - /// - /// Sadly, because it's generic and has a closure, we need to force it to be inlined at all call sites, which is - /// not ideal. - @inline(__always) - private mutating func avoidingStateMachineCoW(_ body: (inout ConnectionStateMachine) -> ReturnType) -> ReturnType { - self.state = .modifying - defer { - assert(!self.isModifying) - } - - return body(&self) - } - - private var isModifying: Bool { - if case .modifying = self.state { - return true - } else { - return false - } - } -} - extension ConnectionStateMachine { func shouldCloseConnection(reason error: PSQLError) -> Bool { switch error.code.base { From 92ee156a649b88f8926bcad6056cf77126b90405 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 18 Sep 2023 21:44:17 +0200 Subject: [PATCH 026/106] Update SSWG Graduation Level (#409) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b4f8f70e..2123262f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@

-SSWG Incubation +SSWG Incubation Level: Graduated Documentation MIT License Continuous Integration From 4ab6d0aa7ac71f74f9d69094786a6d9e447b5722 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 12 Oct 2023 15:21:42 -0500 Subject: [PATCH 027/106] Update minimum Swift requirement to 5.7 (#414) Bump required Swift to 5.7, update dependency version requirements, update CI for Swift and Postgres versions, do some interesting things with the API docs and README. --- .github/workflows/test.yml | 49 +++++++++--------- Package.swift | 16 +++--- README.md | 22 +++++--- .../Docs.docc/images/vapor-postgres-logo.svg | 51 +++++++++++++------ .../PostgresNIO/Docs.docc/theme-settings.json | 2 +- docker-compose.yml | 3 ++ 6 files changed, 89 insertions(+), 54 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2da05f81..91895532 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,13 +18,13 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.6-focal - swift:5.7-jammy - swift:5.8-jammy - - swiftlang/swift:nightly-5.9-jammy + - swift:5.9-jammy + - swiftlang/swift:nightly-5.10-jammy - swiftlang/swift:nightly-main-jammy include: - - swift-image: swift:5.8-jammy + - swift-image: swift:5.9-jammy code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -37,7 +37,7 @@ jobs: printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" swift --version - name: Check out package - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run unit tests with Thread Sanitizer env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} @@ -53,18 +53,18 @@ jobs: fail-fast: false matrix: postgres-image: - - postgres:15 - - postgres:13 - - postgres:11 + - postgres:16 + - postgres:14 + - postgres:12 include: - - postgres-image: postgres:15 + - postgres-image: postgres:16 postgres-auth: scram-sha-256 - - postgres-image: postgres:13 + - postgres-image: postgres:14 postgres-auth: md5 - - postgres-image: postgres:11 + - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.8-jammy + image: swift:5.9-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -109,15 +109,15 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -135,13 +135,13 @@ jobs: matrix: postgres-formula: # Only test one version on macOS, let Linux do the rest - - postgresql@14 + - postgresql@15 postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 xcode-version: - '~14.3' - - '15.0-beta' + - '~15.0' runs-on: macos-13 env: POSTGRES_HOSTNAME: 127.0.0.1 @@ -164,7 +164,7 @@ jobs: pg_ctl start --wait timeout-minutes: 2 - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run all tests run: swift test @@ -174,21 +174,24 @@ jobs: container: swift:jammy steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 # https://github.com/actions/checkout/issues/766 - - name: Mark the workspace as safe - run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: API breaking changes - run: swift package diagnose-api-breaking-changes origin/main + run: | + git config --global --add safe.directory "${GITHUB_WORKSPACE}" + swift package diagnose-api-breaking-changes origin/main gh-codeql: runs-on: ubuntu-latest - permissions: { security-events: write } + container: swift:5.8-jammy # CodeQL currently broken with 5.9 + permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - name: Mark repo safe in non-fake global config + run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: Initialize CodeQL uses: github/codeql-action/init@v2 with: diff --git a/Package.swift b/Package.swift index a45925ed..b3ff085c 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.6 +// swift-tools-version:5.7 import PackageDescription let package = Package( @@ -13,13 +13,13 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.18.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), - .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), - .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), ], targets: [ .target( diff --git a/README.md b/README.md index 2123262f..bca6e82a 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,21 @@

-SSWG Incubation Level: Graduated -Documentation -MIT License -Continuous Integration -Swift 5.6 + + Documentation + + + MIT License + + + Continuous Integration + + + Swift 5.7 - 5.9 + + + SSWG Incubation Level: Graduated +


🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. @@ -170,7 +180,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.6]: https://swift.org +[Swift 5.7]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg index d118faab..2b3fe0b1 100644 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -22,20 +22,39 @@ } PostgresNIO - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index e9fc3d9d..a8042a54 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -18,7 +18,7 @@ }, "color": { "fill": { - "dark": "rgb(20, 20, 22)", + "dark": "rgb(0, 0, 0)", "light": "rgb(255, 255, 255)" }, "psql-blue": "#336791", diff --git a/docker-compose.yml b/docker-compose.yml index 68797651..3eff4249 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,9 @@ x-shared-config: &shared_config - 5432:5432 services: + psql-16: + image: postgres:16 + <<: *shared_config psql-15: image: postgres:15 <<: *shared_config From 1a76cdc6dc9ba9a967b79a3593ec30ce34669f29 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 12 Oct 2023 17:42:24 -0500 Subject: [PATCH 028/106] [skip ci] Fix up README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bca6e82a..489d0e29 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

- Documentation + Documentation MIT License @@ -16,10 +16,10 @@ Continuous Integration - Swift 5.7 - 5.9 + Swift 5.7 - 5.9 - SSWG Incubation Level: Graduated + SSWG Incubation Level: Graduated


From d4d7bed0fde77934a829daed5113f95ceaa7aba0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 09:02:18 +0200 Subject: [PATCH 029/106] Add target `ConnectionPoolModule` (#412) Add `ConnectionPoolModule` We want to land a new ConnectionPool into PostgresNIO in the comming weeks. Since this pool is abstract, let's create a target and product for it. The target and product are both underscored, to signal that we don't make any API stability guarantees. --- Package.swift | 21 +++++++++++++++++++ Sources/ConnectionPoolModule/gitkeep.swift | 1 + Tests/ConnectionPoolModuleTests/gitkeep.swift | 1 + 3 files changed, 23 insertions(+) create mode 100644 Sources/ConnectionPoolModule/gitkeep.swift create mode 100644 Tests/ConnectionPoolModuleTests/gitkeep.swift diff --git a/Package.swift b/Package.swift index b3ff085c..814335bd 100644 --- a/Package.swift +++ b/Package.swift @@ -11,9 +11,11 @@ let package = Package( ], products: [ .library(name: "PostgresNIO", targets: ["PostgresNIO"]), + .library(name: "_ConnectionPoolModule", targets: ["_ConnectionPoolModule"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), @@ -25,6 +27,7 @@ let package = Package( .target( name: "PostgresNIO", dependencies: [ + .target(name: "_ConnectionPoolModule"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), @@ -38,6 +41,14 @@ let package = Package( .product(name: "NIOFoundationCompat", package: "swift-nio"), ] ), + .target( + name: "_ConnectionPoolModule", + dependencies: [ + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "DequeModule", package: "swift-collections"), + ], + path: "Sources/ConnectionPoolModule" + ), .testTarget( name: "PostgresNIOTests", dependencies: [ @@ -46,6 +57,16 @@ let package = Package( .product(name: "NIOTestUtils", package: "swift-nio"), ] ), + .testTarget( + name: "ConnectionPoolModuleTests", + dependencies: [ + .target(name: "_ConnectionPoolModule"), + .product(name: "DequeModule", package: "swift-collections"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOEmbedded", package: "swift-nio"), + ] + ), .testTarget( name: "IntegrationTests", dependencies: [ diff --git a/Sources/ConnectionPoolModule/gitkeep.swift b/Sources/ConnectionPoolModule/gitkeep.swift new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Sources/ConnectionPoolModule/gitkeep.swift @@ -0,0 +1 @@ + diff --git a/Tests/ConnectionPoolModuleTests/gitkeep.swift b/Tests/ConnectionPoolModuleTests/gitkeep.swift new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/gitkeep.swift @@ -0,0 +1 @@ + From d6d3510c7053246de7a673d999bd0ed6f23fe468 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 13 Oct 2023 07:50:00 -0500 Subject: [PATCH 030/106] Fix test filter in CI --- .github/workflows/test.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 91895532..cc34ddcd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,13 +42,12 @@ jobs: env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter=^PostgresNIOTests --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 linux-integration-and-dependencies: - if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: @@ -129,7 +128,6 @@ jobs: run: swift test --package-path fluent-postgres-driver macos-all: - if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: From c6c28a6df558dabc338aa1c42a77de28a40d43b7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 15:58:01 +0200 Subject: [PATCH 031/106] Vendor SwiftNIO NIOLock into the new `ConnectionPoolModule` target (#416) The new `ConnectionPoolModule` shall be dependency free. But we need a lock. Let's vendor NIOLock from SwiftNIO. --- NOTICE.txt | 9 +- Sources/ConnectionPoolModule/NIOLock.swift | 268 +++++++++++++++++++++ Sources/ConnectionPoolModule/gitkeep.swift | 1 - 3 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 Sources/ConnectionPoolModule/NIOLock.swift delete mode 100644 Sources/ConnectionPoolModule/gitkeep.swift diff --git a/NOTICE.txt b/NOTICE.txt index 9547a780..e704f7e6 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -2,7 +2,7 @@ // // This source file is part of the Vapor open source project // -// Copyright (c) 2017-2021 Vapor project authors +// Copyright (c) 2017-2023 Vapor project authors // Licensed under MIT // // See LICENSE for license information @@ -11,3 +11,10 @@ // //===----------------------------------------------------------------------===// +This product contains a derivation of the NIOLock implementation +from Swift NIO. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/apple/swift-nio diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift new file mode 100644 index 00000000..dbc7dbe9 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -0,0 +1,268 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Darwin) +import Darwin +#elseif os(Windows) +import ucrt +import WinSDK +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#else +#error("The concurrency NIOLock module was unable to identify your C library.") +#endif + +#if os(Windows) +@usableFromInline +typealias LockPrimitive = SRWLOCK +#else +@usableFromInline +typealias LockPrimitive = pthread_mutex_t +#endif + +@usableFromInline +enum LockOperations { } + +extension LockOperations { + @inlinable + static func create(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + InitializeSRWLock(mutex) +#else + var attr = pthread_mutexattr_t() + pthread_mutexattr_init(&attr) + debugOnly { + pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) + } + + let err = pthread_mutex_init(mutex, &attr) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func destroy(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + // SRWLOCK does not need to be free'd +#else + let err = pthread_mutex_destroy(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func lock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + AcquireSRWLockExclusive(mutex) +#else + let err = pthread_mutex_lock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func unlock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + ReleaseSRWLockExclusive(mutex) +#else + let err = pthread_mutex_unlock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } +} + +// Tail allocate both the mutex and a generic value using ManagedBuffer. +// Both the header pointer and the elements pointer are stable for +// the class's entire lifetime. +// +// However, for safety reasons, we elect to place the lock in the "elements" +// section of the buffer instead of the head. The reasoning here is subtle, +// so buckle in. +// +// _As a practical matter_, the implementation of ManagedBuffer ensures that +// the pointer to the header is stable across the lifetime of the class, and so +// each time you call `withUnsafeMutablePointers` or `withUnsafeMutablePointerToHeader` +// the value of the header pointer will be the same. This is because ManagedBuffer uses +// `Builtin.addressOf` to load the value of the header, and that does ~magic~ to ensure +// that it does not invoke any weird Swift accessors that might copy the value. +// +// _However_, the header is also available via the `.header` field on the ManagedBuffer. +// This presents a problem! The reason there's an issue is that `Builtin.addressOf` and friends +// do not interact with Swift's exclusivity model. That is, the various `with` functions do not +// conceptually trigger a mutating access to `.header`. For elements this isn't a concern because +// there's literally no other way to perform the access, but for `.header` it's entirely possible +// to accidentally recursively read it. +// +// Our implementation is free from these issues, so we don't _really_ need to worry about it. +// However, out of an abundance of caution, we store the Value in the header, and the LockPrimitive +// in the trailing elements. We still don't use `.header`, but it's better to be safe than sorry, +// and future maintainers will be happier that we were cautious. +// +// See also: https://github.com/apple/swift/pull/40000 +@usableFromInline +final class LockStorage: ManagedBuffer { + + @inlinable + static func create(value: Value) -> Self { + let buffer = Self.create(minimumCapacity: 1) { _ in + return value + } + let storage = unsafeDowncast(buffer, to: Self.self) + + storage.withUnsafeMutablePointers { _, lockPtr in + LockOperations.create(lockPtr) + } + + return storage + } + + @inlinable + func lock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.lock(lockPtr) + } + } + + @inlinable + func unlock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.unlock(lockPtr) + } + } + + @inlinable + deinit { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.destroy(lockPtr) + } + } + + @inlinable + func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointerToElements { lockPtr in + return try body(lockPtr) + } + } + + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointers { valuePtr, lockPtr in + LockOperations.lock(lockPtr) + defer { LockOperations.unlock(lockPtr) } + return try mutate(&valuePtr.pointee) + } + } +} + +extension LockStorage: @unchecked Sendable { } + +/// A threading lock based on `libpthread` instead of `libdispatch`. +/// +/// - note: ``NIOLock`` has reference semantics. +/// +/// This object provides a lock on top of a single `pthread_mutex_t`. This kind +/// of lock is safe to use with `libpthread`-based threading models, such as the +/// one used by NIO. On Windows, the lock is based on the substantially similar +/// `SRWLOCK` type. +@usableFromInline +struct NIOLock { + @usableFromInline + internal let _storage: LockStorage + + /// Create a new lock. + @inlinable + init() { + self._storage = .create(value: ()) + } + + /// Acquire the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `unlock`, to simplify lock handling. + @inlinable + func lock() { + self._storage.lock() + } + + /// Release the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `lock`, to simplify lock handling. + @inlinable + func unlock() { + self._storage.unlock() + } + + @inlinable + internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + return try self._storage.withLockPrimitive(body) + } +} + +extension NIOLock { + /// Acquire the lock for the duration of the given block. + /// + /// This convenience method should be preferred to `lock` and `unlock` in + /// most situations, as it ensures that the lock will be released regardless + /// of how `body` exits. + /// + /// - Parameter body: The block to execute while holding the lock. + /// - Returns: The value returned by the block. + @inlinable + func withLock(_ body: () throws -> T) rethrows -> T { + self.lock() + defer { + self.unlock() + } + return try body() + } + + @inlinable + func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + try self.withLock(body) + } +} + +extension NIOLock: Sendable {} + +extension UnsafeMutablePointer { + @inlinable + func assertValidAlignment() { + assert(UInt(bitPattern: self) % UInt(MemoryLayout.alignment) == 0) + } +} + +/// A utility function that runs the body code only in debug builds, without +/// emitting compiler warnings. +/// +/// This is currently the only way to do this in Swift: see +/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. +@inlinable +internal func debugOnly(_ body: () -> Void) { + // FIXME: duplicated with NIO. + assert({ body(); return true }()) +} diff --git a/Sources/ConnectionPoolModule/gitkeep.swift b/Sources/ConnectionPoolModule/gitkeep.swift deleted file mode 100644 index 8b137891..00000000 --- a/Sources/ConnectionPoolModule/gitkeep.swift +++ /dev/null @@ -1 +0,0 @@ - From 8fbf8ff7309921ebe73a9500b6d6a8bca161861b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 16:38:59 +0200 Subject: [PATCH 032/106] Add `PooledConnection` protocol (#417) --- .../ConnectionPoolModule/ConnectionPool.swift | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPool.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift new file mode 100644 index 00000000..290e0679 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -0,0 +1,58 @@ +/// A connection that can be pooled in a ``ConnectionPool`` +public protocol PooledConnection: AnyObject, Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The connections identifier. The identifier is passed to + /// the connection factory method and must stay attached to + /// the connection at all times. It must not change during + /// the connections lifetime. + var id: ID { get } + + /// A method to register closures that are invoked when the + /// connection is closed. If the connection closed unexpectedly + /// the closure shall be called with the underlying error. + /// In most NIO clients this can be easily implemented by + /// attaching to the `channel.closeFuture`: + /// ``` + /// func onClose( + /// _ closure: @escaping @Sendable ((any Error)?) -> () + /// ) { + /// channel.closeFuture.whenComplete { _ in + /// closure(previousError) + /// } + /// } + /// ``` + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) + + /// Close the running connection. Once the close has completed + /// closures that were registered in `onClose` must be + /// invoked. + func close() +} + +/// A connection id generator. Its returned connection IDs will +/// be used when creating new ``PooledConnection``s +public protocol ConnectionIDGeneratorProtocol: Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The next connection ID that shall be used. + func next() -> ID +} + +/// A keep alive behavior for connections maintained by the pool +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public protocol ConnectionKeepAliveBehavior: Sendable { + /// the connection type + associatedtype Connection: PooledConnection + + /// The time after which a keep-alive shall + /// be triggered. + /// If nil is returned, keep-alive is deactivated + var keepAliveFrequency: Duration? { get } + + /// This method is invoked when the keep-alive shall be + /// run. + func runKeepAlive(for connection: Connection) async throws +} From 358fa598ae6fc2fc1cde213a0d2e8bd1eaf5b2eb Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 13:20:09 +0200 Subject: [PATCH 033/106] Add `ConnectionIDGenerator` and `NoOpKeepAliveBehavior` (#418) --- .../ConnectionIDGenerator.swift | 15 ++++ .../NoKeepAliveBehavior.swift | 8 ++ .../ConnectionIDGeneratorTests.swift | 22 ++++++ .../Mocks/MockConnection.swift | 74 +++++++++++++++++++ .../NoKeepAliveBehaviorTests.swift | 10 +++ Tests/ConnectionPoolModuleTests/gitkeep.swift | 1 - 6 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 Sources/ConnectionPoolModule/ConnectionIDGenerator.swift create mode 100644 Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift create mode 100644 Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift delete mode 100644 Tests/ConnectionPoolModuleTests/gitkeep.swift diff --git a/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift new file mode 100644 index 00000000..b428d805 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift @@ -0,0 +1,15 @@ +import Atomics + +public struct ConnectionIDGenerator: ConnectionIDGeneratorProtocol { + static let globalGenerator = ConnectionIDGenerator() + + private let atomic: ManagedAtomic + + public init() { + self.atomic = .init(0) + } + + public func next() -> Int { + return self.atomic.loadThenWrappingIncrement(ordering: .relaxed) + } +} diff --git a/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift new file mode 100644 index 00000000..0a7b2dee --- /dev/null +++ b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift @@ -0,0 +1,8 @@ +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct NoOpKeepAliveBehavior: ConnectionKeepAliveBehavior { + public var keepAliveFrequency: Duration? { nil } + + public func runKeepAlive(for connection: Connection) async throws {} + + public init(connectionType: Connection.Type) {} +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift new file mode 100644 index 00000000..fb0bfce1 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift @@ -0,0 +1,22 @@ +import _ConnectionPoolModule +import XCTest + +final class ConnectionIDGeneratorTests: XCTestCase { + func testGenerateConnectionIDs() async { + let idGenerator = ConnectionIDGenerator() + + XCTAssertEqual(idGenerator.next(), 0) + XCTAssertEqual(idGenerator.next(), 1) + XCTAssertEqual(idGenerator.next(), 2) + + await withTaskGroup(of: Void.self) { taskGroup in + for _ in 0..<1000 { + taskGroup.addTask { + _ = idGenerator.next() + } + } + } + + XCTAssertEqual(idGenerator.next(), 1003) + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift new file mode 100644 index 00000000..6a8ed297 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -0,0 +1,74 @@ +import DequeModule +@testable import _ConnectionPoolModule + +// Sendability enforced through the lock +final class MockConnection: PooledConnection, @unchecked Sendable { + typealias ID = Int + + let id: ID + + private enum State { + case running([@Sendable ((any Error)?) -> ()]) + case closing([@Sendable ((any Error)?) -> ()]) + case closed + } + + private let lock = NIOLock() + private var _state = State.running([]) + + init(id: Int) { + self.id = id + } + + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + let enqueued = self.lock.withLock { () -> Bool in + switch self._state { + case .closed: + return false + + case .running(var callbacks): + callbacks.append(closure) + self._state = .running(callbacks) + return true + + case .closing(var callbacks): + callbacks.append(closure) + self._state = .closing(callbacks) + return true + } + } + + if !enqueued { + closure(nil) + } + } + + func close() { + self.lock.withLock { + switch self._state { + case .running(let callbacks): + self._state = .closing(callbacks) + + case .closing, .closed: + break + } + } + } + + func closeIfClosing() { + let callbacks = self.lock.withLock { () -> [@Sendable ((any Error)?) -> ()] in + switch self._state { + case .running, .closed: + return [] + + case .closing(let callbacks): + self._state = .closed + return callbacks + } + } + + for callback in callbacks { + callback(nil) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift new file mode 100644 index 00000000..b817ce19 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -0,0 +1,10 @@ +import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class NoKeepAliveBehaviorTests: XCTestCase { + func testNoKeepAlive() { + let keepAliveBehavior = NoOpKeepAliveBehavior(connectionType: MockConnection.self) + XCTAssertNil(keepAliveBehavior.keepAliveFrequency) + } +} diff --git a/Tests/ConnectionPoolModuleTests/gitkeep.swift b/Tests/ConnectionPoolModuleTests/gitkeep.swift deleted file mode 100644 index 8b137891..00000000 --- a/Tests/ConnectionPoolModuleTests/gitkeep.swift +++ /dev/null @@ -1 +0,0 @@ - From f5a04aab09b382e30129b8a86d7284412b549435 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 15:33:30 +0200 Subject: [PATCH 034/106] Add `OneElementFastSequence` to be used in `ConnectionPool` (#420) --- .../OneElementFastSequence.swift | 151 ++++++++++++++++++ .../OneElementFastSequence.swift | 70 ++++++++ 2 files changed, 221 insertions(+) create mode 100644 Sources/ConnectionPoolModule/OneElementFastSequence.swift create mode 100644 Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/OneElementFastSequence.swift new file mode 100644 index 00000000..1bb3b8e4 --- /dev/null +++ b/Sources/ConnectionPoolModule/OneElementFastSequence.swift @@ -0,0 +1,151 @@ +/// A `Sequence` that does not heap allocate, if it only carries a single element +@usableFromInline +struct OneElementFastSequence: Sequence { + @usableFromInline + enum Base { + case none(reserveCapacity: Int) + case one(Element, reserveCapacity: Int) + case n([Element]) + } + + @usableFromInline + private(set) var base: Base + + @inlinable + init() { + self.base = .none(reserveCapacity: 0) + } + + @inlinable + init(_ element: Element) { + self.base = .one(element, reserveCapacity: 1) + } + + @inlinable + init(_ collection: some Collection) { + switch collection.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(collection.first!, reserveCapacity: 0) + default: + if let collection = collection as? Array { + self.base = .n(collection) + } else { + self.base = .n(Array(collection)) + } + } + } + + @usableFromInline + var count: Int { + switch self.base { + case .none: + return 0 + case .one: + return 1 + case .n(let array): + return array.count + } + } + + @inlinable + var first: Element? { + switch self.base { + case .none: + return nil + case .one(let element, _): + return element + case .n(let array): + return array.first + } + } + + @usableFromInline + var isEmpty: Bool { + switch self.base { + case .none: + return true + case .one, .n: + return false + } + } + + @inlinable + mutating func reserveCapacity(_ minimumCapacity: Int) { + switch self.base { + case .none(let reservedCapacity): + self.base = .none(reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .one(let element, let reservedCapacity): + self.base = .one(element, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .n(var array): + self.base = .none(reserveCapacity: 0) // prevent CoW + array.reserveCapacity(minimumCapacity) + self.base = .n(array) + } + } + + @inlinable + mutating func append(_ element: Element) { + switch self.base { + case .none(let reserveCapacity): + self.base = .one(element, reserveCapacity: reserveCapacity) + case .one(let existing, let reserveCapacity): + var new = [Element]() + new.reserveCapacity(reserveCapacity) + new.append(existing) + new.append(element) + self.base = .n(new) + case .n(var existing): + self.base = .none(reserveCapacity: 0) // prevent CoW + existing.append(element) + self.base = .n(existing) + } + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(self) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline private(set) var index: Int = 0 + @usableFromInline private(set) var backing: OneElementFastSequence + + @inlinable + init(_ backing: OneElementFastSequence) { + self.backing = backing + } + + @inlinable + mutating func next() -> Element? { + switch self.backing.base { + case .none: + return nil + case .one(let element, _): + if self.index == 0 { + self.index += 1 + return element + } + return nil + + case .n(let array): + if self.index < array.endIndex { + defer { self.index += 1} + return array[self.index] + } + return nil + } + } + } +} + +extension OneElementFastSequence: Equatable where Element: Equatable {} +extension OneElementFastSequence.Base: Equatable where Element: Equatable {} + +extension OneElementFastSequence: Hashable where Element: Hashable {} +extension OneElementFastSequence.Base: Hashable where Element: Hashable {} + +extension OneElementFastSequence: Sendable where Element: Sendable {} +extension OneElementFastSequence.Base: Sendable where Element: Sendable {} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift new file mode 100644 index 00000000..8098438f --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift @@ -0,0 +1,70 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class OneElementFastSequenceTests: XCTestCase { + func testCountIsEmptyAndIterator() async { + var sequence = OneElementFastSequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + XCTAssertEqual(sequence.first, nil) + XCTAssertEqual(Array(sequence), []) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1]) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2]) + sequence.append(3) + XCTAssertEqual(sequence.count, 3) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2, 3]) + } + + func testReserveCapacityIsForwarded() { + var emptySequence = OneElementFastSequence() + emptySequence.reserveCapacity(8) + emptySequence.append(1) + emptySequence.append(2) + guard case .n(let array) = emptySequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var oneElemSequence = OneElementFastSequence(1) + oneElemSequence.reserveCapacity(8) + oneElemSequence.append(2) + guard case .n(let array) = oneElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var twoElemSequence = OneElementFastSequence([1, 2]) + twoElemSequence.reserveCapacity(8) + guard case .n(let array) = twoElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + } + + func testNewSequenceSlowPath() { + let sequence = OneElementFastSequence("AB".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) + } + + func testSingleItem() { + let sequence = OneElementFastSequence("A".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) + } + + func testEmptyCollection() { + let sequence = OneElementFastSequence("".utf8) + XCTAssertTrue(sequence.isEmpty) + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(Array(sequence), []) + } +} From 5e75c9e24db385870e19578404635891490314bf Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 15:40:42 +0200 Subject: [PATCH 035/106] Add `Max2Sequence` to be used in `ConnectionPool` (#419) --- .../ConnectionPoolModule/Max2Sequence.swift | 95 +++++++++++++++++++ .../Max2SequenceTests.swift | 60 ++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 Sources/ConnectionPoolModule/Max2Sequence.swift create mode 100644 Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift new file mode 100644 index 00000000..6c330067 --- /dev/null +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -0,0 +1,95 @@ +// A `Sequence` that can contain at most two elements. However it does not heap allocate. +@usableFromInline +struct Max2Sequence: Sequence { + @usableFromInline + private(set) var first: Element? + @usableFromInline + private(set) var second: Element? + + @inlinable + var count: Int { + if self.first == nil { return 0 } + if self.second == nil { return 1 } + return 2 + } + + @inlinable + var isEmpty: Bool { + self.first == nil + } + + @inlinable + init(_ first: Element?, _ second: Element? = nil) { + if let first = first { + self.first = first + self.second = second + } else { + self.first = second + self.second = nil + } + } + + @inlinable + init() { + self.first = nil + self.second = nil + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(first: self.first, second: self.second) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline + let first: Element? + @usableFromInline + let second: Element? + + @usableFromInline + private(set) var index: UInt8 = 0 + + @inlinable + init(first: Element?, second: Element?) { + self.first = first + self.second = second + self.index = 0 + } + + @inlinable + mutating func next() -> Element? { + switch self.index { + case 0: + self.index += 1 + return self.first + case 1: + self.index += 1 + return self.second + default: + return nil + } + } + } + + @inlinable + mutating func append(_ element: Element) { + precondition(self.second == nil) + if self.first == nil { + self.first = element + } else if self.second == nil { + self.second = element + } else { + fatalError("Max2Sequence can only hold two Elements.") + } + } + + @inlinable + func map(_ transform: (Element) throws -> (NewElement)) rethrows -> Max2Sequence { + try Max2Sequence(self.first.flatMap(transform), self.second.flatMap(transform)) + } +} + +extension Max2Sequence: Equatable where Element: Equatable {} +extension Max2Sequence: Hashable where Element: Hashable {} +extension Max2Sequence: Sendable where Element: Sendable {} diff --git a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift new file mode 100644 index 00000000..081e867b --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift @@ -0,0 +1,60 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class Max2SequenceTests: XCTestCase { + func testCountAndIsEmpty() async { + var sequence = Max2Sequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + } + + func testOptionalInitializer() { + let emptySequence = Max2Sequence(nil, nil) + XCTAssertEqual(emptySequence.count, 0) + XCTAssertEqual(emptySequence.isEmpty, true) + var emptySequenceIterator = emptySequence.makeIterator() + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + + let oneElemSequence1 = Max2Sequence(1, nil) + XCTAssertEqual(oneElemSequence1.count, 1) + XCTAssertEqual(oneElemSequence1.isEmpty, false) + var oneElemSequence1Iterator = oneElemSequence1.makeIterator() + XCTAssertEqual(oneElemSequence1Iterator.next(), 1) + XCTAssertNil(oneElemSequence1Iterator.next()) + XCTAssertNil(oneElemSequence1Iterator.next()) + + let oneElemSequence2 = Max2Sequence(nil, 2) + XCTAssertEqual(oneElemSequence2.count, 1) + XCTAssertEqual(oneElemSequence2.isEmpty, false) + var oneElemSequence2Iterator = oneElemSequence2.makeIterator() + XCTAssertEqual(oneElemSequence2Iterator.next(), 2) + XCTAssertNil(oneElemSequence2Iterator.next()) + XCTAssertNil(oneElemSequence2Iterator.next()) + + let twoElemSequence = Max2Sequence(1, 2) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), 1) + XCTAssertEqual(twoElemSequenceIterator.next(), 2) + XCTAssertNil(twoElemSequenceIterator.next()) + } + + func testMap() { + let twoElemSequence = Max2Sequence(1, 2).map({ "\($0)" }) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), "1") + XCTAssertEqual(twoElemSequenceIterator.next(), "2") + XCTAssertNil(twoElemSequenceIterator.next()) + } +} From a57baa7f7233646449f1fde2d3fd5670de7df870 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 17 Oct 2023 12:59:39 +0200 Subject: [PATCH 036/106] Add `ConnectionRequestProtocol`, `ConnectionPoolError` and `ConnectionPoolConfiguration` (#421) --- .../ConnectionPoolModule/ConnectionPool.swift | 58 +++++++++++++++++++ .../ConnectionPoolError.swift | 16 +++++ .../ConnectionRequest.swift | 20 +++++++ 3 files changed, 94 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPoolError.swift create mode 100644 Sources/ConnectionPoolModule/ConnectionRequest.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 290e0679..825c3ab3 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -56,3 +56,61 @@ public protocol ConnectionKeepAliveBehavior: Sendable { /// run. func runKeepAlive(for connection: Connection) async throws } + +/// A request to get a connection from the `ConnectionPool` +public protocol ConnectionRequestProtocol: Sendable { + /// A connection lease request ID type. + associatedtype ID: Hashable & Sendable + /// The leased connection type + associatedtype Connection: PooledConnection + + /// A connection lease request ID. This ID must be generated + /// by users of the `ConnectionPool` outside the + /// `ConnectionPool`. It is not generated inside the pool like + /// the `ConnectionID`s. The lease request ID must be unique + /// and must not change, if your implementing type is a + /// reference type. + var id: ID { get } + + /// A function that is called with a connection or a + /// `PoolError`. + func complete(with: Result) +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionPoolConfiguration { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes + /// idle connections, + /// the `ConnectionPool` will initiate new outbound + /// connections proactively to avoid the number of available + /// connections dropping below this number. + public var minimumConnectionCount: Int + + /// Between the `minimumConnectionCount` and + /// `maximumConnectionSoftLimit` the connection pool creates + /// _preserved_ connections. Preserved connections are closed + /// if they have been idle for ``idleTimeout``. + public var maximumConnectionSoftLimit: Int + + /// The maximum number of connections for this pool, that can + /// exist at any point in time. The pool can create _overflow_ + /// connections, if all connections are leased, and the + /// `maximumConnectionHardLimit` > `maximumConnectionSoftLimit ` + /// Overflow connections are closed immediately as soon as they + /// become idle. + public var maximumConnectionHardLimit: Int + + /// The time that a _preserved_ idle connection stays in the + /// pool before it is closed. + public var idleTimeout: Duration + + /// initializer + public init() { + self.minimumConnectionCount = 0 + self.maximumConnectionSoftLimit = 16 + self.maximumConnectionHardLimit = 16 + self.idleTimeout = .seconds(60) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift new file mode 100644 index 00000000..1f1e1d2c --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -0,0 +1,16 @@ + +public struct ConnectionPoolError: Error, Hashable { + enum Base: Error, Hashable { + case requestCancelled + case poolShutdown + } + + private let base: Base + + init(_ base: Base) { self.base = base } + + /// The connection requests got cancelled + public static let requestCancelled = ConnectionPoolError(.requestCancelled) + /// The connection requests can't be fulfilled as the pool has already been shutdown + public static let poolShutdown = ConnectionPoolError(.poolShutdown) +} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift new file mode 100644 index 00000000..34b77084 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -0,0 +1,20 @@ + +public struct ConnectionRequest: ConnectionRequestProtocol { + public typealias ID = Int + + public var id: ID + + private var continuation: CheckedContinuation + + init( + id: Int, + continuation: CheckedContinuation + ) { + self.id = id + self.continuation = continuation + } + + public func complete(with result: Result) { + self.continuation.resume(with: result) + } +} From c80a9347024892434d7c214eab8d194ee3a71bc0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 17 Oct 2023 20:39:37 +0200 Subject: [PATCH 037/106] Add `ConnectionPoolObservabilityDelegate` (#422) --- .../ConnectionPoolObservabilityDelegate.swift | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift new file mode 100644 index 00000000..35f30dcb --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -0,0 +1,62 @@ + +public protocol ConnectionPoolObservabilityDelegate: Sendable { + associatedtype ConnectionID: Hashable & Sendable + + /// The connection with the given ID has started trying to establish a connection. The outcome + /// of the connection will be reported as either ``connectSucceeded(id:streamCapacity:)`` or + /// ``connectFailed(id:error:)``. + func startedConnecting(id: ConnectionID) + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) + + /// A connection was established on the connection with the given ID. `streamCapacity` streams are + /// available to use on the connection. The maximum number of available streams may change over + /// time and is reported via ````. The + func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionUtilizationChanged(id:ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) + + func keepAliveTriggered(id: ConnectionID) + + func keepAliveSucceeded(id: ConnectionID) + + func keepAliveFailed(id: ConnectionID, error: Error) + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) + + func requestQueueDepthChanged(_ newDepth: Int) +} + +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { + public init(connectionIDType: ConnectionID.Type) {} + + public func startedConnecting(id: ConnectionID) {} + + public func connectFailed(id: ConnectionID, error: Error) {} + + public func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) {} + + public func connectionUtilizationChanged(id: ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) {} + + public func keepAliveTriggered(id: ConnectionID) {} + + public func keepAliveSucceeded(id: ConnectionID) {} + + public func keepAliveFailed(id: ConnectionID, error: Error) {} + + public func connectionClosing(id: ConnectionID) {} + + public func connectionClosed(id: ConnectionID, error: Error?) {} + + public func requestQueueDepthChanged(_ newDepth: Int) {} +} From 8babbcff00e879173779f0d59b3fa413af4282c9 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Wed, 18 Oct 2023 16:08:07 +0330 Subject: [PATCH 038/106] Fix `PostgresDecodable` inference for `RawRepresentable` enums (#423) --- .../RawRepresentable+PostgresCodable.swift | 2 +- .../PSQLIntegrationTests.swift | 39 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index 4d6c20c4..ea097963 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -19,7 +19,7 @@ extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEnco } extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDecodable, RawValue._DecodableType == RawValue { - init( + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 4b2b9950..0550dc77 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -1,6 +1,6 @@ import XCTest import Logging -@testable import PostgresNIO +import PostgresNIO import NIOCore import NIOPosix import NIOTestUtils @@ -252,7 +252,7 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(result = try conn?.query(""" SELECT \(Decimal(string: "123456.789123")!)::numeric as numeric, - \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative + \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative """, logger: .psqlTest).wait()) XCTAssertEqual(result?.rows.count, 1) @@ -263,6 +263,41 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(cells?.1, Decimal(string: "-123456.789123")) } + func testDecodeRawRepresentables() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + enum StringRR: String, PostgresDecodable { + case a + } + + enum IntRR: Int, PostgresDecodable { + case b + } + + let stringValue = StringRR.a + let intValue = IntRR.b + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + \(stringValue.rawValue)::varchar as string, + \(intValue.rawValue)::int8 as int + """, logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + + var cells: (StringRR, IntRR)? + XCTAssertNoThrow(cells = try result?.rows.first?.decode((StringRR, IntRR).self, context: .default)) + + XCTAssertEqual(cells?.0, stringValue) + XCTAssertEqual(cells?.1, intValue) + } + func testRoundTripUUID() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } From 56419669833c265c4096df5341ae22f5753849cd Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 18 Oct 2023 15:46:13 +0200 Subject: [PATCH 039/106] Add `PoolStateMachine.RequestQueue` (#424) --- .../ConnectionRequest.swift | 6 +- .../OneElementFastSequence.swift | 2 +- .../PoolStateMachine+RequestQueue.swift | 71 +++++++++ .../PoolStateMachine.swift | 74 +++++++++ .../ConnectionRequestTests.swift | 27 ++++ .../Mocks/MockRequest.swift | 28 ++++ .../Mocks/MockTimerCancellationToken.swift | 16 ++ .../OneElementFastSequence.swift | 2 +- .../PoolStateMachine+RequestQueueTests.swift | 147 ++++++++++++++++++ .../PoolStateMachineTests.swift | 14 ++ 10 files changed, 383 insertions(+), 4 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 34b77084..fd01bb76 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -4,11 +4,13 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID - private var continuation: CheckedContinuation + @usableFromInline + private(set) var continuation: CheckedContinuation + @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation ) { self.id = id self.continuation = continuation diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/OneElementFastSequence.swift index 1bb3b8e4..3c3bfaa0 100644 --- a/Sources/ConnectionPoolModule/OneElementFastSequence.swift +++ b/Sources/ConnectionPoolModule/OneElementFastSequence.swift @@ -17,7 +17,7 @@ struct OneElementFastSequence: Sequence { } @inlinable - init(_ element: Element) { + init(element: Element) { self.base = .one(element, reserveCapacity: 1) } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift new file mode 100644 index 00000000..7e3c6607 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -0,0 +1,71 @@ +import DequeModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + /// A request queue, which can enqueue requests in O(1), dequeue requests in O(1) and even cancel requests in O(1). + /// + /// While enqueueing and dequeueing on O(1) is trivial, cancellation is hard, as it normally requires a removal within the + /// underlying Deque. However thanks to having an additional `requests` dictionary, we can remove the cancelled + /// request from the dictionary and keep it inside the queue. Whenever we pop a request from the deque, we validate + /// that it hasn't been cancelled in the meantime by checking if the popped request is still in the `requests` dictionary. + @usableFromInline + struct RequestQueue { + @usableFromInline + private(set) var queue: Deque + + @usableFromInline + private(set) var requests: [RequestID: Request] + + @inlinable + var count: Int { + self.requests.count + } + + @inlinable + var isEmpty: Bool { + self.count == 0 + } + + @usableFromInline + init() { + self.queue = .init(minimumCapacity: 256) + self.requests = .init(minimumCapacity: 256) + } + + @inlinable + mutating func queue(_ request: Request) { + self.requests[request.id] = request + self.queue.append(request.id) + } + + @inlinable + mutating func pop(max: UInt16) -> OneElementFastSequence { + var result = OneElementFastSequence() + result.reserveCapacity(Int(max)) + var popped = 0 + while let requestID = self.queue.popFirst(), popped < max { + if let requestIndex = self.requests.index(forKey: requestID) { + popped += 1 + result.append(self.requests.remove(at: requestIndex).value) + } + } + + assert(result.count <= max) + return result + } + + @inlinable + mutating func remove(_ requestID: RequestID) -> Request? { + self.requests.removeValue(forKey: requestID) + } + + @inlinable + mutating func removeAll() -> OneElementFastSequence { + let result = OneElementFastSequence(self.requests.values) + self.requests.removeAll() + self.queue.removeAll() + return result + } + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift new file mode 100644 index 00000000..a3962790 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -0,0 +1,74 @@ +#if canImport(Darwin) +import Darwin +#else +import Glibc +#endif + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolConfiguration { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes idle connections, + /// the `ConnectionPool` will initiate new outbound connections proactively + /// to avoid the number of available connections dropping below this number. + @usableFromInline + var minimumConnectionCount: Int = 0 + + /// The maximum number of connections to for this pool, to be preserved. + @usableFromInline + var maximumConnectionSoftLimit: Int = 10 + + @usableFromInline + var maximumConnectionHardLimit: Int = 10 + + @usableFromInline + var keepAliveDuration: Duration? + + @usableFromInline + var idleTimeoutDuration: Duration = .seconds(30) +} + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolStateMachine< + Connection: PooledConnection, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + ConnectionID: Hashable & Sendable, + Request: ConnectionRequestProtocol, + RequestID, + TimerCancellationToken +> where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + + @usableFromInline + struct Timer: Hashable, Sendable { + @usableFromInline + enum Usecase: Sendable { + case backoff + case idleTimeout + case keepAlive + } + + @usableFromInline + var connectionID: ConnectionID + + @usableFromInline + var timerID: Int + + @usableFromInline + var duration: Duration + + @usableFromInline + var usecase: Usecase + + @inlinable + init(connectionID: ConnectionID, timerID: Int, duration: Duration, usecase: Usecase) { + self.connectionID = connectionID + self.timerID = timerID + self.duration = duration + self.usecase = usecase + } + } + + +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift new file mode 100644 index 00000000..5845267f --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -0,0 +1,27 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class ConnectionRequestTests: XCTestCase { + + func testHappyPath() async throws { + let mockConnection = MockConnection(id: 1) + let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = ConnectionRequest(id: 42, continuation: continuation) + XCTAssertEqual(request.id, 42) + continuation.resume(with: .success(mockConnection)) + } + + XCTAssert(connection === mockConnection) + } + + func testSadPath() async throws { + do { + _ = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + continuation.resume(with: .failure(ConnectionPoolError.requestCancelled)) + } + XCTFail("This point should not be reached") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .requestCancelled) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift new file mode 100644 index 00000000..6aaa9c91 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift @@ -0,0 +1,28 @@ +import _ConnectionPoolModule + +final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + typealias Connection = MockConnection + + struct ID: Hashable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + var id: ID { ID(self) } + + + static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + func complete(with: Result) { + + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift new file mode 100644 index 00000000..20434450 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift @@ -0,0 +1,16 @@ +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct MockTimerCancellationToken: Hashable, Sendable { + var connectionID: MockConnection.ID + var timerID: Int + var duration: Duration + var usecase: TestPoolStateMachine.Timer.Usecase + + init(_ timer: TestPoolStateMachine.Timer) { + self.connectionID = timer.connectionID + self.timerID = timer.timerID + self.duration = timer.duration + self.usecase = timer.usecase + } +} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift index 8098438f..a086341e 100644 --- a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift @@ -35,7 +35,7 @@ final class OneElementFastSequenceTests: XCTestCase { } XCTAssertEqual(array.capacity, 8) - var oneElemSequence = OneElementFastSequence(1) + var oneElemSequence = OneElementFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) guard case .n(let array) = oneElemSequence.base else { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift new file mode 100644 index 00000000..0231da51 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -0,0 +1,147 @@ +@testable import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_RequestQueueTests: XCTestCase { + + typealias TestQueue = TestPoolStateMachine.RequestQueue + + func testHappyPath() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + let request1 = MockRequest() + queue.queue(request1) + XCTAssertEqual(queue.count, 1) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + func testEnqueueAndPopMultipleRequests() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request2, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testEnqueueAndPopOnlyOne() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 1) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssertFalse(queue.isEmpty) + XCTAssertEqual(queue.count, 2) + + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request2, request3]) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testRemoveAllAfterCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request1, request3]) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift new file mode 100644 index 00000000..ee8cfdc6 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -0,0 +1,14 @@ +import NIOCore +import NIOEmbedded +import XCTest +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +typealias TestPoolStateMachine = PoolStateMachine< + MockConnection, + ConnectionIDGenerator, + MockConnection.ID, + MockRequest, + MockRequest.ID, + MockTimerCancellationToken +> From 20a8c340ed4984b6c85aabd27a38fa5b2d780ee0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 18 Oct 2023 22:37:54 +0200 Subject: [PATCH 040/106] Add `PoolStateMachine.ConnectionState` (#425) --- .../ConnectionPoolModule/Max2Sequence.swift | 10 + .../PoolStateMachine+ConnectionState.swift | 584 ++++++++++++++++++ .../PoolStateMachine.swift | 53 +- .../Mocks/MockTimerCancellationToken.swift | 18 +- ...oolStateMachine+ConnectionStateTests.swift | 264 ++++++++ 5 files changed, 904 insertions(+), 25 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift index 6c330067..0feccd68 100644 --- a/Sources/ConnectionPoolModule/Max2Sequence.swift +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -90,6 +90,16 @@ struct Max2Sequence: Sequence { } } +extension Max2Sequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + precondition(elements.count <= 2) + var iterator = elements.makeIterator() + self.first = iterator.next() + self.second = iterator.next() + } +} + extension Max2Sequence: Equatable where Element: Equatable {} extension Max2Sequence: Hashable where Element: Hashable {} extension Max2Sequence: Sendable where Element: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift new file mode 100644 index 00000000..51ab5323 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -0,0 +1,584 @@ +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + @usableFromInline + struct KeepAliveAction { + @usableFromInline + var connection: Connection + @usableFromInline + var keepAliveTimerCancellationContinuation: TimerCancellationToken? + + @inlinable + init(connection: Connection, keepAliveTimerCancellationContinuation: TimerCancellationToken? = nil) { + self.connection = connection + self.keepAliveTimerCancellationContinuation = keepAliveTimerCancellationContinuation + } + } + + @usableFromInline + struct ConnectionTimer: Hashable, Sendable { + @usableFromInline + enum Usecase: Hashable, Sendable { + case backoff + case keepAlive + case idleTimeout + } + + @usableFromInline + var timerID: Int + + @usableFromInline + var connectionID: Connection.ID + + @usableFromInline + var usecase: Usecase + + @inlinable + init(timerID: Int, connectionID: Connection.ID, usecase: Usecase) { + self.timerID = timerID + self.connectionID = connectionID + self.usecase = usecase + } + } + + @usableFromInline + /// An connection state machine about the pool's view on the connection. + struct ConnectionState { + @usableFromInline + enum State { + @usableFromInline + enum KeepAlive { + case notScheduled + case scheduled(Timer) + case running(_ consumingStream: Bool) + + @inlinable + var usedStreams: UInt16 { + switch self { + case .notScheduled, .scheduled, .running(false): + return 0 + case .running(true): + return 1 + } + } + + @inlinable + var isRunning: Bool { + switch self { + case .running: + return true + case .notScheduled, .scheduled: + return false + } + } + + @inlinable + mutating func cancelTimerIfScheduled() -> TimerCancellationToken? { + switch self { + case .scheduled(let timer): + self = .notScheduled + return timer.cancellationContinuation + case .running, .notScheduled: + return nil + } + } + } + + @usableFromInline + struct Timer { + @usableFromInline + let timerID: Int + + @usableFromInline + private(set) var cancellationContinuation: TimerCancellationToken? + + @inlinable + init(id: Int) { + self.timerID = id + self.cancellationContinuation = nil + } + + @inlinable + mutating func registerCancellationContinuation(_ continuation: TimerCancellationToken) { + precondition(self.cancellationContinuation == nil) + self.cancellationContinuation = continuation + } + } + + /// The pool is creating a connection. Valid transitions are to: `.backingOff`, `.idle`, and `.closed` + case starting + /// The pool is waiting to retry establishing a connection. Valid transitions are to: `.closed`. + /// This means, the connection can be removed from the connections without cancelling external + /// state. The connection state can then be replaced by a new one. + case backingOff(Timer) + /// The connection is `idle` and ready to execute a new query. Valid transitions to: `.pingpong`, `.leased`, + /// `.closing` and `.closed` + case idle(Connection, maxStreams: UInt16, keepAlive: KeepAlive, idleTimer: Timer?) + /// The connection is leased and executing a query. Valid transitions to: `.idle` and `.closed` + case leased(Connection, usedStreams: UInt16, maxStreams: UInt16, keepAlive: KeepAlive) + /// The connection is closing. Valid transitions to: `.closed` + case closing(Connection) + /// The connection is closed. Final state. + case closed + } + + @usableFromInline + let id: Connection.ID + + @usableFromInline + private(set) var state: State = .starting + + @usableFromInline + private(set) var nextTimerID: Int = 0 + + @inlinable + init(id: Connection.ID) { + self.id = id + } + + @inlinable + var isIdle: Bool { + switch self.state { + case .idle(_, _, .notScheduled, _), .idle(_, _, .scheduled, _): + return true + case .idle(_, _, .running, _): + return false + case .backingOff, .starting, .closed, .closing, .leased: + return false + } + } + + @inlinable + var isAvailable: Bool { + switch self.state { + case .idle(_, let maxStreams, .running(true), _): + return maxStreams > 1 + case .idle(_, let maxStreams, let keepAlive, _): + return keepAlive.usedStreams < maxStreams + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + return usedStreams + keepAlive.usedStreams < maxStreams + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @usableFromInline + var isLeased: Bool { + switch self.state { + case .leased: + return true + case .backingOff, .starting, .closed, .closing, .idle: + return false + } + } + + @usableFromInline + var isIdleOrRunningKeepAlive: Bool { + switch self.state { + case .idle: + return true + case .backingOff, .starting, .closed, .closing, .leased: + return false + } + } + + @usableFromInline + var isConnected: Bool { + switch self.state { + case .idle, .leased: + return true + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @inlinable + mutating func connected(_ connection: Connection, maxStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .starting: + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: nil) + return .idle(availableStreams: maxStreams, newIdle: true) + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func parkConnection(scheduleKeepAliveTimer: Bool, scheduleIdleTimeoutTimer: Bool) -> Max2Sequence { + var keepAliveTimer: ConnectionTimer? + var keepAliveTimerState: State.Timer? + var idleTimer: ConnectionTimer? + var idleTimerState: State.Timer? + + switch self.state { + case .backingOff, .starting, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let maxStreams, .notScheduled, .none): + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self.nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .idleTimeout) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, .scheduled, .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + + case .idle(let connection, let maxStreams, .notScheduled, let idleTimer): + precondition(!scheduleIdleTimeoutTimer) + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self.nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return Max2Sequence(keepAliveTimer) + + case .idle(let connection, let maxStreams, .scheduled(let keepAliveTimer), .none): + precondition(!scheduleKeepAliveTimer) + + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimer), idleTimer: idleTimerState) + return Max2Sequence(idleTimer, nil) + + case .idle(let connection, let maxStreams, keepAlive: .running(let usingStream), idleTimer: .none): + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(usingStream), idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, keepAlive: .running(_), idleTimer: .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + } + } + + @inlinable + mutating func nextTimer() -> State.Timer { + defer { self.nextTimerID += 1 } + return State.Timer(id: self.nextTimerID) + } + + /// The connection failed to start + @inlinable + mutating func failedToConnect() -> ConnectionTimer { + switch self.state { + case .starting: + let backoffTimerState = self.nextTimer() + self.state = .backingOff(backoffTimerState) + return ConnectionTimer(timerID: backoffTimerState.timerID, connectionID: self.id, usecase: .backoff) + + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + /// Moves a connection, that has previously ``failedToConnect()`` back into the connecting state. + /// + /// - Returns: A ``TimerCancellationToken`` that was previously registered with the state machine + /// for the ``ConnectionTimer`` returned in ``failedToConnect()``. If no token was registered + /// nil is returned. + @inlinable + mutating func retryConnect() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .starting + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct LeaseAction { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence, wasIdle: Bool) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + } + } + + @inlinable + mutating func lease(streams newLeasedStreams: UInt16 = 1) -> LeaseAction { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimer): + var cancel = Max2Sequence() + if let token = idleTimer?.cancellationContinuation { + cancel.append(token) + } + if let token = keepAlive.cancelTimerIfScheduled() { + cancel.append(token) + } + precondition(maxStreams >= newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: cancel, wasIdle: true) + + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(maxStreams >= usedStreams + newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: usedStreams + newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: .init(), wasIdle: false) + + case .backingOff, .starting, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func release(streams returnedStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(usedStreams >= returnedStreams) + let newUsedStreams = usedStreams - returnedStreams + let availableStreams = maxStreams - (newUsedStreams + keepAlive.usedStreams) + if newUsedStreams == 0 { + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return .idle(availableStreams: availableStreams, newIdle: true) + } else { + self.state = .leased(connection, usedStreams: newUsedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return .leased(availableStreams: availableStreams) + } + case .backingOff, .starting, .idle, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func runKeepAliveIfIdle(reducesAvailableStreams: Bool) -> KeepAliveAction? { + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(let timer), let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(reducesAvailableStreams), idleTimer: idleTimer) + return KeepAliveAction( + connection: connection, + keepAliveTimerCancellationContinuation: timer.cancellationContinuation + ) + + case .leased, .closed, .closing: + return nil + + case .backingOff, .starting, .idle(_, _, .running, _), .idle(_, _, .notScheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func keepAliveSucceeded() -> ConnectionAvailableInfo? { + switch self.state { + case .idle(let connection, let maxStreams, .running, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: idleTimer) + return .idle(availableStreams: maxStreams, newIdle: false) + + case .leased(let connection, let usedStreams, let maxStreams, .running): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: maxStreams, keepAlive: .notScheduled) + return .leased(availableStreams: maxStreams - usedStreams) + + case .closed, .closing: + return nil + + case .backingOff, .starting, + .leased(_, _, _, .notScheduled), + .leased(_, _, _, .scheduled), + .idle(_, _, .notScheduled, _), + .idle(_, _, .scheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + switch timer.usecase { + case .backoff: + switch self.state { + case .backingOff(var timerState): + if timerState.timerID == timer.timerID { + timerState.registerCancellationContinuation(cancelContinuation) + self.state = .backingOff(timerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .idle, .leased, .closing, .closed: + return cancelContinuation + } + + case .idleTimeout: + switch self.state { + case .idle(let connection, let maxStreams, let keepAlive, let idleTimerState): + if var idleTimerState = idleTimerState, idleTimerState.timerID == timer.timerID { + idleTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed: + return cancelContinuation + } + + case .keepAlive: + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(var keepAliveTimerState), let idleTimerState): + if keepAliveTimerState.timerID == timer.timerID { + keepAliveTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimerState), idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed, + .idle(_, _, .running, _), + .idle(_, _, .notScheduled, _): + return cancelContinuation + } + } + } + + @usableFromInline + struct CloseAction { + @usableFromInline + var connection: Connection + @usableFromInline + var cancelTimers: Max2Sequence + @usableFromInline + var maxStreams: UInt16 + + @inlinable + init(connection: Connection, cancelTimers: Max2Sequence, maxStreams: UInt16) { + self.connection = connection + self.cancelTimers = cancelTimers + self.maxStreams = maxStreams + } + } + + @inlinable + mutating func close() -> CloseAction { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + maxStreams: maxStreams + ) + + case .backingOff, .starting, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func closeIfIdle() -> CloseAction? { + switch self.state { + case .idle: + return self.close() + case .leased, .closed: + return nil + case .backingOff, .starting, .closing: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct ShutdownAction { + @usableFromInline + var connection: Connection? + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var maxStreams: UInt16 + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init( + connection: Connection? = nil, + timersToCancel: Max2Sequence = .init(), + maxStreams: UInt16 = 0, + usedStreams: UInt16 = 0 + ) { + self.connection = connection + self.timersToCancel = timersToCancel + self.maxStreams = maxStreams + self.usedStreams = usedStreams + } + } + } + + @usableFromInline + enum ConnectionAvailableInfo: Equatable { + case leased(availableStreams: UInt16) + case idle(availableStreams: UInt16, newIdle: Bool) + + @usableFromInline + var availableStreams: UInt16 { + switch self { + case .leased(let availableStreams): + return availableStreams + case .idle(let availableStreams, newIdle: _): + return availableStreams + } + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.KeepAliveAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.connection === rhs.connection && lhs.keepAliveTimerCancellationContinuation == rhs.keepAliveTimerCancellationContinuation + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.LeaseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.wasIdle == rhs.wasIdle && lhs.connection === rhs.connection && lhs.timersToCancel == rhs.timersToCancel + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.CloseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.cancelTimers == rhs.cancelTimers && lhs.connection === rhs.connection && lhs.maxStreams == rhs.maxStreams + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index a3962790..dc18784f 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -39,36 +39,55 @@ struct PoolStateMachine< RequestID, TimerCancellationToken > where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + + @usableFromInline + struct ConnectionRequest: Equatable { + @usableFromInline var connectionID: ConnectionID + + @inlinable + init(connectionID: ConnectionID) { + self.connectionID = connectionID + } + } @usableFromInline - struct Timer: Hashable, Sendable { + enum ConnectionAction { @usableFromInline - enum Usecase: Sendable { - case backoff - case idleTimeout - case keepAlive + struct Shutdown { + @usableFromInline + var connections: [Connection] + @usableFromInline + var timersToCancel: [TimerCancellationToken] + + @inlinable + init() { + self.connections = [] + self.timersToCancel = [] + } } - @usableFromInline - var connectionID: ConnectionID + case scheduleTimers(Max2Sequence) + case makeConnection(ConnectionRequest, TimerCancellationToken?) + case runKeepAlive(Connection, TimerCancellationToken?) + case cancelTimers(Max2Sequence) + case closeConnection(Connection) + case shutdown(Shutdown) - @usableFromInline - var timerID: Int + case none + } + @usableFromInline + struct Timer: Hashable, Sendable { @usableFromInline - var duration: Duration + var underlying: ConnectionTimer @usableFromInline - var usecase: Usecase + var duration: Duration @inlinable - init(connectionID: ConnectionID, timerID: Int, duration: Duration, usecase: Usecase) { - self.connectionID = connectionID - self.timerID = timerID + init(_ connectionTimer: ConnectionTimer, duration: Duration) { + self.underlying = connectionTimer self.duration = duration - self.usecase = usecase } } - - } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift index 20434450..27035ee9 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift @@ -2,15 +2,17 @@ @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) struct MockTimerCancellationToken: Hashable, Sendable { - var connectionID: MockConnection.ID - var timerID: Int - var duration: Duration - var usecase: TestPoolStateMachine.Timer.Usecase + enum Backing: Hashable, Sendable { + case timer(TestPoolStateMachine.Timer) + case connectionTimer(TestPoolStateMachine.ConnectionTimer) + } + var backing: Backing init(_ timer: TestPoolStateMachine.Timer) { - self.connectionID = timer.connectionID - self.timerID = timer.timerID - self.duration = timer.duration - self.usecase = timer.usecase + self.backing = .timer(timer) + } + + init(_ timer: TestPoolStateMachine.ConnectionTimer) { + self.backing = .connectionTimer(timer) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift new file mode 100644 index 00000000..b1622d0d --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -0,0 +1,264 @@ +@testable import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionStateTests: XCTestCase { + + typealias TestConnectionState = TestPoolStateMachine.ConnectionState + + func testStartupLeaseReleaseParkLease() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + XCTAssertEqual(state.id, connectionID) + XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, false) + XCTAssertEqual(state.isLeased, false) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(state.isIdleOrRunningKeepAlive, true) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, false) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, true) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + let expectLeaseAction = TestConnectionState.LeaseAction( + connection: connection, + timersToCancel: [idleTimerCancellationToken, keepAliveTimerCancellationToken], + wasIdle: true + ) + XCTAssertEqual(state.lease(streams: 1), expectLeaseAction) + } + + func testStartupParkLeaseBeforeTimersRegistered() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssertEqual( + parkResult, + [ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken), keepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken), idleTimerCancellationToken) + } + + func testStartupParkLeasePark() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let initialKeepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let initialIdleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual( + state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true), + [ + .init(timerID: 2, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 3, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken), initialKeepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken), initialIdleTimerCancellationToken) + } + + func testStartupFailed() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let firstBackoffTimer = state.failedToConnect() + let firstBackoffTimerCancellationToken = MockTimerCancellationToken(firstBackoffTimer) + XCTAssertNil(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken)) + XCTAssertEqual(state.retryConnect(), firstBackoffTimerCancellationToken) + + let secondBackoffTimer = state.failedToConnect() + let secondBackoffTimerCancellationToken = MockTimerCancellationToken(secondBackoffTimer) + XCTAssertNil(state.retryConnect()) + XCTAssertEqual( + state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken), + secondBackoffTimerCancellationToken + ) + + let thirdBackoffTimer = state.failedToConnect() + let thirdBackoffTimerCancellationToken = MockTimerCancellationToken(thirdBackoffTimer) + XCTAssertNil(state.retryConnect()) + let forthBackoffTimer = state.failedToConnect() + let forthBackoffTimerCancellationToken = MockTimerCancellationToken(forthBackoffTimer) + XCTAssertEqual( + state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken), + thirdBackoffTimerCancellationToken + ) + XCTAssertNil( + state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) + ) + XCTAssertEqual(state.retryConnect(), forthBackoffTimerCancellationToken) + + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + } + + func testLeaseMultipleStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [keepAliveTimerCancellationToken], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual(state.release(streams: 1), .leased(availableStreams: 1)) + XCTAssertEqual(state.release(streams: 98), .leased(availableStreams: 99)) + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 100, newIdle: true)) + } + + func testRunningKeepAliveReducesAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: true), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 79)) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual( + state.lease(streams: 79), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 1)) + XCTAssertEqual(state.isAvailable, true) + } + + func testRunningKeepAliveDoesNotReduceAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: false), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 80)) + } + + func testRunKeepAliveRacesAgainstIdleClose() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + XCTAssertEqual(keepAliveTimer, .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) + XCTAssertEqual(idleTimer, .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], maxStreams: 1)) + XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) + + } +} From 17d3c80e7739c781254c1883bd9e8fd6c113b1c1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 23 Oct 2023 11:20:32 +0200 Subject: [PATCH 041/106] Add `PoolStateMachine.ConnectionGroup` (#425) (#426) --- .../ConnectionPoolModule/Max2Sequence.swift | 3 +- .../PoolStateMachine+ConnectionGroup.swift | 640 ++++++++++++++++++ .../PoolStateMachine+ConnectionState.swift | 218 ++++-- .../PoolStateMachine.swift | 2 +- ...oolStateMachine+ConnectionGroupTests.swift | 294 ++++++++ ...oolStateMachine+ConnectionStateTests.swift | 8 +- 6 files changed, 1113 insertions(+), 52 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift index 0feccd68..9b7d972b 100644 --- a/Sources/ConnectionPoolModule/Max2Sequence.swift +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -95,8 +95,7 @@ extension Max2Sequence: ExpressibleByArrayLiteral { init(arrayLiteral elements: Element...) { precondition(elements.count <= 2) var iterator = elements.makeIterator() - self.first = iterator.next() - self.second = iterator.next() + self.init(iterator.next(), iterator.next()) } } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift new file mode 100644 index 00000000..8ec99c7d --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -0,0 +1,640 @@ +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + @usableFromInline + struct LeaseResult { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + @usableFromInline + var use: ConnectionGroup.ConnectionUse + + @inlinable + init( + connection: Connection, + timersToCancel: Max2Sequence, + wasIdle: Bool, + use: ConnectionGroup.ConnectionUse + ) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + self.use = use + } + } + + @usableFromInline + struct ConnectionGroup: Sendable { + @usableFromInline + struct Stats: Hashable, Sendable { + @usableFromInline var connecting: UInt16 = 0 + @usableFromInline var backingOff: UInt16 = 0 + @usableFromInline var idle: UInt16 = 0 + @usableFromInline var leased: UInt16 = 0 + @usableFromInline var runningKeepAlive: UInt16 = 0 + @usableFromInline var closing: UInt16 = 0 + + @usableFromInline var availableStreams: UInt16 = 0 + @usableFromInline var leasedStreams: UInt16 = 0 + + @usableFromInline var soonAvailable: UInt16 { + self.connecting + self.backingOff + self.runningKeepAlive + } + + @usableFromInline var active: UInt16 { + self.idle + self.leased + self.connecting + self.backingOff + } + } + + /// The minimum number of connections + @usableFromInline + let minimumConcurrentConnections: Int + + /// The maximum number of preserved connections + @usableFromInline + let maximumConcurrentConnectionSoftLimit: Int + + /// The absolute maximum number of connections + @usableFromInline + let maximumConcurrentConnectionHardLimit: Int + + @usableFromInline + let keepAlive: Bool + + @usableFromInline + let keepAliveReducesAvailableStreams: Bool + + /// A connectionID generator. + @usableFromInline + let generator: ConnectionIDGenerator + + /// The connections states + @usableFromInline + private(set) var connections: [ConnectionState] + + @usableFromInline + private(set) var stats = Stats() + + @inlinable + init( + generator: ConnectionIDGenerator, + minimumConcurrentConnections: Int, + maximumConcurrentConnectionSoftLimit: Int, + maximumConcurrentConnectionHardLimit: Int, + keepAlive: Bool, + keepAliveReducesAvailableStreams: Bool + ) { + self.generator = generator + self.connections = [] + self.minimumConcurrentConnections = minimumConcurrentConnections + self.maximumConcurrentConnectionSoftLimit = maximumConcurrentConnectionSoftLimit + self.maximumConcurrentConnectionHardLimit = maximumConcurrentConnectionHardLimit + self.keepAlive = keepAlive + self.keepAliveReducesAvailableStreams = keepAliveReducesAvailableStreams + } + + var isEmpty: Bool { + self.connections.isEmpty + } + + @usableFromInline + var canGrow: Bool { + self.stats.active < self.maximumConcurrentConnectionHardLimit + } + + @usableFromInline + var soonAvailableConnections: UInt16 { + self.stats.soonAvailable + } + + // MARK: - Mutations - + + /// A connection's use. Is it persisted or an overflow connection? + @usableFromInline + enum ConnectionUse: Equatable { + case persisted + case demand + case overflow + } + + /// Information around an idle connection. + @usableFromInline + struct AvailableConnectionContext { + /// The connection's use. Either general purpose or for requests with `EventLoop` + /// requirements. + @usableFromInline + var use: ConnectionUse + + @usableFromInline + var info: ConnectionAvailableInfo + } + + /// Information around the failed/closed connection. + @usableFromInline + struct FailedConnectionContext { + /// Connections that are currently starting + @usableFromInline + var connectionsStarting: Int + + @inlinable + init(connectionsStarting: Int) { + self.connectionsStarting = connectionsStarting + } + } + + mutating func refillConnections() -> [ConnectionRequest] { + let existingConnections = self.stats.active + let missingConnection = self.minimumConcurrentConnections - Int(existingConnections) + guard missingConnection > 0 else { + return [] + } + + var requests = [ConnectionRequest]() + requests.reserveCapacity(missingConnection) + + for _ in 0.. ConnectionRequest? { + precondition(self.minimumConcurrentConnections <= self.stats.active) + guard self.maximumConcurrentConnectionSoftLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + mutating func createNewOverflowConnectionIfPossible() -> ConnectionRequest? { + precondition(self.maximumConcurrentConnectionSoftLimit <= self.stats.active) + guard self.maximumConcurrentConnectionHardLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + /*private*/ mutating func createNewConnection() -> ConnectionRequest { + precondition(self.canGrow) + self.stats.connecting += 1 + let connectionID = self.generator.next() + let connection = ConnectionState(id: connectionID) + self.connections.append(connection) + return ConnectionRequest(connectionID: connectionID) + } + + /// A new ``Connection`` was established. + /// + /// This will put the connection into the idle state. + /// + /// - Parameter connection: The new established connection. + /// - Returns: An index and an IdleConnectionContext to determine the next action for the now idle connection. + /// Call ``parkConnection(at:)``, ``leaseConnection(at:)`` or ``closeConnection(at:)`` + /// with the supplied index after this. + @inlinable + mutating func newConnectionEstablished(_ connection: Connection, maxStreams: UInt16) -> (Int, AvailableConnectionContext) { + guard let index = self.connections.firstIndex(where: { $0.id == connection.id }) else { + preconditionFailure("There is a new connection that we didn't request!") + } + self.stats.connecting -= 1 + self.stats.idle += 1 + self.stats.availableStreams += maxStreams + let connectionInfo = self.connections[index].connected(connection, maxStreams: maxStreams) + // TODO: If this is an overflow connection, but we are currently also creating a + // persisted connection, we might want to swap those. + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func backoffNextConnectionAttempt(_ connectionID: Connection.ID) -> ConnectionTimer { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.connecting -= 1 + self.stats.backingOff += 1 + + return self.connections[index].failedToConnect() + } + + @usableFromInline + enum BackoffDoneAction { + case createConnection(ConnectionRequest, TimerCancellationToken?) + case cancelTimers(Max2Sequence) + } + + @inlinable + mutating func backoffDone(_ connectionID: Connection.ID, retry: Bool) -> BackoffDoneAction { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.backingOff -= 1 + + if retry || self.stats.active < self.minimumConcurrentConnections { + self.stats.connecting += 1 + let backoffTimerCancellation = self.connections[index].retryConnect() + return .createConnection(.init(connectionID: connectionID), backoffTimerCancellation) + } + + let backoffTimerCancellation = self.connections[index].destroyBackingOffConnection() + var timerCancellations = Max2Sequence(backoffTimerCancellation) + + if let timerCancellationToken = self.swapForDeletion(index: index) { + timerCancellations.append(timerCancellationToken) + } + return .cancelTimers(timerCancellations) + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + guard let index = self.connections.firstIndex(where: { $0.id == timer.connectionID }) else { + return cancelContinuation + } + + return self.connections[index].timerScheduled(timer, cancelContinuation: cancelContinuation) + } + + // MARK: Leasing and releasing + + /// Lease a connection, if an idle connection is available. + /// + /// - Returns: A connection to execute a request on. + @inlinable + mutating func leaseConnection() -> LeaseResult? { + if self.stats.availableStreams == 0 { + return nil + } + + guard let index = self.findAvailableConnection() else { + preconditionFailure("Stats and actual count are of.") + } + + return self.leaseConnection(at: index, streams: 1) + } + + @usableFromInline + enum LeasedConnectionOrStartingCount { + case leasedConnection(LeaseResult) + case startingCount(UInt16) + } + + @inlinable + mutating func leaseConnectionOrSoonAvailableConnectionCount() -> LeasedConnectionOrStartingCount { + if let result = self.leaseConnection() { + return .leasedConnection(result) + } + return .startingCount(self.stats.soonAvailable) + } + + @inlinable + mutating func leaseConnection(at index: Int, streams: UInt16) -> LeaseResult { + let leaseResult = self.connections[index].lease(streams: streams) + let use = self.getConnectionUse(index: index) + + if leaseResult.wasIdle { + self.stats.idle -= 1 + self.stats.leased += 1 + } + self.stats.leasedStreams += streams + self.stats.availableStreams -= streams + return LeaseResult( + connection: leaseResult.connection, + timersToCancel: leaseResult.timersToCancel, + wasIdle: leaseResult.wasIdle, + use: use + ) + } + + @inlinable + mutating func parkConnection(at index: Int) -> Max2Sequence { + let scheduleIdleTimeoutTimer: Bool + switch index { + case 0.. (Int, AvailableConnectionContext) { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("A connection that we don't know was released? Something is very wrong...") + } + + let connectionInfo = self.connections[index].release(streams: streams) + self.stats.availableStreams += streams + self.stats.leasedStreams -= streams + switch connectionInfo { + case .idle: + self.stats.idle += 1 + self.stats.leased -= 1 + case .leased: + break + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func keepAliveIfIdle(_ connectionID: Connection.ID) -> KeepAliveAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of ping pong) + // was already removed from the state machine. + return nil + } + + guard let action = self.connections[index].runKeepAliveIfIdle(reducesAvailableStreams: self.keepAliveReducesAvailableStreams) else { + return nil + } + + self.stats.runningKeepAlive += 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams -= 1 + } + + return action + } + + @inlinable + mutating func keepAliveSucceeded(_ connectionID: Connection.ID) -> (Int, AvailableConnectionContext)? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("A connection that we don't know was released? Something is very wrong...") + } + + guard let connectionInfo = self.connections[index].keepAliveSucceeded() else { + // if we don't get connection info here this means, that the connection already was + // transitioned to closing. when we did this we already decremented the + // runningKeepAlive timer. + return nil + } + + self.stats.runningKeepAlive -= 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams += 1 + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + // MARK: Connection close/removal + + @usableFromInline + struct CloseAction { + @usableFromInline + private(set) var connection: Connection + + @usableFromInline + private(set) var timersToCancel: Max2Sequence + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence) { + self.connection = connection + self.timersToCancel = timersToCancel + } + } + + /// Closes the connection at the given index. + @inlinable + mutating func closeConnectionIfIdle(at index: Int) -> CloseAction { + guard let closeAction = self.connections[index].closeIfIdle() else { + preconditionFailure("Invalid state: \(self)") + } + + self.stats.idle -= 1 + self.stats.closing += 1 + +// if idleState.runningKeepAlive { +// self.stats.runningKeepAlive -= 1 +// if self.keepAliveReducesAvailableStreams { +// self.stats.availableStreams += 1 +// } +// } + + self.stats.availableStreams -= closeAction.maxStreams + + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + + @inlinable + mutating func closeConnectionIfIdle(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of timeout) + // was already removed from the state machine. + return nil + } + + if index < self.minimumConcurrentConnections { + // because of a race a connection might receive a idle timeout after it was moved into + // the persisted connections. If a connection is now persisted, we now need to ignore + // the trigger + return nil + } + + return self.closeConnectionIfIdle(at: index) + } + + /// Connection closed. Call this method, if a connection is closed. + /// + /// This will put the position into the closed state. + /// + /// - Parameter connectionID: The failed connection's id. + /// - Returns: An optional index and an IdleConnectionContext to determine the next action for the closed connection. + /// You must call ``removeConnection(at:)`` or ``replaceConnection(at:)`` with the + /// supplied index after this. If nil is returned the connection was closed by the state machine and was + /// therefore already removed. + @inlinable + mutating func connectionClosed(_ connectionID: Connection.ID) -> FailedConnectionContext? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + let closedAction = self.connections[index].closed() + + if closedAction.wasRunningKeepAlive { + self.stats.runningKeepAlive -= 1 + } + self.stats.leasedStreams -= closedAction.usedStreams + self.stats.availableStreams -= closedAction.maxStreams - closedAction.usedStreams + + switch closedAction.previousConnectionState { + case .idle: + self.stats.idle -= 1 + + case .leased: + self.stats.leased -= 1 + + case .closing: + self.stats.closing -= 1 + } + + let lastIndex = self.connections.index(before: self.connections.endIndex) + + if index == lastIndex { + self.connections.remove(at: index) + } else { + self.connections.swapAt(index, lastIndex) + self.connections.remove(at: lastIndex) + } + + return FailedConnectionContext(connectionsStarting: 0) + } + + // MARK: Shutdown + + mutating func triggerForceShutdown(_ cleanup: inout ConnectionAction.Shutdown) { + for var connectionState in self.connections { + guard let closeAction = connectionState.close() else { + continue + } + + if let connection = closeAction.connection { + cleanup.connections.append(connection) + } + cleanup.timersToCancel.append(contentsOf: closeAction.cancelTimers) + } + + self.connections = [] + } + + // MARK: - Private functions - + + @usableFromInline + /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { + switch index { + case 0.. AvailableConnectionContext { + precondition(self.connections[index].isAvailable) + let use = self.getConnectionUse(index: index) + return AvailableConnectionContext(use: use, info: info) + } + + @inlinable + /*private*/ func findAvailableConnection() -> Int? { + return self.connections.firstIndex(where: { $0.isAvailable }) + } + + @inlinable + /*private*/ mutating func swapForDeletion(index indexToDelete: Int) -> TimerCancellationToken? { + let maybeLastConnectedIndex = self.connections.lastIndex(where: { $0.isConnected }) + + if maybeLastConnectedIndex == nil || maybeLastConnectedIndex! < indexToDelete { + self.removeO1(indexToDelete) + return nil + } + + // if maybeLastConnectedIndex == nil, we return early in the above if case. + let lastConnectedIndex = maybeLastConnectedIndex! + + switch indexToDelete { + case 0.. State.Timer { - defer { self.nextTimerID += 1 } - return State.Timer(id: self.nextTimerID) - } - /// The connection failed to start @inlinable mutating func failedToConnect() -> ConnectionTimer { switch self.state { case .starting: - let backoffTimerState = self.nextTimer() + let backoffTimerState = self._nextTimer() self.state = .backingOff(backoffTimerState) return ConnectionTimer(timerID: backoffTimerState.timerID, connectionID: self.id, usecase: .backoff) @@ -311,6 +295,17 @@ extension PoolStateMachine { } } + @inlinable + mutating func destroyBackingOffConnection() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .closed + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + @usableFromInline struct LeaseAction { @usableFromInline @@ -468,78 +463,211 @@ extension PoolStateMachine { } } + @inlinable + mutating func cancelIdleTimer() -> TimerCancellationToken? { + switch self.state { + case .starting, .backingOff, .leased, .closing, .closed: + return nil + + case .idle(let connection, let maxStreams, let keepAlive, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return idleTimer?.cancellationContinuation + } + } + @usableFromInline struct CloseAction { + @usableFromInline - var connection: Connection + enum PreviousConnectionState { + case idle + case leased + case closing + case backingOff + } + + @usableFromInline + var connection: Connection? + @usableFromInline + var previousConnectionState: PreviousConnectionState @usableFromInline var cancelTimers: Max2Sequence @usableFromInline + var usedStreams: UInt16 + @usableFromInline var maxStreams: UInt16 @inlinable - init(connection: Connection, cancelTimers: Max2Sequence, maxStreams: UInt16) { + init( + connection: Connection?, + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + usedStreams: UInt16, + maxStreams: UInt16 + ) { self.connection = connection + self.previousConnectionState = previousConnectionState self.cancelTimers = cancelTimers + self.usedStreams = usedStreams self.maxStreams = maxStreams } } @inlinable - mutating func close() -> CloseAction { + mutating func closeIfIdle() -> CloseAction? { switch self.state { case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): self.state = .closing(connection) return CloseAction( connection: connection, + previousConnectionState: .idle, cancelTimers: Max2Sequence( keepAlive.cancelTimerIfScheduled(), idleTimerState?.cancellationContinuation ), + usedStreams: keepAlive.usedStreams, maxStreams: maxStreams ) - case .backingOff, .starting, .leased, .closing, .closed: + case .leased, .closed: + return nil + + case .backingOff, .starting, .closing: preconditionFailure("Invalid state: \(self.state)") } } @inlinable - mutating func closeIfIdle() -> CloseAction? { + mutating func close() -> CloseAction? { switch self.state { - case .idle: - return self.close() - case .leased, .closed: + case .starting: + // If we are currently starting, there is nothing we can do about it right now. + // Only once the connection has come up, or failed, we can actually act. return nil - case .backingOff, .starting, .closing: - preconditionFailure("Invalid state: \(self.state)") + + case .closing, .closed: + // If we are already closing, we can't do anything else. + return nil + + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .idle, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + usedStreams: keepAlive.usedStreams, + maxStreams: maxStreams + ) + + case .leased(let connection, usedStreams: let usedStreams, maxStreams: let maxStreams, var keepAlive): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .leased, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled() + ), + usedStreams: keepAlive.usedStreams + usedStreams, + maxStreams: maxStreams + ) + + case .backingOff(let timer): + self.state = .closed + return CloseAction( + connection: nil, + previousConnectionState: .backingOff, + cancelTimers: Max2Sequence(timer.cancellationContinuation), + usedStreams: 0, + maxStreams: 0 + ) } } @usableFromInline - struct ShutdownAction { + struct ClosedAction { + @usableFromInline - var connection: Connection? + enum PreviousConnectionState { + case idle + case leased + case closing + } + @usableFromInline - var timersToCancel: Max2Sequence + var previousConnectionState: PreviousConnectionState + @usableFromInline + var cancelTimers: Max2Sequence @usableFromInline var maxStreams: UInt16 @usableFromInline var usedStreams: UInt16 + @usableFromInline + var wasRunningKeepAlive: Bool @inlinable init( - connection: Connection? = nil, - timersToCancel: Max2Sequence = .init(), - maxStreams: UInt16 = 0, - usedStreams: UInt16 = 0 + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + maxStreams: UInt16, + usedStreams: UInt16, + wasRunningKeepAlive: Bool ) { - self.connection = connection - self.timersToCancel = timersToCancel + self.previousConnectionState = previousConnectionState + self.cancelTimers = cancelTimers self.maxStreams = maxStreams self.usedStreams = usedStreams + self.wasRunningKeepAlive = wasRunningKeepAlive + } + } + + @inlinable + mutating func closed() -> ClosedAction { + switch self.state { + case .starting, .backingOff, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(_, let maxStreams, var keepAlive, let idleTimer): + self.state = .closed + return ClosedAction( + previousConnectionState: .idle, + cancelTimers: .init(keepAlive.cancelTimerIfScheduled(), idleTimer?.cancellationContinuation), + maxStreams: maxStreams, + usedStreams: keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + self.state = .closed + return ClosedAction( + previousConnectionState: .leased, + cancelTimers: .init(), + maxStreams: maxStreams, + usedStreams: usedStreams + keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .closing: + self.state = .closed + return ClosedAction( + previousConnectionState: .closing, + cancelTimers: .init(), + maxStreams: 0, + usedStreams: 0, + wasRunningKeepAlive: false + ) } } + + // MARK: - Private Methods - + + @inlinable + mutating /*private*/ func _nextTimer() -> State.Timer { + defer { self.nextTimerID += 1 } + return State.Timer(id: self.nextTimerID) + } } @usableFromInline diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index dc18784f..29349e56 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -37,7 +37,7 @@ struct PoolStateMachine< ConnectionID: Hashable & Sendable, Request: ConnectionRequestProtocol, RequestID, - TimerCancellationToken + TimerCancellationToken: Sendable > where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { @usableFromInline diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift new file mode 100644 index 00000000..4e3a1647 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -0,0 +1,294 @@ +import XCTest +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionGroupTests: XCTestCase { + var idGenerator: ConnectionIDGenerator! + + override func setUp() { + self.idGenerator = ConnectionIDGenerator() + super.setUp() + } + + override func tearDown() { + self.idGenerator = nil + super.tearDown() + } + + func testRefillConnections() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 4, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + XCTAssertTrue(connections.isEmpty) + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + + XCTAssertEqual(requests.count, 4) + XCTAssertNil(connections.createNewDemandConnectionIfPossible()) + XCTAssertNil(connections.createNewOverflowConnectionIfPossible()) + XCTAssertEqual(connections.stats, .init(connecting: 4)) + XCTAssertEqual(connections.soonAvailableConnections, 4) + + let requests2 = connections.refillConnections() + XCTAssertTrue(requests2.isEmpty) + + var connected: UInt16 = 0 + for request in requests { + let newConnection = MockConnection(id: request.connectionID) + let (_, context) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(context.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(context.use, .persisted) + connected += 1 + XCTAssertEqual(connections.stats, .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) + XCTAssertEqual(connections.soonAvailableConnections, 4 - connected) + } + + let requests3 = connections.refillConnections() + XCTAssertTrue(requests3.isEmpty) + } + + func testMakeConnectionLeaseItAndDropItHappyPath() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertTrue(connections.isEmpty) + XCTAssertTrue(requests.isEmpty) + + guard let request = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to receive a connection request") + } + XCTAssertEqual(request, .init(connectionID: 0)) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + let newConnection = MockConnection(id: request.connectionID) + let (_, establishedContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 0) + + guard case .leasedConnection(let leaseResult) = connections.leaseConnectionOrSoonAvailableConnectionCount() else { + return XCTFail("Expected to lease a connection") + } + XCTAssert(newConnection === leaseResult.connection) + XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) + + let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) + XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + let parkTimers = connections.parkConnection(at: index) + XCTAssertEqual(parkTimers, [ + .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), + .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), + ]) + + guard let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssert(newConnection === keepAliveAction.connection) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, pingPongContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to get an AvailableContext") + } + XCTAssertEqual(pingPongContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + guard let closeAction = connections.closeConnectionIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssertEqual(closeAction.timersToCancel, []) + XCTAssert(closeAction.connection === newConnection) + XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) + + let closeContext = connections.connectionClosed(newConnection.id) + XCTAssertEqual(closeContext?.connectionsStarting, 0) + XCTAssertTrue(connections.isEmpty) + XCTAssertEqual(connections.stats, .init()) + } + + func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 1) + + guard let request = requests.first else { return XCTFail("Expected to receive a connection request") } + XCTAssertEqual(request, .init(connectionID: 0)) + + let backoffTimer = connections.backoffNextConnectionAttempt(request.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + let backoffDoneAction = connections.backoffDone(request.connectionID, retry: false) + XCTAssertEqual(backoffDoneAction, .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) + + XCTAssertEqual(connections.stats, .init(connecting: 1)) + } + + func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 2, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 2) + + var requestIterator = requests.makeIterator() + guard let firstRequest = requestIterator.next(), let secondRequest = requestIterator.next() else { + return XCTFail("Expected to get two requests") + } + + guard let thirdRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 3)) + + let newSecondConnection = MockConnection(id: secondRequest.connectionID) + let (_, establishedSecondConnectionContext) = connections.newConnectionEstablished(newSecondConnection, maxStreams: 1) + XCTAssertEqual(establishedSecondConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedSecondConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(connecting: 2, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + + let newThirdConnection = MockConnection(id: thirdRequest.connectionID) + let (thirdConnectionIndex, establishedThirdConnectionContext) = connections.newConnectionEstablished(newThirdConnection, maxStreams: 1) + XCTAssertEqual(establishedThirdConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedThirdConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 2, availableStreams: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) + let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) + let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) + XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex), [thirdConnKeepTimer, thirdConnIdleTimer]) + + XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) + XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) + + let backoffTimer = connections.backoffNextConnectionAttempt(firstRequest.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + + // connection three should be moved to connection one and for this reason become permanent + + XCTAssertEqual(connections.backoffDone(firstRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 2, availableStreams: 2)) + + XCTAssertNil(connections.closeConnectionIfIdle(newThirdConnection.id)) + } + + func testBackoffDoneReturnsNilIfOverflowConnection() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get two requests") + } + + guard let secondRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 2)) + + let newFirstConnection = MockConnection(id: firstRequest.connectionID) + let (_, establishedFirstConnectionContext) = connections.newConnectionEstablished(newFirstConnection, maxStreams: 1) + XCTAssertEqual(establishedFirstConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedFirstConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + + let backoffTimer = connections.backoffNextConnectionAttempt(secondRequest.connectionID) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 1, availableStreams: 1)) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + XCTAssertEqual(connections.backoffDone(secondRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + XCTAssertNotNil(connections.closeConnectionIfIdle(newFirstConnection.id)) + } + + func testPingPong() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + XCTAssertEqual(requests.count, 1) + guard let firstRequest = requests.first else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + let timers = connections.parkConnection(at: connectionIndex) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertEqual(timers, [keepAliveTimer]) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, afterPingIdleContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to receive an AvailableContext") + } + XCTAssertEqual(afterPingIdleContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(afterPingIdleContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index b1622d0d..7751837e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -10,19 +10,19 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { let connectionID = 1 var state = TestConnectionState(id: connectionID) XCTAssertEqual(state.id, connectionID) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isIdle, false) XCTAssertEqual(state.isAvailable, false) XCTAssertEqual(state.isConnected, false) XCTAssertEqual(state.isLeased, false) let connection = MockConnection(id: connectionID) XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, true) + XCTAssertEqual(state.isIdle, true) XCTAssertEqual(state.isAvailable, true) XCTAssertEqual(state.isConnected, true) XCTAssertEqual(state.isLeased, false) XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isIdle, false) XCTAssertEqual(state.isAvailable, false) XCTAssertEqual(state.isConnected, true) XCTAssertEqual(state.isLeased, true) @@ -257,7 +257,7 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], maxStreams: 1)) + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1)) XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) } From 472ff4ae68bd9b8d59d978137812137ee8162f4a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 25 Oct 2023 22:44:46 +0200 Subject: [PATCH 042/106] Add `PoolStateMachine` (#427) --- .../PoolStateMachine+ConnectionGroup.swift | 61 ++- .../PoolStateMachine+RequestQueue.swift | 10 +- .../PoolStateMachine.swift | 484 +++++++++++++++++- ...tSequence.swift => TinyFastSequence.swift} | 80 ++- ...oolStateMachine+ConnectionGroupTests.swift | 2 +- .../PoolStateMachineTests.swift | 217 +++++++- ...tSequence.swift => TinyFastSequence.swift} | 16 +- 7 files changed, 814 insertions(+), 56 deletions(-) rename Sources/ConnectionPoolModule/{OneElementFastSequence.swift => TinyFastSequence.swift} (58%) rename Tests/ConnectionPoolModuleTests/{OneElementFastSequence.swift => TinyFastSequence.swift} (82%) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 8ec99c7d..16970599 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -134,19 +134,6 @@ extension PoolStateMachine { var info: ConnectionAvailableInfo } - /// Information around the failed/closed connection. - @usableFromInline - struct FailedConnectionContext { - /// Connections that are currently starting - @usableFromInline - var connectionsStarting: Int - - @inlinable - init(connectionsStarting: Int) { - self.connectionsStarting = connectionsStarting - } - } - mutating func refillConnections() -> [ConnectionRequest] { let existingConnections = self.stats.active let missingConnection = self.minimumConcurrentConnections - Int(existingConnections) @@ -477,6 +464,31 @@ extension PoolStateMachine { return self.closeConnectionIfIdle(at: index) } + /// Information around the failed/closed connection. + @usableFromInline + struct ClosedAction { + /// Connections that are currently starting + @usableFromInline + var connectionsStarting: Int + + @usableFromInline + var timersToCancel: TinyFastSequence + + @usableFromInline + var newConnectionRequest: ConnectionRequest? + + @inlinable + init( + connectionsStarting: Int, + timersToCancel: TinyFastSequence, + newConnectionRequest: ConnectionRequest? = nil + ) { + self.connectionsStarting = connectionsStarting + self.timersToCancel = timersToCancel + self.newConnectionRequest = newConnectionRequest + } + } + /// Connection closed. Call this method, if a connection is closed. /// /// This will put the position into the closed state. @@ -487,12 +499,13 @@ extension PoolStateMachine { /// supplied index after this. If nil is returned the connection was closed by the state machine and was /// therefore already removed. @inlinable - mutating func connectionClosed(_ connectionID: Connection.ID) -> FailedConnectionContext? { + mutating func connectionClosed(_ connectionID: Connection.ID) -> ClosedAction { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - return nil + preconditionFailure("All connections that have been created should say goodbye exactly once!") } let closedAction = self.connections[index].closed() + var timersToCancel = TinyFastSequence(closedAction.cancelTimers) if closedAction.wasRunningKeepAlive { self.stats.runningKeepAlive -= 1 @@ -511,16 +524,22 @@ extension PoolStateMachine { self.stats.closing -= 1 } - let lastIndex = self.connections.index(before: self.connections.endIndex) + if let cancellationTimer = self.swapForDeletion(index: index) { + timersToCancel.append(cancellationTimer) + } - if index == lastIndex { - self.connections.remove(at: index) + let newConnectionRequest: ConnectionRequest? + if self.connections.count < self.minimumConcurrentConnections { + newConnectionRequest = .init(connectionID: self.generator.next()) } else { - self.connections.swapAt(index, lastIndex) - self.connections.remove(at: lastIndex) + newConnectionRequest = .none } - return FailedConnectionContext(connectionsStarting: 0) + return ClosedAction( + connectionsStarting: 0, + timersToCancel: timersToCancel, + newConnectionRequest: newConnectionRequest + ) } // MARK: Shutdown diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift index 7e3c6607..f1d6f4e4 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -10,7 +10,7 @@ extension PoolStateMachine { /// request from the dictionary and keep it inside the queue. Whenever we pop a request from the deque, we validate /// that it hasn't been cancelled in the meantime by checking if the popped request is still in the `requests` dictionary. @usableFromInline - struct RequestQueue { + struct RequestQueue: Sendable { @usableFromInline private(set) var queue: Deque @@ -40,8 +40,8 @@ extension PoolStateMachine { } @inlinable - mutating func pop(max: UInt16) -> OneElementFastSequence { - var result = OneElementFastSequence() + mutating func pop(max: UInt16) -> TinyFastSequence { + var result = TinyFastSequence() result.reserveCapacity(Int(max)) var popped = 0 while let requestID = self.queue.popFirst(), popped < max { @@ -61,8 +61,8 @@ extension PoolStateMachine { } @inlinable - mutating func removeAll() -> OneElementFastSequence { - let result = OneElementFastSequence(self.requests.values) + mutating func removeAll() -> TinyFastSequence { + let result = TinyFastSequence(self.requests.values) self.requests.removeAll() self.queue.removeAll() return result diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 29349e56..aa62d749 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -6,7 +6,7 @@ import Glibc @usableFromInline @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -struct PoolConfiguration { +struct PoolConfiguration: Sendable { /// The minimum number of connections to preserve in the pool. /// /// If the pool is mostly idle and the remote servers closes idle connections, @@ -38,10 +38,10 @@ struct PoolStateMachine< Request: ConnectionRequestProtocol, RequestID, TimerCancellationToken: Sendable -> where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { - +>: Sendable where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + @usableFromInline - struct ConnectionRequest: Equatable { + struct ConnectionRequest: Hashable, Sendable { @usableFromInline var connectionID: ConnectionID @inlinable @@ -50,6 +50,21 @@ struct PoolStateMachine< } } + @usableFromInline + struct Action { + @usableFromInline let request: RequestAction + @usableFromInline let connection: ConnectionAction + + @inlinable + init(request: RequestAction, connection: ConnectionAction) { + self.request = request + self.connection = connection + } + + @inlinable + static func none() -> Action { Action(request: .none, connection: .none) } + } + @usableFromInline enum ConnectionAction { @usableFromInline @@ -67,15 +82,32 @@ struct PoolStateMachine< } case scheduleTimers(Max2Sequence) - case makeConnection(ConnectionRequest, TimerCancellationToken?) + case makeConnection(ConnectionRequest, TinyFastSequence) case runKeepAlive(Connection, TimerCancellationToken?) - case cancelTimers(Max2Sequence) - case closeConnection(Connection) + case cancelTimers(TinyFastSequence) + case closeConnection(Connection, Max2Sequence) case shutdown(Shutdown) case none } + @usableFromInline + enum RequestAction { + case leaseConnection(TinyFastSequence, Connection) + + case failRequest(Request, ConnectionPoolError) + case failRequests(TinyFastSequence, ConnectionPoolError) + + case none + } + + @usableFromInline + enum PoolState: Sendable { + case running + case shuttingDown(graceful: Bool) + case shutDown + } + @usableFromInline struct Timer: Hashable, Sendable { @usableFromInline @@ -84,10 +116,448 @@ struct PoolStateMachine< @usableFromInline var duration: Duration + @inlinable + var connectionID: ConnectionID { + self.underlying.connectionID + } + @inlinable init(_ connectionTimer: ConnectionTimer, duration: Duration) { self.underlying = connectionTimer self.duration = duration } } + + @usableFromInline let configuration: PoolConfiguration + @usableFromInline let generator: ConnectionIDGenerator + + @usableFromInline + private(set) var connections: ConnectionGroup + @usableFromInline + private(set) var requestQueue: RequestQueue + @usableFromInline + private(set) var poolState: PoolState = .running + @usableFromInline + private(set) var cacheNoMoreConnectionsAllowed: Bool = false + + @usableFromInline + private(set) var failedConsecutiveConnectionAttempts: Int = 0 + + @inlinable + init( + configuration: PoolConfiguration, + generator: ConnectionIDGenerator, + timerCancellationTokenType: TimerCancellationToken.Type + ) { + self.configuration = configuration + self.generator = generator + self.connections = ConnectionGroup( + generator: generator, + minimumConcurrentConnections: configuration.minimumConnectionCount, + maximumConcurrentConnectionSoftLimit: configuration.maximumConnectionSoftLimit, + maximumConcurrentConnectionHardLimit: configuration.maximumConnectionHardLimit, + keepAlive: configuration.keepAliveDuration != nil, + keepAliveReducesAvailableStreams: true + ) + self.requestQueue = RequestQueue() + } + + mutating func refillConnections() -> [ConnectionRequest] { + return self.connections.refillConnections() + } + + @inlinable + mutating func leaseConnection(_ request: Request) -> Action { + switch self.poolState { + case .running: + break + + case .shuttingDown, .shutDown: + return .init( + request: .failRequest(request, ConnectionPoolError.poolShutdown), + connection: .none + ) + } + + if !self.requestQueue.isEmpty && self.cacheNoMoreConnectionsAllowed { + self.requestQueue.queue(request) + return .none() + } + + var soonAvailable: UInt16 = 0 + + // check if any other EL has an idle connection + switch self.connections.leaseConnectionOrSoonAvailableConnectionCount() { + case .leasedConnection(let leaseResult): + return .init( + request: .leaseConnection(TinyFastSequence(element: request), leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + + case .startingCount(let count): + soonAvailable += count + } + + // we tried everything. there is no connection available. now we must check, if and where we + // can create further connections. but first we must enqueue the new request + + self.requestQueue.queue(request) + + let requestAction = RequestAction.none + + if soonAvailable >= self.requestQueue.count { + // if more connections will be soon available then we have waiters, we don't need to + // create further new connections. + return .init( + request: requestAction, + connection: .none + ) + } else if let request = self.connections.createNewDemandConnectionIfPossible() { + // Can we create a demand connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else if let request = self.connections.createNewOverflowConnectionIfPossible() { + // Can we create an overflow connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else { + self.cacheNoMoreConnectionsAllowed = true + + // no new connections allowed: + return .init(request: requestAction, connection: .none) + } + } + + @inlinable + mutating func releaseConnection(_ connection: Connection, streams: UInt16) -> Action { + let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) + return self.handleAvailableConnection(index: index, availableContext: context) + } + + mutating func cancelRequest(id: RequestID) -> Action { + guard let request = self.requestQueue.remove(id) else { + return .none() + } + + return .init( + request: .failRequest(request, ConnectionPoolError.requestCancelled), + connection: .none + ) + } + + @inlinable + mutating func connectionEstablished(_ connection: Connection, maxStreams: UInt16) -> Action { + let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) + return self.handleAvailableConnection(index: index, availableContext: context) + } + + @inlinable + mutating func timerScheduled(_ timer: Timer, cancelContinuation: TimerCancellationToken) -> TimerCancellationToken? { + self.connections.timerScheduled(timer.underlying, cancelContinuation: cancelContinuation) + } + + @inlinable + mutating func timerTriggered(_ timer: Timer) -> Action { + switch timer.underlying.usecase { + case .backoff: + return self.connectionCreationBackoffDone(timer.connectionID) + case .keepAlive: + return self.connectionKeepAliveTimerTriggered(timer.connectionID) + case .idleTimeout: + return self.connectionIdleTimerTriggered(timer.connectionID) + } + } + + @inlinable + mutating func connectionEstablishFailed(_ error: Error, for request: ConnectionRequest) -> Action { + self.failedConsecutiveConnectionAttempts += 1 + + let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) + let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + let timer = Timer(connectionTimer, duration: backoff) + return .init(request: .none, connection: .scheduleTimers(.init(timer))) + } + + @inlinable + mutating func connectionCreationBackoffDone(_ connectionID: ConnectionID) -> Action { + let soonAvailable = self.connections.soonAvailableConnections + let retry = (soonAvailable - 1) < self.requestQueue.count + + switch self.connections.backoffDone(connectionID, retry: retry) { + case .createConnection(let request, let continuation): + let timers: TinyFastSequence + if let continuation { + timers = .init(element: continuation) + } else { + timers = .init() + } + return .init(request: .none, connection: .makeConnection(request, timers)) + + case .cancelTimers(let timers): + return .init(request: .none, connection: .cancelTimers(.init(timers))) + } + } + + @inlinable + mutating func connectionKeepAliveTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + precondition(self.requestQueue.isEmpty) + + guard let keepAliveAction = self.connections.keepAliveIfIdle(connectionID) else { + return .none() + } + return .init(request: .none, connection: .runKeepAlive(keepAliveAction.connection, keepAliveAction.keepAliveTimerCancellationContinuation)) + } + + @inlinable + mutating func connectionKeepAliveDone(_ connection: Connection) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + guard let (index, context) = self.connections.keepAliveSucceeded(connection.id) else { + return .none() + } + return self.handleAvailableConnection(index: index, availableContext: context) + } + + @inlinable + mutating func connectionIdleTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.requestQueue.isEmpty) + + guard let closeAction = self.connections.closeConnectionIfIdle(connectionID) else { + return .none() + } + + self.cacheNoMoreConnectionsAllowed = false + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + + @inlinable + mutating func connectionClosed(_ connection: Connection) -> Action { + self.cacheNoMoreConnectionsAllowed = false + + let closedConnectionAction = self.connections.connectionClosed(connection.id) + + let connectionAction: ConnectionAction + if let newRequest = closedConnectionAction.newConnectionRequest { + connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) + } else { + connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) + } + + return .init(request: .none, connection: connectionAction) + } + + struct CleanupAction { + struct ConnectionToDrop { + var connection: Connection + var keepAliveTimer: Bool + var idleTimer: Bool + } + + var connections: [ConnectionToDrop] + var requests: [Request] + } + + mutating func triggerGracefulShutdown() -> Action { + fatalError("Unimplemented") + } + + mutating func triggerForceShutdown() -> Action { + switch self.poolState { + case .running: + self.poolState = .shuttingDown(graceful: false) + var shutdown = ConnectionAction.Shutdown() + self.connections.triggerForceShutdown(&shutdown) + + if shutdown.connections.isEmpty { + self.poolState = .shutDown + } + + return .init( + request: .failRequests(self.requestQueue.removeAll(), ConnectionPoolError.poolShutdown), + connection: .shutdown(shutdown) + ) + + case .shuttingDown: + return .none() + + case .shutDown: + return .init(request: .none, connection: .none) + } + } + + @inlinable + /*private*/ mutating func handleAvailableConnection( + index: Int, + availableContext: ConnectionGroup.AvailableConnectionContext + ) -> Action { + // this connection was busy before + let requests = self.requestQueue.pop(max: availableContext.info.availableStreams) + if !requests.isEmpty { + let leaseResult = self.connections.leaseConnection(at: index, streams: UInt16(requests.count)) + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + + switch availableContext.use { + case .persisted, .demand: + switch availableContext.info { + case .leased: + return .none() + + case .idle: + let timers = self.connections.parkConnection(at: index).map(self.mapTimers) + + return .init( + request: .none, + connection: .scheduleTimers(timers) + ) + } + + case .overflow: + let closeAction = self.connections.closeConnectionIfIdle(at: index) + return .init( + request: .none, + connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) + ) + } + + } + + @inlinable + /* private */ func mapTimers(_ connectionTimer: ConnectionTimer) -> Timer { + switch connectionTimer.usecase { + case .backoff: + return Timer( + connectionTimer, + duration: Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + ) + + case .keepAlive: + return Timer(connectionTimer, duration: self.configuration.keepAliveDuration!) + + case .idleTimeout: + return Timer(connectionTimer, duration: self.configuration.idleTimeoutDuration) + + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + /// Calculates the delay for the next connection attempt after the given number of failed `attempts`. + /// + /// Our backoff formula is: 100ms * 1.25^(attempts - 1) with 3% jitter that is capped of at 1 minute. + /// This means for: + /// - 1 failed attempt : 100ms + /// - 5 failed attempts: ~300ms + /// - 10 failed attempts: ~930ms + /// - 15 failed attempts: ~2.84s + /// - 20 failed attempts: ~8.67s + /// - 25 failed attempts: ~26s + /// - 29 failed attempts: ~60s (max out) + /// + /// - Parameter attempts: number of failed attempts in a row + /// - Returns: time to wait until trying to establishing a new connection + @usableFromInline + static func calculateBackoff(failedAttempt attempts: Int) -> Duration { + // Our backoff formula is: 100ms * 1.25^(attempts - 1) that is capped of at 1minute + // This means for: + // - 1 failed attempt : 100ms + // - 5 failed attempts: ~300ms + // - 10 failed attempts: ~930ms + // - 15 failed attempts: ~2.84s + // - 20 failed attempts: ~8.67s + // - 25 failed attempts: ~26s + // - 29 failed attempts: ~60s (max out) + + let start = Double(100_000_000) + let backoffNanosecondsDouble = start * pow(1.25, Double(attempts - 1)) + + // Cap to 60s _before_ we convert to Int64, to avoid trapping in the Int64 initializer. + let backoffNanoseconds = Int64(min(backoffNanosecondsDouble, Double(60_000_000_000))) + + let backoff = Duration.nanoseconds(backoffNanoseconds) + + // Calculate a 3% jitter range + let jitterRange = (backoffNanoseconds / 100) * 3 + // Pick a random element from the range +/- jitter range. + let jitter: Duration = .nanoseconds((-jitterRange...jitterRange).randomElement()!) + let jitteredBackoff = backoff + jitter + return jitteredBackoff + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.Action: Equatable where TimerCancellationToken: Equatable, Request: Equatable {} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.scheduleTimers(let lhs), .scheduleTimers(let rhs)): + return lhs == rhs + case (.makeConnection(let lhsRequest, let lhsToken), .makeConnection(let rhsRequest, let rhsToken)): + return lhsRequest == rhsRequest && lhsToken == rhsToken + case (.runKeepAlive(let lhsConn, let lhsToken), .runKeepAlive(let rhsConn, let rhsToken)): + return lhsConn === rhsConn && lhsToken == rhsToken + case (.closeConnection(let lhsConn, let lhsTimers), .closeConnection(let rhsConn, let rhsTimers)): + return lhsConn === rhsConn && lhsTimers == rhsTimers + case (.shutdown(let lhs), .shutdown(let rhs)): + return lhs == rhs + case (.cancelTimers(let lhs), .cancelTimers(let rhs)): + return lhs == rhs + case (.none, .none), + (.cancelTimers([]), .none), (.none, .cancelTimers([])), + (.scheduleTimers([]), .none), (.none, .scheduleTimers([])): + return true + default: + return false + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction.Shutdown: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + Set(lhs.connections.lazy.map(\.id)) == Set(rhs.connections.lazy.map(\.id)) && lhs.timersToCancel == rhs.timersToCancel + } +} + + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.RequestAction: Equatable where Request: Equatable { + + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.leaseConnection(let lhsRequests, let lhsConn), .leaseConnection(let rhsRequests, let rhsConn)): + guard lhsRequests.count == rhsRequests.count else { return false } + var lhsIterator = lhsRequests.makeIterator() + var rhsIterator = rhsRequests.makeIterator() + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.id == rhsNext.id else { return false } + } + return lhsConn === rhsConn + + case (.failRequest(let lhsRequest, let lhsError), .failRequest(let rhsRequest, let rhsError)): + return lhsRequest.id == rhsRequest.id && lhsError == rhsError + + case (.failRequests(let lhsRequests, let lhsError), .failRequests(let rhsRequests, let rhsError)): + return Set(lhsRequests.lazy.map(\.id)) == Set(rhsRequests.lazy.map(\.id)) && lhsError == rhsError + + case (.none, .none): + return true + + default: + return false + } + } } diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift similarity index 58% rename from Sources/ConnectionPoolModule/OneElementFastSequence.swift rename to Sources/ConnectionPoolModule/TinyFastSequence.swift index 3c3bfaa0..dff8a30b 100644 --- a/Sources/ConnectionPoolModule/OneElementFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -1,10 +1,11 @@ /// A `Sequence` that does not heap allocate, if it only carries a single element @usableFromInline -struct OneElementFastSequence: Sequence { +struct TinyFastSequence: Sequence { @usableFromInline enum Base { case none(reserveCapacity: Int) case one(Element, reserveCapacity: Int) + case two(Element, Element, reserveCapacity: Int) case n([Element]) } @@ -37,6 +38,20 @@ struct OneElementFastSequence: Sequence { } } + @inlinable + init(_ max2Sequence: Max2Sequence) { + switch max2Sequence.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(max2Sequence.first!, reserveCapacity: 0) + case 2: + self.base = .n(Array(max2Sequence)) + default: + fatalError() + } + } + @usableFromInline var count: Int { switch self.base { @@ -44,6 +59,8 @@ struct OneElementFastSequence: Sequence { return 0 case .one: return 1 + case .two: + return 2 case .n(let array): return array.count } @@ -56,6 +73,8 @@ struct OneElementFastSequence: Sequence { return nil case .one(let element, _): return element + case .two(let first, _, _): + return first case .n(let array): return array.first } @@ -66,7 +85,7 @@ struct OneElementFastSequence: Sequence { switch self.base { case .none: return true - case .one, .n: + case .one, .two, .n: return false } } @@ -78,6 +97,8 @@ struct OneElementFastSequence: Sequence { self.base = .none(reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) case .one(let element, let reservedCapacity): self.base = .one(element, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .two(let first, let second, let reservedCapacity): + self.base = .two(first, second, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) case .n(var array): self.base = .none(reserveCapacity: 0) // prevent CoW array.reserveCapacity(minimumCapacity) @@ -90,12 +111,17 @@ struct OneElementFastSequence: Sequence { switch self.base { case .none(let reserveCapacity): self.base = .one(element, reserveCapacity: reserveCapacity) - case .one(let existing, let reserveCapacity): + case .one(let first, let reserveCapacity): + self.base = .two(first, element, reserveCapacity: reserveCapacity) + + case .two(let first, let second, let reserveCapacity): var new = [Element]() - new.reserveCapacity(reserveCapacity) - new.append(existing) + new.reserveCapacity(Swift.max(4, reserveCapacity)) + new.append(first) + new.append(second) new.append(element) self.base = .n(new) + case .n(var existing): self.base = .none(reserveCapacity: 0) // prevent CoW existing.append(element) @@ -111,10 +137,10 @@ struct OneElementFastSequence: Sequence { @usableFromInline struct Iterator: IteratorProtocol { @usableFromInline private(set) var index: Int = 0 - @usableFromInline private(set) var backing: OneElementFastSequence + @usableFromInline private(set) var backing: TinyFastSequence @inlinable - init(_ backing: OneElementFastSequence) { + init(_ backing: TinyFastSequence) { self.backing = backing } @@ -130,6 +156,17 @@ struct OneElementFastSequence: Sequence { } return nil + case .two(let first, let second, _): + defer { self.index += 1 } + switch self.index { + case 0: + return first + case 1: + return second + default: + return nil + } + case .n(let array): if self.index < array.endIndex { defer { self.index += 1} @@ -141,11 +178,28 @@ struct OneElementFastSequence: Sequence { } } -extension OneElementFastSequence: Equatable where Element: Equatable {} -extension OneElementFastSequence.Base: Equatable where Element: Equatable {} +extension TinyFastSequence: Equatable where Element: Equatable {} +extension TinyFastSequence.Base: Equatable where Element: Equatable {} + +extension TinyFastSequence: Hashable where Element: Hashable {} +extension TinyFastSequence.Base: Hashable where Element: Hashable {} -extension OneElementFastSequence: Hashable where Element: Hashable {} -extension OneElementFastSequence.Base: Hashable where Element: Hashable {} +extension TinyFastSequence: Sendable where Element: Sendable {} +extension TinyFastSequence.Base: Sendable where Element: Sendable {} -extension OneElementFastSequence: Sendable where Element: Sendable {} -extension OneElementFastSequence.Base: Sendable where Element: Sendable {} +extension TinyFastSequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + var iterator = elements.makeIterator() + switch elements.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(iterator.next()!, reserveCapacity: 0) + case 2: + self.base = .two(iterator.next()!, iterator.next()!, reserveCapacity: 0) + default: + self.base = .n(elements) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 4e3a1647..bf385918 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -120,7 +120,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) let closeContext = connections.connectionClosed(newConnection.id) - XCTAssertEqual(closeContext?.connectionsStarting, 0) + XCTAssertEqual(closeContext.connectionsStarting, 0) XCTAssertTrue(connections.isEmpty) XCTAssertEqual(connections.stats, .init()) } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index ee8cfdc6..0f3af728 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,5 +1,3 @@ -import NIOCore -import NIOEmbedded import XCTest @testable import _ConnectionPoolModule @@ -12,3 +10,218 @@ typealias TestPoolStateMachine = PoolStateMachine< MockRequest.ID, MockTimerCancellationToken > + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachineTests: XCTestCase { + + func testConnectionsAreCreatedAndParkedOnStartup() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 2 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = .seconds(10) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + let connection2 = MockConnection(id: 1) + + do { + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 2) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + let connection1KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 0, usecase: .keepAlive), duration: .seconds(10)) + let connection1KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection1KeepAliveTimer) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([connection1KeepAliveTimer])) + + XCTAssertEqual(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken), .none) + + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + let connection2KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: .seconds(10)) + let connection2KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection2KeepAliveTimer) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .scheduleTimers([connection2KeepAliveTimer])) + XCTAssertEqual(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken), .none) + } + } + + func testConnectionsNoKeepAliveRun() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(5) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // release connection 1 + XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) + + // lease connection 1 + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) + + // request connection while none is available + let request3 = MockRequest() + let leaseRequest3 = stateMachine.leaseConnection(request3) + XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest3.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request3), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + XCTAssertEqual(stateMachine.timerTriggered(connection2IdleTimer), .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) + } + + func testOnlyOverflowConnections() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 0 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .leaseConnection(.init(element: request2), connection1)) + XCTAssertEqual(releaseRequest1.connection, .none) + + // connection 2 comes up and should be closed right away + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .closeConnection(connection2, [])) + XCTAssertEqual(stateMachine.connectionClosed(connection2), .none()) + + // release connection 1 should be closed as well + let releaseRequest2 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest2.request, .none) + XCTAssertEqual(releaseRequest2.connection, .closeConnection(connection1, [])) + + let shutdownAction = stateMachine.triggerForceShutdown() + XCTAssertEqual(shutdownAction.request, .failRequests(.init(), .poolShutdown)) + XCTAssertEqual(shutdownAction.connection, .shutdown(.init())) + } + + func testDemandConnectionIsMadePermanentIfPermanentIsClose() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + + // connection 1 is dropped + XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) + } +} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift similarity index 82% rename from Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift rename to Tests/ConnectionPoolModuleTests/TinyFastSequence.swift index a086341e..b3f8179d 100644 --- a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift @@ -3,7 +3,7 @@ import XCTest final class OneElementFastSequenceTests: XCTestCase { func testCountIsEmptyAndIterator() async { - var sequence = OneElementFastSequence() + var sequence = TinyFastSequence() XCTAssertEqual(sequence.count, 0) XCTAssertEqual(sequence.isEmpty, true) XCTAssertEqual(sequence.first, nil) @@ -26,24 +26,26 @@ final class OneElementFastSequenceTests: XCTestCase { } func testReserveCapacityIsForwarded() { - var emptySequence = OneElementFastSequence() + var emptySequence = TinyFastSequence() emptySequence.reserveCapacity(8) emptySequence.append(1) emptySequence.append(2) + emptySequence.append(3) guard case .n(let array) = emptySequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) - var oneElemSequence = OneElementFastSequence(element: 1) + var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) + oneElemSequence.append(3) guard case .n(let array) = oneElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) - var twoElemSequence = OneElementFastSequence([1, 2]) + var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") @@ -52,17 +54,17 @@ final class OneElementFastSequenceTests: XCTestCase { } func testNewSequenceSlowPath() { - let sequence = OneElementFastSequence("AB".utf8) + let sequence = TinyFastSequence("AB".utf8) XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) } func testSingleItem() { - let sequence = OneElementFastSequence("A".utf8) + let sequence = TinyFastSequence("A".utf8) XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) } func testEmptyCollection() { - let sequence = OneElementFastSequence("".utf8) + let sequence = TinyFastSequence("".utf8) XCTAssertTrue(sequence.isEmpty) XCTAssertEqual(sequence.count, 0) XCTAssertEqual(Array(sequence), []) From 468ae25f310e877b6613058e8ad2750cfe11f5d8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 27 Oct 2023 08:16:30 +0200 Subject: [PATCH 043/106] Land ConnectionPool (#428) --- .../ConnectionPoolModule/ConnectionPool.swift | 484 +++++++++++++++++- .../ConnectionRequest.swift | 53 ++ .../NIOLockedValueBox.swift | 46 ++ .../PoolStateMachine+ConnectionGroup.swift | 7 +- .../PoolStateMachine.swift | 61 ++- .../ConnectionPoolTests.swift | 189 +++++++ .../Mocks/MockClock.swift | 186 +++++++ .../Mocks/MockConnection.swift | 73 +++ .../Mocks/MockPingPongBehaviour.swift | 14 + ...oolStateMachine+ConnectionGroupTests.swift | 4 +- .../PoolStateMachineTests.swift | 42 ++ ...ence.swift => TinyFastSequenceTests.swift} | 2 +- 12 files changed, 1135 insertions(+), 26 deletions(-) create mode 100644 Sources/ConnectionPoolModule/NIOLockedValueBox.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename Tests/ConnectionPoolModuleTests/{TinyFastSequence.swift => TinyFastSequenceTests.swift} (97%) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 825c3ab3..5571e617 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -1,3 +1,17 @@ + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionAndMetadata { + + public var connection: Connection + + public var maximalStreamsOnConnection: UInt16 + + public init(connection: Connection, maximalStreamsOnConnection: UInt16) { + self.connection = connection + self.maximalStreamsOnConnection = maximalStreamsOnConnection + } +} + /// A connection that can be pooled in a ``ConnectionPool`` public protocol PooledConnection: AnyObject, Sendable { /// The connections identifier type. @@ -78,7 +92,7 @@ public protocol ConnectionRequestProtocol: Sendable { } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -public struct ConnectionPoolConfiguration { +public struct ConnectionPoolConfiguration: Sendable { /// The minimum number of connections to preserve in the pool. /// /// If the pool is mostly idle and the remote servers closes @@ -114,3 +128,471 @@ public struct ConnectionPoolConfiguration { self.idleTimeout = .seconds(60) } } + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class ConnectionPool< + Connection: PooledConnection, + ConnectionID: Hashable & Sendable, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + Request: ConnectionRequestProtocol, + RequestID: Hashable & Sendable, + KeepAliveBehavior: ConnectionKeepAliveBehavior, + ObservabilityDelegate: ConnectionPoolObservabilityDelegate, + Clock: _Concurrency.Clock +>: Sendable where + Connection.ID == ConnectionID, + ConnectionIDGenerator.ID == ConnectionID, + Request.Connection == Connection, + Request.ID == RequestID, + KeepAliveBehavior.Connection == Connection, + ObservabilityDelegate.ConnectionID == ConnectionID, + Clock.Duration == Duration +{ + public typealias ConnectionFactory = @Sendable (ConnectionID, ConnectionPool) async throws -> ConnectionAndMetadata + + @usableFromInline + typealias StateMachine = PoolStateMachine> + + @usableFromInline + let factory: ConnectionFactory + + @usableFromInline + let keepAliveBehavior: KeepAliveBehavior + + @usableFromInline + let observabilityDelegate: ObservabilityDelegate + + @usableFromInline + let clock: Clock + + @usableFromInline + let configuration: ConnectionPoolConfiguration + + @usableFromInline + struct State: Sendable { + @usableFromInline + var stateMachine: StateMachine + @usableFromInline + var lastConnectError: (any Error)? + } + + @usableFromInline let stateBox: NIOLockedValueBox + + private let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + + @usableFromInline + let eventStream: AsyncStream + + @usableFromInline + let eventContinuation: AsyncStream.Continuation + + public init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator, + requestType: Request.Type, + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock, + connectionFactory: @escaping ConnectionFactory + ) { + self.clock = clock + self.factory = connectionFactory + self.keepAliveBehavior = keepAliveBehavior + self.observabilityDelegate = observabilityDelegate + self.configuration = configuration + var stateMachine = StateMachine( + configuration: .init(configuration, keepAliveBehavior: keepAliveBehavior), + generator: idGenerator, + timerCancellationTokenType: CheckedContinuation.self + ) + + let (stream, continuation) = AsyncStream.makeStream(of: NewPoolActions.self) + self.eventStream = stream + self.eventContinuation = continuation + + let connectionRequests = stateMachine.refillConnections() + + self.stateBox = NIOLockedValueBox(.init(stateMachine: stateMachine)) + + for request in connectionRequests { + self.eventContinuation.yield(.makeConnection(request)) + } + } + + @inlinable + public func releaseConnection(_ connection: Connection, streams: UInt16 = 1) { + self.modifyStateAndRunActions { state in + state.stateMachine.releaseConnection(connection, streams: streams) + } + } + + @inlinable + public func leaseConnection(_ request: Request) { + self.modifyStateAndRunActions { state in + state.stateMachine.leaseConnection(request) + } + } + + @inlinable + public func leaseConnections(_ requests: some Collection) { + let actions = self.stateBox.withLockedValue { state in + var actions = [StateMachine.Action]() + actions.reserveCapacity(requests.count) + + for request in requests { + let stateMachineAction = state.stateMachine.leaseConnection(request) + actions.append(stateMachineAction) + } + + return actions + } + + for action in actions { + self.runRequestAction(action.request) + self.runConnectionAction(action.connection) + } + } + + public func cancelLeaseConnection(_ requestID: RequestID) { + self.modifyStateAndRunActions { state in + state.stateMachine.cancelRequest(id: requestID) + } + } + + /// Mark a connection as going away. Connection implementors have to call this method if the connection + /// has received a close intent from the server. For example: an HTTP/2 GOWAY frame. + public func connectionWillClose(_ connection: Connection) { + + } + + public func connection(_ connection: Connection, didReceiveNewMaxStreamSetting: UInt16) { + + } + + public func run() async { + await withTaskCancellationHandler { + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { + return await withDiscardingTaskGroup() { taskGroup in + await self.run(in: &taskGroup) + } + } + #endif + return await withTaskGroup(of: Void.self) { taskGroup in + await self.run(in: &taskGroup) + } + } onCancel: { + let actions = self.stateBox.withLockedValue { state in + state.stateMachine.triggerForceShutdown() + } + + self.runStateMachineActions(actions) + } + } + + // MARK: - Private Methods - + + @inlinable + func connectionDidClose(_ connection: Connection, error: (any Error)?) { + self.observabilityDelegate.connectionClosed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionClosed(connection) + } + } + + // MARK: Events + + @usableFromInline + enum NewPoolActions: Sendable { + case makeConnection(StateMachine.ConnectionRequest) + case closeConnection(Connection) + case runKeepAlive(Connection) + + case scheduleTimer(StateMachine.Timer) + } + + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) + private func run(in taskGroup: inout DiscardingTaskGroup) async { + for await event in self.eventStream { + self.runEvent(event, in: &taskGroup) + } + } + #endif + + private func run(in taskGroup: inout TaskGroup) async { + var running = 0 + for await event in self.eventStream { + running += 1 + self.runEvent(event, in: &taskGroup) + + if running == 100 { + _ = await taskGroup.next() + running -= 1 + } + } + } + + private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + switch event { + case .makeConnection(let request): + self.makeConnection(for: request, in: &taskGroup) + + case .runKeepAlive(let connection): + self.runKeepAlive(connection, in: &taskGroup) + + case .closeConnection(let connection): + self.closeConnection(connection) + + case .scheduleTimer(let timer): + self.runTimer(timer, in: &taskGroup) + } + } + + // MARK: Run actions + + @inlinable + /*private*/ func modifyStateAndRunActions(_ closure: (inout State) -> StateMachine.Action) { + let actions = self.stateBox.withLockedValue { state -> StateMachine.Action in + closure(&state) + } + self.runStateMachineActions(actions) + } + + @inlinable + /*private*/ func runStateMachineActions(_ actions: StateMachine.Action) { + self.runConnectionAction(actions.connection) + self.runRequestAction(actions.request) + } + + @inlinable + /*private*/ func runConnectionAction(_ action: StateMachine.ConnectionAction) { + switch action { + case .makeConnection(let request, let timers): + self.cancelTimers(timers) + self.eventContinuation.yield(.makeConnection(request)) + + case .runKeepAlive(let connection, let cancelContinuation): + cancelContinuation?.resume(returning: ()) + self.eventContinuation.yield(.runKeepAlive(connection)) + + case .scheduleTimers(let timers): + for timer in timers { + self.eventContinuation.yield(.scheduleTimer(timer)) + } + + case .cancelTimers(let timers): + self.cancelTimers(timers) + + case .closeConnection(let connection, let timers): + self.closeConnection(connection) + self.cancelTimers(timers) + + case .shutdown(let cleanup): + for connection in cleanup.connections { + self.closeConnection(connection) + } + self.cancelTimers(cleanup.timersToCancel) + + case .none: + break + } + } + + @inlinable + /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { + switch action { + case .leaseConnection(let requests, let connection): + for request in requests { + request.complete(with: .success(connection)) + } + + case .failRequest(let request, let error): + request.complete(with: .failure(error)) + + case .failRequests(let requests, let error): + for request in requests { request.complete(with: .failure(error)) } + + case .none: + break + } + } + + @inlinable + /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { + taskGroup.addTask { + self.observabilityDelegate.startedConnecting(id: request.connectionID) + + do { + let bundle = try await self.factory(request.connectionID, self) + self.connectionEstablished(bundle) + bundle.connection.onClose { + self.connectionDidClose(bundle.connection, error: $0) + } + } catch { + self.connectionEstablishFailed(error, for: request) + } + } + } + + @inlinable + /*private*/ func connectionEstablished(_ connectionBundle: ConnectionAndMetadata) { + self.observabilityDelegate.connectSucceeded(id: connectionBundle.connection.id, streamCapacity: connectionBundle.maximalStreamsOnConnection) + + self.modifyStateAndRunActions { state in + state.lastConnectError = nil + return state.stateMachine.connectionEstablished( + connectionBundle.connection, + maxStreams: connectionBundle.maximalStreamsOnConnection + ) + } + } + + @inlinable + /*private*/ func connectionEstablishFailed(_ error: Error, for request: StateMachine.ConnectionRequest) { + self.observabilityDelegate.connectFailed(id: request.connectionID, error: error) + + self.modifyStateAndRunActions { state in + state.lastConnectError = error + return state.stateMachine.connectionEstablishFailed(error, for: request) + } + } + + @inlinable + /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { + self.observabilityDelegate.keepAliveTriggered(id: connection.id) + + taskGroup.addTask { + do { + try await self.keepAliveBehavior.runKeepAlive(for: connection) + + self.observabilityDelegate.keepAliveSucceeded(id: connection.id) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionKeepAliveDone(connection) + } + } catch { + self.observabilityDelegate.keepAliveFailed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionClosed(connection) + } + } + } + } + + @inlinable + /*private*/ func closeConnection(_ connection: Connection) { + self.observabilityDelegate.connectionClosing(id: connection.id) + + connection.close() + } + + @usableFromInline + enum TimerRunResult { + case timerTriggered + case timerCancelled + case cancellationContinuationFinished + } + + @inlinable + /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { + poolGroup.addTask { () async -> () in + await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in + taskGroup.addTask { + do { + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + try await self.clock.sleep(for: timer.duration) + #else + try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) + #endif + return .timerTriggered + } catch { + return .timerCancelled + } + } + + taskGroup.addTask { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let continuation = self.stateBox.withLockedValue { state in + state.stateMachine.timerScheduled(timer, cancelContinuation: continuation) + } + + continuation?.resume(returning: ()) + } + + return .cancellationContinuationFinished + } + + switch await taskGroup.next()! { + case .cancellationContinuationFinished: + taskGroup.cancelAll() + + case .timerTriggered: + let action = self.stateBox.withLockedValue { state in + state.stateMachine.timerTriggered(timer) + } + + self.runStateMachineActions(action) + + case .timerCancelled: + // the only way to reach this, is if the state machine decided to cancel the + // timer. therefore we don't need to report it back! + break + } + + return + } + } + } + + @inlinable + /*private*/ func cancelTimers(_ cancellationTokens: some Sequence>) { + for token in cancellationTokens { + token.resume() + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolConfiguration { + init(_ configuration: ConnectionPoolConfiguration, keepAliveBehavior: KeepAliveBehavior) { + self.minimumConnectionCount = configuration.minimumConnectionCount + self.maximumConnectionSoftLimit = configuration.maximumConnectionSoftLimit + self.maximumConnectionHardLimit = configuration.maximumConnectionHardLimit + self.keepAliveDuration = keepAliveBehavior.keepAliveFrequency + self.idleTimeoutDuration = configuration.idleTimeout + } +} + +#if swift(<5.9) +// This should be removed once we support Swift 5.9+ only +extension AsyncStream { + static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) + } +} +#endif + +@usableFromInline +protocol TaskGroupProtocol { + mutating func addTask(operation: @escaping @Sendable () async -> Void) +} + +#if swift(>=5.8) && os(Linux) || swift(>=5.9) +@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 9.0, *) +extension DiscardingTaskGroup: TaskGroupProtocol {} +#endif + +extension TaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index fd01bb76..19ed9bd2 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -20,3 +20,56 @@ public struct ConnectionRequest: ConnectionRequest self.continuation.resume(with: result) } } + +fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPool where Request == ConnectionRequest { + public convenience init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator(), + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock = ContinuousClock(), + connectionFactory: @escaping ConnectionFactory + ) { + self.init( + configuration: configuration, + idGenerator: idGenerator, + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAliveBehavior, + observabilityDelegate: observabilityDelegate, + clock: clock, + connectionFactory: connectionFactory + ) + } + + public func leaseConnection() async throws -> Connection { + let requestID = requestIDGenerator.next() + + let connection = try await withTaskCancellationHandler { + if Task.isCancelled { + throw CancellationError() + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = Request( + id: requestID, + continuation: continuation + ) + + self.leaseConnection(request) + } + } onCancel: { + self.cancelLeaseConnection(requestID) + } + + return connection + } + + public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + defer { self.releaseConnection(connection) } + return try await closure(connection) + } +} diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift new file mode 100644 index 00000000..e5a3e6a2 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -0,0 +1,46 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// Provides locked access to `Value`. +/// +/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// alongside a lock behind a reference. +/// +/// This is no different than creating a ``Lock`` and protecting all +/// accesses to a value using the lock. But it's easy to forget to actually +/// acquire/release the lock in the correct place. ``NIOLockedValueBox`` makes +/// that much easier. +@usableFromInline +struct NIOLockedValueBox { + + @usableFromInline + internal let _storage: LockStorage + + /// Initialize the `Value`. + @inlinable + init(_ value: Value) { + self._storage = .create(value: value) + } + + /// Access the `Value`, allowing mutation of it. + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + return try self._storage.withLockedValue(mutate) + } +} + +extension NIOLockedValueBox: Sendable where Value: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 16970599..e735d277 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -342,9 +342,9 @@ extension PoolStateMachine { /// Call ``leaseConnection(at:)`` or ``closeConnection(at:)`` with the supplied index after /// this. If you want to park the connection no further call is required. @inlinable - mutating func releaseConnection(_ connectionID: Connection.ID, streams: UInt16) -> (Int, AvailableConnectionContext) { + mutating func releaseConnection(_ connectionID: Connection.ID, streams: UInt16) -> (Int, AvailableConnectionContext)? { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - preconditionFailure("A connection that we don't know was released? Something is very wrong...") + return nil } let connectionInfo = self.connections[index].release(streams: streams) @@ -657,3 +657,6 @@ extension PoolStateMachine { @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension PoolStateMachine.ConnectionGroup.BackoffDoneAction: Equatable where TimerCancellationToken: Equatable {} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionGroup.ClosedAction: Equatable where TimerCancellationToken: Equatable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index aa62d749..4cd78c0e 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -234,7 +234,9 @@ struct PoolStateMachine< @inlinable mutating func releaseConnection(_ connection: Connection, streams: UInt16) -> Action { - let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) + guard let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) else { + return .none() + } return self.handleAvailableConnection(index: index, availableContext: context) } @@ -251,8 +253,13 @@ struct PoolStateMachine< @inlinable mutating func connectionEstablished(_ connection: Connection, maxStreams: UInt16) -> Action { - let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) - return self.handleAvailableConnection(index: index, availableContext: context) + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) + return self.handleAvailableConnection(index: index, availableContext: context) + case .shuttingDown(graceful: false), .shutDown: + return .init(request: .none, connection: .closeConnection(connection, [])) + } } @inlinable @@ -274,31 +281,43 @@ struct PoolStateMachine< @inlinable mutating func connectionEstablishFailed(_ error: Error, for request: ConnectionRequest) -> Action { - self.failedConsecutiveConnectionAttempts += 1 + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.failedConsecutiveConnectionAttempts += 1 - let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) - let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) - let timer = Timer(connectionTimer, duration: backoff) - return .init(request: .none, connection: .scheduleTimers(.init(timer))) + let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) + let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + let timer = Timer(connectionTimer, duration: backoff) + return .init(request: .none, connection: .scheduleTimers(.init(timer))) + + case .shuttingDown(graceful: false), .shutDown: + return .none() + } } @inlinable mutating func connectionCreationBackoffDone(_ connectionID: ConnectionID) -> Action { - let soonAvailable = self.connections.soonAvailableConnections - let retry = (soonAvailable - 1) < self.requestQueue.count - - switch self.connections.backoffDone(connectionID, retry: retry) { - case .createConnection(let request, let continuation): - let timers: TinyFastSequence - if let continuation { - timers = .init(element: continuation) - } else { - timers = .init() + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let soonAvailable = self.connections.soonAvailableConnections + let retry = (soonAvailable - 1) < self.requestQueue.count + + switch self.connections.backoffDone(connectionID, retry: retry) { + case .createConnection(let request, let continuation): + let timers: TinyFastSequence + if let continuation { + timers = .init(element: continuation) + } else { + timers = .init() + } + return .init(request: .none, connection: .makeConnection(request, timers)) + + case .cancelTimers(let timers): + return .init(request: .none, connection: .cancelTimers(.init(timers))) } - return .init(request: .none, connection: .makeConnection(request, timers)) - case .cancelTimers(let timers): - return .init(request: .none, connection: .cancelTimers(.init(timers))) + case .shuttingDown(graceful: false), .shutDown: + return .none() } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift new file mode 100644 index 00000000..b27fff37 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -0,0 +1,189 @@ +@testable import _ConnectionPoolModule +import XCTest +import NIOEmbedded + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class ConnectionPoolTests: XCTestCase { + + func test1000ConsecutiveRequestsOnSingleConnection() async { + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + // the same connection is reused 1000 times + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let createdConnection = await factory.nextConnectAttempt { _ in + return 1 + } + XCTAssertNotNil(createdConnection) + + do { + for _ in 0..<1000 { + async let connectionFuture = try await pool.leaseConnection() + var leasedConnection: MockConnection? + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + leasedConnection = try await connectionFuture + XCTAssertNotNil(leasedConnection) + XCTAssert(createdConnection === leasedConnection) + + if let leasedConnection { + pool.releaseConnection(leasedConnection) + } + } + } catch { + XCTFail("Unexpected error: \(error)") + } + + taskGroup.cancelAll() + } + } + + func testShutdownPoolWhileConnectionIsBeingCreated() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) + let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) + + taskGroup.addTask { + _ = try? await factory.nextConnectAttempt { _ in + blockCancelContinuation.yield() + var iterator = blockConnCreationStream.makeAsyncIterator() + await iterator.next() + throw ConnectionCreationError() + } + } + + var iterator = blockCancelStream.makeAsyncIterator() + await iterator.next() + + taskGroup.cancelAll() + blockConnCreationContinuation.yield() + } + + struct ConnectionCreationError: Error {} + } + + func testShutdownPoolWhileConnectionIsBackingOff() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + _ = try? await factory.nextConnectAttempt { _ in + throw ConnectionCreationError() + } + + await clock.timerScheduled() + + taskGroup.cancelAll() + } + + struct ConnectionCreationError: Error {} + } + + func testConnectionHardLimitIsRespected() async { + let factory = MockConnectionFactory() + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 8 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + // the same connection is reused 1000 times + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + taskGroup.addTask { + var usedConnectionIDs = Set() + for _ in 0.. Self { + .init(self.base + duration) + } + + func duration(to other: Self) -> Self.Duration { + self.base - other.base + } + + private var base: Swift.Duration + + init(_ base: Duration) { + self.base = base + } + + static func < (lhs: Self, rhs: Self) -> Bool { + lhs.base < rhs.base + } + + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.base == rhs.base + } + } + + private struct State: Sendable { + var now: Instant + + var sleepersHeap: Array + + var waitersHeap: Array + + init() { + self.now = .init(.seconds(0)) + self.sleepersHeap = Array() + self.waitersHeap = Array() + } + } + + private struct Waiter { + var expectedSleepers: Int + + var continuation: CheckedContinuation + } + + private struct Sleeper { + var id: Int + + var deadline: Instant + + var continuation: CheckedContinuation + } + + typealias Duration = Swift.Duration + + var minimumResolution: Duration { .nanoseconds(1) } + + var now: Instant { self.stateBox.withLockedValue { $0.now } } + + private let stateBox = NIOLockedValueBox(State()) + private let waiterIDGenerator = ManagedAtomic(0) + + func sleep(until deadline: Instant, tolerance: Duration?) async throws { + let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + enum SleepAction { + case none + case resume + case cancel + } + + let action = self.stateBox.withLockedValue { state -> (SleepAction, ArraySlice) in + state.waitersHeap = state.waitersHeap.map { waiter in + var waiter = waiter; waiter.expectedSleepers -= 1; return waiter + } + let slice: ArraySlice + let lastRemainingIndex = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > 0 }) + if let lastRemainingIndex { + slice = state.waitersHeap[0..= deadline { + return (.resume, slice) + } + + let newWaiter = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) + + if let index = state.sleepersHeap.lastIndex(where: { $0.deadline < deadline }) { + state.sleepersHeap.insert(newWaiter, at: index + 1) + } else { + state.sleepersHeap.append(newWaiter) + } + + return (.none, slice) + } + + switch action.0 { + case .cancel: + continuation.resume(throwing: CancellationError()) + case .resume: + continuation.resume() + case .none: + break + } + + for waiter in action.1 { + waiter.continuation.resume() + } + } + } onCancel: { + let continuation = self.stateBox.withLockedValue { state -> CheckedContinuation? in + if let index = state.sleepersHeap.firstIndex(where: { $0.id == waiterID }) { + return state.sleepersHeap.remove(at: index).continuation + } + return nil + } + continuation?.resume(throwing: CancellationError()) + } + } + + func timerScheduled(n: Int = 1) async { + precondition(n >= 1, "At least one new sleep must be awaited") + await withCheckedContinuation { (continuation: CheckedContinuation<(), Never>) in + let result = self.stateBox.withLockedValue { state -> Bool in + let n = n - state.sleepersHeap.count + + if n <= 0 { + return true + } + + let waiter = Waiter(expectedSleepers: n, continuation: continuation) + + if let index = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > n }) { + state.waitersHeap.insert(waiter, at: index) + } else { + state.waitersHeap.append(waiter) + } + return false + } + + if result { + continuation.resume() + } + } + } + + func advance(to deadline: Instant) { + let waiters = self.stateBox.withLockedValue { state -> ArraySlice in + precondition(deadline > state.now, "Time can only move forward") + state.now = deadline + + if let newFirstIndex = state.sleepersHeap.firstIndex(where: { $0.deadline > deadline }) { + defer { state.sleepersHeap.removeFirst(newFirstIndex) } + return state.sleepersHeap[0.. where Clock.Duration == Duration { + typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + typealias Request = ConnectionRequest + typealias KeepAliveBehavior = MockPingPongBehavior + typealias MetricsDelegate = NoOpConnectionPoolMetrics + typealias ConnectionID = Int + typealias Connection = MockConnection + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() + + var waiter = Deque), Never>>() + } + + var pendingConnectionAttemptsCount: Int { + self.stateBox.withLockedValue { $0.attempts.count } + } + + func makeConnection( + id: Int, + for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + ) async throws -> ConnectionAndMetadata { + // we currently don't support cancellation when creating a connection + let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.attempts.append((id, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (id, checkedContinuation)) + } + } + + return .init(connection: result.0, maximalStreamsOnConnection: result.1) + } + + @discardableResult + func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in + let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in + if let attempt = state.attempts.popFirst() { + return attempt + } else { + state.waiter.append(continuation) + return nil + } + } + + if let attempt { + continuation.resume(returning: attempt) + } + } + + do { + let streamCount = try await closure(connectionID) + let connection = MockConnection(id: connectionID) + continuation.resume(returning: (connection, streamCount)) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift new file mode 100644 index 00000000..2ee9b7a0 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift @@ -0,0 +1,14 @@ +import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct MockPingPongBehavior: ConnectionKeepAliveBehavior { + let keepAliveFrequency: Duration? + + init(keepAliveFrequency: Duration?) { + self.keepAliveFrequency = keepAliveFrequency + } + + func runKeepAlive(for connection: MockConnection) async throws { + preconditionFailure() + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index bf385918..99b73fd0 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -88,7 +88,9 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssert(newConnection === leaseResult.connection) XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) - let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) + guard let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) else { + return XCTFail("Expected that this connection is still active") + } XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) XCTAssertEqual(releasedContext.use, .demand) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 0f3af728..a19d2326 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -224,4 +224,46 @@ final class PoolStateMachineTests: XCTestCase { // connection 1 is dropped XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) } + + func testReleaseLoosesRaceAgainstClosed() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // connection got closed + let closedAction = stateMachine.connectionClosed(connection1) + XCTAssertEqual(closedAction.connection, .none) + XCTAssertEqual(closedAction.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .none) + XCTAssertEqual(releaseRequest1.connection, .none) + } + } diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift similarity index 97% rename from Tests/ConnectionPoolModuleTests/TinyFastSequence.swift rename to Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index b3f8179d..1a2836b9 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -1,7 +1,7 @@ @testable import _ConnectionPoolModule import XCTest -final class OneElementFastSequenceTests: XCTestCase { +final class TinyFastSequenceTests: XCTestCase { func testCountIsEmptyAndIterator() async { var sequence = TinyFastSequence() XCTAssertEqual(sequence.count, 0) From add68a0aed8d794a5608318452495621d038b255 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 28 Oct 2023 15:23:47 +0200 Subject: [PATCH 044/106] Ensure pool runs until all connections are closed (#429) - Ensure pool runs until all connections are closed - Fix an ordering issue in `RequestQueue` - Remove unused `closeConnection` in NewPoolActions --- .../ConnectionPoolModule/ConnectionPool.swift | 15 ++++--- .../PoolStateMachine+RequestQueue.swift | 2 +- .../PoolStateMachine.swift | 24 ++++++---- .../ConnectionPoolTests.swift | 44 +++++++++++++++++-- .../Mocks/MockConnection.swift | 17 +++++++ 5 files changed, 82 insertions(+), 20 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 5571e617..e9c9c4c9 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -306,7 +306,6 @@ public final class ConnectionPool< @usableFromInline enum NewPoolActions: Sendable { case makeConnection(StateMachine.ConnectionRequest) - case closeConnection(Connection) case runKeepAlive(Connection) case scheduleTimer(StateMachine.Timer) @@ -342,9 +341,6 @@ public final class ConnectionPool< case .runKeepAlive(let connection): self.runKeepAlive(connection, in: &taskGroup) - case .closeConnection(let connection): - self.closeConnection(connection) - case .scheduleTimer(let timer): self.runTimer(timer, in: &taskGroup) } @@ -427,8 +423,15 @@ public final class ConnectionPool< do { let bundle = try await self.factory(request.connectionID, self) self.connectionEstablished(bundle) - bundle.connection.onClose { - self.connectionDidClose(bundle.connection, error: $0) + + // after the connection has been established, we keep the task open. This ensures + // that the pools run method can not be exited before all connections have been + // closed. + await withCheckedContinuation { (continuation: CheckedContinuation) in + bundle.connection.onClose { + self.connectionDidClose(bundle.connection, error: $0) + continuation.resume() + } } } catch { self.connectionEstablishFailed(error, for: request) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift index f1d6f4e4..99ec4896 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -44,7 +44,7 @@ extension PoolStateMachine { var result = TinyFastSequence() result.reserveCapacity(Int(max)) var popped = 0 - while let requestID = self.queue.popFirst(), popped < max { + while popped < max, let requestID = self.queue.popFirst() { if let requestIndex = self.requests.index(forKey: requestID) { popped += 1 result.append(self.requests.remove(at: requestIndex).value) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 4cd78c0e..4b3680a1 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -355,18 +355,24 @@ struct PoolStateMachine< @inlinable mutating func connectionClosed(_ connection: Connection) -> Action { - self.cacheNoMoreConnectionsAllowed = false + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.cacheNoMoreConnectionsAllowed = false - let closedConnectionAction = self.connections.connectionClosed(connection.id) + let closedConnectionAction = self.connections.connectionClosed(connection.id) - let connectionAction: ConnectionAction - if let newRequest = closedConnectionAction.newConnectionRequest { - connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) - } else { - connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) - } + let connectionAction: ConnectionAction + if let newRequest = closedConnectionAction.newConnectionRequest { + connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) + } else { + connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) + } + + return .init(request: .none, connection: connectionAction) - return .init(request: .none, connection: connectionAction) + case .shuttingDown(graceful: false), .shutDown: + return .none() + } } struct CleanupAction { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index b27fff37..5be12a1c 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import Atomics import XCTest import NIOEmbedded @@ -52,7 +53,14 @@ final class ConnectionPoolTests: XCTestCase { } taskGroup.cancelAll() + + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + for connection in factory.runningConnections { + connection.closeIfClosing() + } } + + XCTAssertEqual(factory.runningConnections.count, 0) } func testShutdownPoolWhileConnectionIsBeingCreated() async { @@ -155,11 +163,16 @@ final class ConnectionPoolTests: XCTestCase { try await factory.makeConnection(id: $0, for: $1) } + let hasFinished = ManagedAtomic(false) + let createdConnections = ManagedAtomic(0) + let iterations = 10_000 + // the same connection is reused 1000 times - await withThrowingTaskGroup(of: Void.self) { taskGroup in + await withTaskGroup(of: Void.self) { taskGroup in taskGroup.addTask { await pool.run() + XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) } taskGroup.addTask { @@ -167,22 +180,45 @@ final class ConnectionPoolTests: XCTestCase { for _ in 0.. where Clock.Duratio var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() var waiter = Deque), Never>>() + + var runningConnections = [ConnectionID: Connection]() } var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } + var runningConnections: [Connection] { + self.stateBox.withLockedValue { Array($0.runningConnections.values) } + } + func makeConnection( id: Int, for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> @@ -137,6 +143,17 @@ final class MockConnectionFactory where Clock.Duratio do { let streamCount = try await closure(connectionID) let connection = MockConnection(id: connectionID) + + connection.onClose { _ in + self.stateBox.withLockedValue { state in + _ = state.runningConnections.removeValue(forKey: connectionID) + } + } + + self.stateBox.withLockedValue { state in + _ = state.runningConnections[connectionID] = connection + } + continuation.resume(returning: (connection, streamCount)) return connection } catch { From 2905779f4a0ccf7fa59e1e8e951b7a1c31e689e3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 30 Oct 2023 11:01:48 +0100 Subject: [PATCH 045/106] Land PostgresClient that is backed by a ConnectionPool as SPI (#430) --- .../PoolStateMachine+ConnectionGroup.swift | 9 +- .../PoolStateMachine.swift | 4 +- .../Connection/PostgresConnection.swift | 4 +- .../ConnectionStateMachine.swift | 2 +- Sources/PostgresNIO/New/PSQLError.swift | 16 +- .../PostgresNIO/Pool/ConnectionFactory.swift | 206 ++++++++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 378 ++++++++++++++++++ .../Pool/PostgresClientMetrics.swift | 85 ++++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 + ...oolStateMachine+ConnectionGroupTests.swift | 6 +- .../PostgresClientTests.swift | 66 +++ 11 files changed, 764 insertions(+), 14 deletions(-) create mode 100644 Sources/PostgresNIO/Pool/ConnectionFactory.swift create mode 100644 Sources/PostgresNIO/Pool/PostgresClient.swift create mode 100644 Sources/PostgresNIO/Pool/PostgresClientMetrics.swift create mode 100644 Tests/IntegrationTests/PostgresClientTests.swift diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index e735d277..b53f8d68 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -308,7 +308,7 @@ extension PoolStateMachine { } @inlinable - mutating func parkConnection(at index: Int) -> Max2Sequence { + mutating func parkConnection(at index: Int, hasBecomeIdle newIdle: Bool) -> Max2Sequence { let scheduleIdleTimeoutTimer: Bool switch index { case 0.. + + struct SSLContextCache: Sendable { + enum State { + case none + case producing(TLSConfiguration, [CheckedContinuation]) + case cached(TLSConfiguration, NIOSSLContext) + case failed(TLSConfiguration, any Error) + } + + var state: State = .none + } + + let sslContextBox = NIOLockedValueBox(SSLContextCache()) + + let eventLoopGroup: any EventLoopGroup + + let logger: Logger + + init(config: PostgresClient.Configuration, eventLoopGroup: any EventLoopGroup, logger: Logger) { + self.eventLoopGroup = eventLoopGroup + self.configBox = NIOLockedValueBox(ConfigCache(config: config)) + self.logger = logger + } + + func makeConnection(_ connectionID: PostgresConnection.ID, pool: PostgresClient.Pool) async throws -> PostgresConnection { + let config = try await self.makeConnectionConfig() + + var connectionLogger = self.logger + connectionLogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + + return try await PostgresConnection.connect( + on: self.eventLoopGroup.any(), + configuration: config, + id: connectionID, + logger: connectionLogger + ).get() + } + + func makeConnectionConfig() async throws -> PostgresConnection.Configuration { + let config = self.configBox.withLockedValue { $0.config } + + let tls: PostgresConnection.Configuration.TLS + switch config.tls.base { + case .prefer(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .prefer(sslContext) + + case .require(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .require(sslContext) + case .disable: + tls = .disable + } + + var connectionConfig: PostgresConnection.Configuration + switch config.endpointInfo { + case .bindUnixDomainSocket(let path): + connectionConfig = PostgresConnection.Configuration( + unixSocketPath: path, + username: config.username, + password: config.password, + database: config.database + ) + + case .connectTCP(let host, let port): + connectionConfig = PostgresConnection.Configuration( + host: host, + port: port, + username: config.username, + password: config.password, + database: config.database, + tls: tls + ) + } + + connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) + connectionConfig.options.tlsServerName = config.options.tlsServerName + connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + + return connectionConfig + } + + private func getSSLContext(for tlsConfiguration: TLSConfiguration) async throws -> NIOSSLContext { + enum Action { + case produce + case succeed(NIOSSLContext) + case fail(any Error) + case wait + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + + case .cached(let cachedTLSConfiguration, let context): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .succeed(context) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .failed(let cachedTLSConfiguration, let error): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .fail(error) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .producing(let cachedTLSConfiguration, var continuations): + continuations.append(continuation) + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + cache.state = .producing(cachedTLSConfiguration, continuations) + return .wait + } else { + cache.state = .producing(tlsConfiguration, continuations) + return .produce + } + } + } + + switch action { + case .wait: + break + + case .produce: + // TBD: we might want to consider moving this off the concurrent executor + self.reportProduceSSLContextResult( + Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), + for: tlsConfiguration + ) + + case .succeed(let context): + continuation.resume(returning: context) + + case .fail(let error): + continuation.resume(throwing: error) + } + } + } + + private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + enum Action { + case fail(any Error, [CheckedContinuation]) + case succeed(NIOSSLContext, [CheckedContinuation]) + case none + } + + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + preconditionFailure("Invalid state: \(cache.state)") + + case .cached, .failed: + return .none + + case .producing(let cachedTLSConfiguration, let continuations): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + switch result { + case .success(let context): + cache.state = .cached(cachedTLSConfiguration, context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(cachedTLSConfiguration, failure) + return .fail(failure, continuations) + } + } else { + return .none + } + } + } + + switch action { + case .none: + break + + case .succeed(let context, let continuations): + for continuation in continuations { + continuation.resume(returning: context) + } + + case .fail(let error, let continuations): + for continuation in continuations { + continuation.resume(throwing: error) + } + } + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift new file mode 100644 index 00000000..fc5a5b00 --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -0,0 +1,378 @@ +import NIOCore +import NIOSSL +import Atomics +import Logging +import _ConnectionPoolModule + +/// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's +/// behavior. +/// +/// > Important: +/// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: +/// +/// ```swift +/// let client = PostgresClient(configuration: configuration, logger: logger) +/// await withTaskGroup(of: Void.self) { +/// taskGroup.addTask { +/// client.run() // !important +/// } +/// +/// taskGroup.addTask { +/// client.withConnection { connection in +/// do { +/// let rows = try await connection.query("SELECT userID, name, age FROM users;") +/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { +/// // do something with the values +/// } +/// } catch { +/// // handle errors +/// } +/// } +/// } +/// } +/// ``` +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +@_spi(ConnectionPool) +public final class PostgresClient: Sendable { + public struct Configuration: Sendable { + public struct TLS: Sendable { + enum Base { + case disable + case prefer(NIOSSL.TLSConfiguration) + case require(NIOSSL.TLSConfiguration) + } + + var base: Base + + private init(_ base: Base) { + self.base = base + } + + /// Do not try to create a TLS connection to the server. + public static var disable: Self = Self.init(.disable) + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.prefer(sslContext)) + } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.require(sslContext)) + } + } + + // MARK: Client options + + /// Describes general client behavior options. Those settings are considered advanced options. + public struct Options: Sendable { + /// A keep-alive behavior for Postgres connections. The ``frequency`` defines after which time an idle + /// connection shall run a keep-alive ``query``. + public struct KeepAliveBehavior: Sendable { + /// The amount of time that shall pass before an idle connection runs a keep-alive ``query``. + public var frequency: Duration + + /// The ``query`` that is run on an idle connection after it has been idle for ``frequency``. + public var query: PostgresQuery + + /// Create a new `KeepAliveBehavior`. + /// - Parameters: + /// - frequency: The amount of time that shall pass before an idle connection runs a keep-alive `query`. + /// Defaults to `30` seconds. + /// - query: The `query` that is run on an idle connection after it has been idle for `frequency`. + /// Defaults to `SELECT 1;`. + public init(frequency: Duration = .seconds(30), query: PostgresQuery = "SELECT 1;") { + self.frequency = frequency + self.query = query + } + } + + /// A timeout for creating a TCP/Unix domain socket connection. Defaults to `10` seconds. + public var connectTimeout: Duration = .seconds(10) + + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. + /// Defaults to none (but see below). + /// + /// > When set to `nil`: + /// If the connection is made to a server over TCP using + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other + /// method, SNI is disabled. + public var tlsServerName: String? = nil + + /// Whether the connection is required to provide backend key data (internal Postgres stuff). + /// + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). + public var requireBackendKeyData: Bool = true + + /// The minimum number of connections that the client shall keep open at any time, even if there is no + /// demand. Default to `0`. + /// + /// If the open connection count becomes less than ``minimumConnections`` new connections + /// are created immidiatly. Must be greater or equal to zero and less than ``maximumConnections``. + /// + /// Idle connections are kept alive using the ``keepAliveBehavior``. + public var minimumConnections: Int = 0 + + /// The maximum number of connections that the client may open to the server at any time. Must be greater + /// than ``minimumConnections``. Defaults to `20` connections. + /// + /// Connections, that are created in response to demand are kept alive for the ``connectionIdleTimeout`` + /// before they are dropped. + public var maximumConnections: Int = 20 + + /// The maximum amount time that a connection that is not part of the ``minimumConnections`` is kept + /// open without being leased. Defaults to `60` seconds. + public var connectionIdleTimeout: Duration = .seconds(60) + + /// The ``KeepAliveBehavior-swift.struct`` to ensure that the underlying tcp-connection is still active + /// for idle connections. `Nil` means that the client shall not run keep alive queries to the server. Defaults to a + /// keep alive query of `SELECT 1;` every `30` seconds. + public var keepAliveBehavior: KeepAliveBehavior? = KeepAliveBehavior() + + /// Create an options structure with default values. + /// + /// Most users should not need to adjust the defaults. + public init() {} + } + + // MARK: - Accessors + + /// The hostname to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var host: String? { + if case let .connectTCP(host, _) = self.endpointInfo { return host } + else { return nil } + } + + /// The port to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var port: Int? { + if case let .connectTCP(_, port) = self.endpointInfo { return port } + else { return nil } + } + + /// The socket path to connect to for Unix domain socket connections. + /// + /// Always `nil` for other configurations. + public var unixSocketPath: String? { + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } + else { return nil } + } + + /// The TLS mode to use for the connection. Valid for all configurations. + /// + /// See ``TLS-swift.struct``. + public var tls: TLS = .prefer(.makeClientConfiguration()) + + /// Options for handling the communication channel. Most users don't need to change these. + /// + /// See ``Options-swift.struct``. + public var options: Options = .init() + + /// The username to connect with. + public var username: String + + /// The password, if any, for the user specified by ``username``. + /// + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero + /// length; these are not the same thing. + public var password: String? + + /// The name of the database to open. + /// + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. + public var database: String? + + // MARK: - Initializers + + /// Create a configuration for connecting to a server with a hostname and optional port. + /// + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost + /// definitely want this one. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + /// - tls: The TLS mode to use. + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for connecting to a server through a UNIX domain socket. + /// + /// - Parameters: + /// - path: The filesystem path of the socket to connect to. + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(unixSocketPath: String, username: String, password: String?, database: String?) { + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) + } + + // MARK: - Implementation details + + enum EndpointInfo { + case bindUnixDomainSocket(path: String) + case connectTCP(host: String, port: Int) + } + + var endpointInfo: EndpointInfo + + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { + self.endpointInfo = endpointInfo + self.tls = tls + self.username = username + self.password = password + self.database = database + } + } + + typealias Pool = ConnectionPool< + PostgresConnection, + PostgresConnection.ID, + ConnectionIDGenerator, + ConnectionRequest, + ConnectionRequest.ID, + PostgresKeepAliveBehavor, + PostgresClientMetrics, + ContinuousClock + > + + let pool: Pool + let factory: ConnectionFactory + let runningAtomic = ManagedAtomic(false) + let backgroundLogger: Logger + + /// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task. + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + /// - backgroundLogger: A `swift-log` `Logger` to log background messages to. A copy of this logger is also + /// forwarded to the created connections as a background logger. + public init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup, + backgroundLogger: Logger + ) { + let factory = ConnectionFactory(config: configuration, eventLoopGroup: eventLoopGroup, logger: backgroundLogger) + self.factory = factory + self.backgroundLogger = backgroundLogger + + self.pool = ConnectionPool( + configuration: .init(configuration), + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: .init(configuration.options.keepAliveBehavior, logger: backgroundLogger), + observabilityDelegate: .init(logger: backgroundLogger), + clock: ContinuousClock() + ) { (connectionID, pool) in + let connection = try await factory.makeConnection(connectionID, pool: pool) + + return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) + } + } + + + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + /// The client's run method. Users must call this function in order to start the client's background task processing + /// like creating and destroying connections and running timers. + /// + /// Calls to ``withConnection(_:)`` will emit a `logger` warning, if ``run()`` hasn't been called previously. + public func run() async { + let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) + precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") + await self.pool.run() + } + + // MARK: - Private Methods - + + private func leaseConnection() async throws -> PostgresConnection { + if !self.runningAtomic.load(ordering: .relaxed) { + self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") + } + return try await self.pool.leaseConnection() + } + + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { + PostgresConnection.defaultEventLoopGroup + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PostgresKeepAliveBehavor: ConnectionKeepAliveBehavior { + let behavior: PostgresClient.Configuration.Options.KeepAliveBehavior? + let logger: Logger + + init(_ behavior: PostgresClient.Configuration.Options.KeepAliveBehavior?, logger: Logger) { + self.behavior = behavior + self.logger = logger + } + + var keepAliveFrequency: Duration? { + self.behavior?.frequency + } + + func runKeepAlive(for connection: PostgresConnection) async throws { + try await connection.query(self.behavior!.query, logger: self.logger).map { _ in }.get() + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPoolConfiguration { + init(_ config: PostgresClient.Configuration) { + self = ConnectionPoolConfiguration() + self.minimumConnectionCount = config.options.minimumConnections + self.maximumConnectionSoftLimit = config.options.maximumConnections + self.maximumConnectionHardLimit = config.options.maximumConnections + self.idleTimeout = config.options.connectionIdleTimeout + } +} + +@_spi(ConnectionPool) +extension PostgresConnection: PooledConnection { + public func close() { + self.channel.close(mode: .all, promise: nil) + } + + public func onClose(_ closure: @escaping ((any Error)?) -> ()) { + self.closeFuture.whenComplete { _ in closure(nil) } + } +} + +extension ConnectionPoolError { + func mapToPSQLError(lastConnectError: Error?) -> Error { + var psqlError: PSQLError + switch self { + case .poolShutdown: + psqlError = PSQLError.poolClosed + psqlError.underlying = self + + case .requestCancelled: + psqlError = PSQLError.queryCancelled + psqlError.underlying = self + + default: + return self + } + return psqlError + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift new file mode 100644 index 00000000..aa8215db --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -0,0 +1,85 @@ +import _ConnectionPoolModule +import Logging + +final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { + typealias ConnectionID = PostgresConnection.ID + + let logger: Logger + + init(logger: Logger) { + self.logger = logger + } + + func startedConnecting(id: ConnectionID) { + self.logger.debug("Creating new connection", metadata: [ + .connectionID: "\(id)", + ]) + } + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) { + self.logger.debug("Connection creation failed", metadata: [ + .connectionID: "\(id)", + .error: "\(String(reflecting: error))" + ]) + } + + func connectSucceeded(id: ConnectionID) { + self.logger.debug("Connection established", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionLeased(id: ConnectionID) { + self.logger.debug("Connection leased", metadata: [ + .connectionID: "\(id)" + ]) + } + + func connectionReleased(id: ConnectionID) { + self.logger.debug("Connection released", metadata: [ + .connectionID: "\(id)" + ]) + } + + func keepAliveTriggered(id: ConnectionID) { + self.logger.debug("run ping pong", metadata: [ + .connectionID: "\(id)", + ]) + } + + func keepAliveSucceeded(id: ConnectionID) {} + + func keepAliveFailed(id: PostgresConnection.ID, error: Error) {} + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) { + self.logger.debug("Close connection", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) { + self.logger.debug("Connection closed", metadata: [ + .connectionID: "\(id)" + ]) + } + + func requestQueueDepthChanged(_ newDepth: Int) { + + } + + func connectSucceeded(id: PostgresConnection.ID, streamCapacity: UInt16) { + + } + + func connectionUtilizationChanged(id: PostgresConnection.ID, streamsUsed: UInt16, streamCapacity: UInt16) { + + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index c4f30624..7d464c2b 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -46,6 +46,8 @@ extension PSQLError { return self.underlying ?? self case .uncleanShutdown: return PostgresError.protocol("Unexpected connection close") + case .poolClosed: + return self } } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 99b73fd0..ac0f96f4 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -95,7 +95,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(releasedContext.use, .demand) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - let parkTimers = connections.parkConnection(at: index) + let parkTimers = connections.parkConnection(at: index, hasBecomeIdle: true) XCTAssertEqual(parkTimers, [ .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), @@ -199,7 +199,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) - XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex), [thirdConnKeepTimer, thirdConnIdleTimer]) + XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true), [thirdConnKeepTimer, thirdConnIdleTimer]) XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) @@ -277,7 +277,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) XCTAssertEqual(establishedConnectionContext.use, .persisted) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - let timers = connections.parkConnection(at: connectionIndex) + let timers = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) XCTAssertEqual(timers, [keepAliveTimer]) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift new file mode 100644 index 00000000..b1e7f9a8 --- /dev/null +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -0,0 +1,66 @@ +@_spi(ConnectionPool) import PostgresNIO +import XCTest +import NIOPosix +import NIOSSL +import Logging +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PostgresClientTests: XCTestCase { + + func testGetConnection() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + for i in 0..<10000 { + taskGroup.addTask { + try await client.withConnection() { connection in + _ = try await connection.query("SELECT 1", logger: logger) + } + print("done: \(i)") + } + } + + for _ in 0..<10000 { + _ = await taskGroup.nextResult()! + } + + taskGroup.cancelAll() + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PostgresClient.Configuration { + static func makeTestConfiguration() -> PostgresClient.Configuration { + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + var clientConfig = PostgresClient.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap({ Int($0) }) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database", + tls: .prefer(tlsConfiguration) + ) + clientConfig.options.minimumConnections = 0 + clientConfig.options.maximumConnections = 12*4 + clientConfig.options.keepAliveBehavior = .init(frequency: .seconds(5)) + clientConfig.options.connectionIdleTimeout = .seconds(15) + + return clientConfig + } +} From 21473f547ab195da56dca4bd203d7d2f150c48c1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 30 Oct 2023 14:32:31 +0100 Subject: [PATCH 046/106] Remove warn-concurrency warnings (#408) --- .../PostgresConnection+Configuration.swift | 10 ++--- .../Connection/PostgresConnection.swift | 17 +++++---- .../PostgresDatabase+PreparedQuery.swift | 35 ++++++++++++----- .../Message/PostgresMessage+Error.swift | 4 +- .../New/NotificationListener.swift | 3 +- Sources/PostgresNIO/New/PSQLError.swift | 3 +- Sources/PostgresNIO/New/PSQLRowStream.swift | 8 ++-- Sources/PostgresNIO/New/PSQLTask.swift | 2 +- .../New/PostgresChannelHandler.swift | 14 +++++-- Sources/PostgresNIO/New/PostgresCodable.swift | 2 +- .../PostgresNIO/PostgresDatabase+Query.swift | 28 +++++++++----- .../PostgresDatabase+SimpleQuery.swift | 12 ++++-- Sources/PostgresNIO/PostgresDatabase.swift | 5 ++- .../Utilities/PostgresJSONDecoder.swift | 16 +++++++- .../Utilities/PostgresJSONEncoder.swift | 16 +++++++- Tests/IntegrationTests/AsyncTests.swift | 2 +- .../PSQLIntegrationTests.swift | 11 +++--- .../New/Data/JSON+PSQLCodableTests.swift | 9 +++-- .../New/PSQLRowStreamTests.swift | 38 ++++++++++--------- .../New/PostgresConnectionTests.swift | 2 +- .../Utilities/PostgresJSONCodingTests.swift | 21 ++++++---- 21 files changed, 164 insertions(+), 94 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index bc9bcfc2..22c59d8a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -4,12 +4,12 @@ import NIOSSL extension PostgresConnection { /// A configuration object for a connection - public struct Configuration { - + public struct Configuration: Sendable { + // MARK: - TLS /// The possible modes of operation for TLS encapsulation of a connection. - public struct TLS { + public struct TLS: Sendable { // MARK: Initializers /// Do not try to create a TLS connection to the server. @@ -63,7 +63,7 @@ extension PostgresConnection { // MARK: - Connection options /// Describes options affecting how the underlying connection is made. - public struct Options { + public struct Options: Sendable { /// A timeout for connection attempts. Defaults to ten seconds. /// /// Ignored when using a preexisting communcation channel. (See @@ -219,7 +219,7 @@ extension PostgresConnection { /// the deprecated configuration. /// /// TODO: Drop with next major release - struct InternalConfiguration { + struct InternalConfiguration: Sendable { enum Connection { case unresolvedTCP(host: String, port: Int) case unresolvedUDS(path: String) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 9994ec42..f79a5555 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -144,8 +144,9 @@ public final class PostgresConnection: @unchecked Sendable { on eventLoop: EventLoop ) -> EventLoopFuture { - var logger = logger - logger[postgresMetadataKey: .connectionID] = "\(connectionID)" + var mlogger = logger + mlogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + let logger = mlogger // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the @@ -567,12 +568,13 @@ extension PostgresConnection { /// - line: The line, the query was started in. Used for better error reporting. /// - onRow: A closure that is invoked for every row. /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryMetadata``. + @preconcurrency public func query( _ query: PostgresQuery, logger: Logger, file: String = #fileID, line: Int = #line, - _ onRow: @escaping (PostgresRow) throws -> () + _ onRow: @escaping @Sendable (PostgresRow) throws -> () ) -> EventLoopFuture { self.queryStream(query, logger: logger).flatMap { rowStream in rowStream.onRow(onRow).flatMapThrowing { () -> PostgresQueryMetadata in @@ -638,6 +640,7 @@ extension PostgresConnection: PostgresDatabase { } } + @preconcurrency public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { closure(self) } @@ -645,11 +648,11 @@ extension PostgresConnection: PostgresDatabase { internal enum PostgresCommands: PostgresRequest { case query(PostgresQuery, - onMetadata: (PostgresQueryMetadata) -> () = { _ in }, - onRow: (PostgresRow) throws -> ()) - case queryAll(PostgresQuery, onResult: (PostgresQueryResult) -> ()) + onMetadata: @Sendable (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable (PostgresRow) throws -> ()) + case queryAll(PostgresQuery, onResult: @Sendable (PostgresQueryResult) -> ()) case prepareQuery(request: PrepareQueryRequest) - case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) + case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: @Sendable (PostgresRow) throws -> ()) func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { fatalError("This function must not be called") diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index 074ba6de..56496172 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -1,4 +1,5 @@ import NIOCore +import NIOConcurrencyHelpers import struct Foundation.UUID extension PostgresDatabase { @@ -14,7 +15,8 @@ extension PostgresDatabase { } } - public func prepare(query: String, handler: @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { + @preconcurrency + public func prepare(query: String, handler: @Sendable @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { prepare(query: query) .flatMap { preparedQuery in handler(preparedQuery) @@ -26,7 +28,7 @@ extension PostgresDatabase { } -public struct PreparedQuery { +public struct PreparedQuery: Sendable { let underlying: PSQLPreparedStatement let database: PostgresDatabase @@ -36,11 +38,16 @@ public struct PreparedQuery { } public func execute(_ binds: [PostgresData] = []) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return self.execute(binds) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.execute(binds) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func execute(_ binds: [PostgresData] = [], _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + @preconcurrency + public func execute(_ binds: [PostgresData] = [], _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { let command = PostgresCommands.executePreparedStatement(query: self, binds: binds, onRow: onRow) return self.database.send(command, logger: self.database.logger) } @@ -50,15 +57,23 @@ public struct PreparedQuery { } } -final class PrepareQueryRequest { +final class PrepareQueryRequest: Sendable { let query: String let name: String - var prepared: PreparedQuery? = nil - - + var prepared: PreparedQuery? { + get { + self._prepared.withLockedValue { $0 } + } + set { + self._prepared.withLockedValue { + $0 = newValue + } + } + } + let _prepared: NIOLockedValueBox = .init(nil) + init(_ query: String, as name: String) { self.query = query self.name = name } - } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 44f9e6bf..45cda21f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -2,8 +2,8 @@ import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. - public struct Error: CustomStringConvertible { - public enum Field: UInt8, Hashable { + public struct Error: CustomStringConvertible, Sendable { + public enum Field: UInt8, Hashable, Sendable { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a //// localized translation of one of these. Always present. diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 5f4bc3de..9e47ff34 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -44,6 +44,7 @@ final class NotificationListener: @unchecked Sendable { func startListeningSucceeded(handler: PostgresChannelHandler) { self.eventLoop.preconditionInEventLoop() + let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop) switch self.state { case .streamInitialized(let checkedContinuation): @@ -55,7 +56,7 @@ final class NotificationListener: @unchecked Sendable { switch reason { case .cancelled: eventLoop.execute { - handler.cancelNotificationListener(channel: channel, id: listenerID) + handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID) } case .finished: diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 81099043..4a9f9216 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -1,7 +1,8 @@ import NIOCore /// An error that is thrown from the PostgresClient. -public struct PSQLError: Error { +/// Sendability enforced through Copy on Write semantics +public struct PSQLError: Error, @unchecked Sendable { public struct Code: Sendable, Hashable, CustomStringConvertible { enum Base: Sendable, Hashable { diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b008d185..b3dfea30 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -96,10 +96,8 @@ final class PSQLRowStream: @unchecked Sendable { let yieldResult = source.yield(contentsOf: bufferedRows) self.downstreamState = .asyncSequence(source, dataSource) - self.eventLoop.execute { - self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - } - + self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) + case .finished(let buffer, let commandTag): _ = source.yield(contentsOf: buffer) source.finish() @@ -206,7 +204,7 @@ final class PSQLRowStream: @unchecked Sendable { // MARK: Consume on EventLoop - func onRow(_ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + func onRow(_ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.onRow0(onRow) } else { diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 9425c12b..6308a5b3 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -70,7 +70,7 @@ final class ExtendedQueryContext { } } -final class PreparedStatementContext{ +final class PreparedStatementContext: Sendable { let name: String let sql: String let bindings: PostgresBindings diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 6d9d08b3..9d0ef2a5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -597,8 +597,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: self.logger, promise: promise ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in - self.startListenCompleted(result, for: channel, context: context) + let (selfTransferred, context) = loopBound.value + selfTransferred.startListenCompleted(result, for: channel, context: context) } return .extendedQuery(query) @@ -643,8 +645,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: self.logger, promise: promise ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in - self.stopListenCompleted(result, for: channel, context: context) + let (selfTransferred, context) = loopBound.value + selfTransferred.stopListenCompleted(result, for: channel, context: context) } return .extendedQuery(query) @@ -693,10 +697,12 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext ) -> PSQLTask { let promise = self.eventLoop.makePromise(of: RowDescription?.self) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in + let (selfTransferred, context) = loopBound.value switch result { case .success(let rowDescription): - self.prepareStatementComplete( + selfTransferred.prepareStatementComplete( name: preparedStatement.name, rowDescription: rowDescription, context: context @@ -708,7 +714,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } else { psqlError = .connectionError(underlying: error) } - self.prepareStatementFailed( + selfTransferred.prepareStatementFailed( name: preparedStatement.name, error: psqlError, context: context diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 53dbd708..71c689bf 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -188,7 +188,7 @@ extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { /// A context that is passed to Swift objects that are decoded from the Postgres wire format. Used /// to pass further information to the decoding method. -public struct PostgresDecodingContext { +public struct PostgresDecodingContext: Sendable { /// A ``PostgresJSONDecoder`` used to decode the object from json. public var jsonDecoder: JSONDecoder diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 95abb6fc..01a7e61f 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -1,27 +1,35 @@ import NIOCore import Logging +import NIOConcurrencyHelpers extension PostgresDatabase { public func query( _ string: String, _ binds: [PostgresData] = [] ) -> EventLoopFuture { - var rows: [PostgresRow] = [] - var metadata: PostgresQueryMetadata? - return self.query(string, binds, onMetadata: { - metadata = $0 - }) { - rows.append($0) + let box = NIOLockedValueBox((metadata: PostgresQueryMetadata?.none, rows: [PostgresRow]())) + + return self.query(string, binds, onMetadata: { metadata in + box.withLockedValue { + $0.metadata = metadata + } + }) { row in + box.withLockedValue { + $0.rows.append(row) + } }.map { - .init(metadata: metadata!, rows: rows) + box.withLockedValue { + PostgresQueryResult(metadata: $0.metadata!, rows: $0.rows) + } } } + @preconcurrency public func query( _ string: String, _ binds: [PostgresData] = [], - onMetadata: @escaping (PostgresQueryMetadata) -> () = { _ in }, - onRow: @escaping (PostgresRow) throws -> () + onMetadata: @Sendable @escaping (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable @escaping (PostgresRow) throws -> () ) -> EventLoopFuture { var bindings = PostgresBindings(capacity: binds.count) binds.forEach { bindings.append($0) } @@ -58,7 +66,7 @@ extension PostgresQueryResult: Collection { } } -public struct PostgresQueryMetadata { +public struct PostgresQueryMetadata: Sendable { public let command: String public var oid: Int? public var rows: Int? diff --git a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift index 77f3d034..5cf2d7a4 100644 --- a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift +++ b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift @@ -1,13 +1,19 @@ import NIOCore +import NIOConcurrencyHelpers import Logging extension PostgresDatabase { public func simpleQuery(_ string: String) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return simpleQuery(string) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.simpleQuery(string) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func simpleQuery(_ string: String, _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + @preconcurrency + public func simpleQuery(_ string: String, _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.query(string, onRow: onRow) } } diff --git a/Sources/PostgresNIO/PostgresDatabase.swift b/Sources/PostgresNIO/PostgresDatabase.swift index 64e44abb..fcd1afc7 100644 --- a/Sources/PostgresNIO/PostgresDatabase.swift +++ b/Sources/PostgresNIO/PostgresDatabase.swift @@ -1,14 +1,15 @@ import NIOCore import Logging -public protocol PostgresDatabase { +@preconcurrency +public protocol PostgresDatabase: Sendable { var logger: Logger { get } var eventLoop: EventLoop { get } func send( _ request: PostgresRequest, logger: Logger ) -> EventLoopFuture - + func withConnection(_ closure: @escaping (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture } diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift index fb7b4e8d..ba57ee9b 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift @@ -2,11 +2,13 @@ import class Foundation.JSONDecoder import struct Foundation.Data import NIOFoundationCompat import NIOCore +import NIOConcurrencyHelpers /// A protocol that mimicks the Foundation `JSONDecoder.decode(_:from:)` function. /// Conform a non-Foundation JSON decoder to this protocol if you want PostgresNIO to be /// able to use it when decoding JSON & JSONB values (see `PostgresNIO._defaultJSONDecoder`) -public protocol PostgresJSONDecoder { +@preconcurrency +public protocol PostgresJSONDecoder: Sendable { func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T @@ -20,10 +22,20 @@ extension PostgresJSONDecoder { } } +//@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension JSONDecoder: PostgresJSONDecoder {} +private let jsonDecoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONDecoder()) + /// The default JSON decoder used by PostgresNIO when decoding JSON & JSONB values. /// As `_defaultJSONDecoder` will be reused for decoding all JSON & JSONB values /// from potentially multiple threads at once, you must ensure your custom JSON decoder is /// thread safe internally like `Foundation.JSONDecoder`. -public var _defaultJSONDecoder: PostgresJSONDecoder = JSONDecoder() +public var _defaultJSONDecoder: PostgresJSONDecoder { + set { + jsonDecoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonDecoderLocked.withLockedValue { $0 } + } +} diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift index 735e4b14..9585f20b 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift @@ -1,11 +1,13 @@ import Foundation import NIOFoundationCompat import NIOCore +import NIOConcurrencyHelpers /// A protocol that mimicks the Foundation `JSONEncoder.encode(_:)` function. /// Conform a non-Foundation JSON encoder to this protocol if you want PostgresNIO to be /// able to use it when encoding JSON & JSONB values (see `PostgresNIO._defaultJSONEncoder`) -public protocol PostgresJSONEncoder { +@preconcurrency +public protocol PostgresJSONEncoder: Sendable { func encode(_ value: T) throws -> Data where T : Encodable func encode(_ value: T, into buffer: inout ByteBuffer) throws @@ -20,8 +22,18 @@ extension PostgresJSONEncoder { extension JSONEncoder: PostgresJSONEncoder {} +private let jsonEncoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONEncoder()) + /// The default JSON encoder used by PostgresNIO when encoding JSON & JSONB values. /// As `_defaultJSONEncoder` will be reused for encoding all JSON & JSONB values /// from potentially multiple threads at once, you must ensure your custom JSON encoder is /// thread safe internally like `Foundation.JSONEncoder`. -public var _defaultJSONEncoder: PostgresJSONEncoder = JSONEncoder() +public var _defaultJSONEncoder: PostgresJSONEncoder { + set { + jsonEncoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonEncoderLocked.withLockedValue { $0 } + } +} + diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 5c77ba29..91b5656c 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -323,7 +323,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { let eventLoop = eventLoopGroup.next() struct TestPreparedStatement: PostgresPreparedStatement { - static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) var state: String diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 0550dc77..57939c06 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -1,3 +1,4 @@ +import Atomics import XCTest import Logging import PostgresNIO @@ -73,19 +74,17 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } var metadata: PostgresQueryMetadata? - var received: Int64 = 0 + let received = ManagedAtomic(0) XCTAssertNoThrow(metadata = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest) { row in func workaround() { - var number: Int64? - XCTAssertNoThrow(number = try row.decode(Int64.self, context: .default)) - received += 1 - XCTAssertEqual(number, received) + let expected = received.wrappingIncrementThenLoad(ordering: .relaxed) + XCTAssertEqual(expected, try row.decode(Int64.self, context: .default)) } workaround() }.wait()) - XCTAssertEqual(received, 10000) + XCTAssertEqual(received.load(ordering: .relaxed), 10000) XCTAssertEqual(metadata?.command, "SELECT") XCTAssertEqual(metadata?.rows, 10000) } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 858b6ede..52dead6a 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import Atomics import NIOCore @testable import PostgresNIO @@ -69,11 +70,11 @@ class JSON_PSQLCodableTests: XCTestCase { } func testCustomEncoderIsUsed() { - class TestEncoder: PostgresJSONEncoder { - var encodeHits = 0 + final class TestEncoder: PostgresJSONEncoder { + let encodeHits = ManagedAtomic(0) func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { - self.encodeHits += 1 + self.encodeHits.wrappingIncrement(ordering: .relaxed) } func encode(_ value: T) throws -> Data where T : Encodable { @@ -85,6 +86,6 @@ class JSON_PSQLCodableTests: XCTestCase { let encoder = TestEncoder() var buffer = ByteBuffer() XCTAssertNoThrow(try hello.encode(into: &buffer, context: .init(jsonEncoder: encoder))) - XCTAssertEqual(encoder.encodeHits, 1) + XCTAssertEqual(encoder.encodeHits.load(ordering: .relaxed), 1) } } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index d6d03107..9a1e9e41 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -1,3 +1,4 @@ +import Atomics import NIOCore import Logging import XCTest @@ -128,12 +129,12 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0) // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - counter += 1 + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") } - XCTAssertEqual(counter, 2) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertNoThrow(try future.wait()) @@ -155,7 +156,9 @@ final class PSQLRowStreamTests: XCTestCase { stream.receive([ [ByteBuffer(string: "0")], - [ByteBuffer(string: "1")] + [ByteBuffer(string: "1")], + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")], ]) stream.receive(completion: .success("SELECT 2")) @@ -163,15 +166,15 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0) // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - if counter == 1 { - throw OnRowError(row: counter) + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") + if expected == 1 { + throw OnRowError(row: expected) } - counter += 1 } - XCTAssertEqual(counter, 1) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) // one more than where we excited, because we already incremented XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertThrowsError(try future.wait()) { @@ -179,7 +182,6 @@ final class PSQLRowStreamTests: XCTestCase { } } - func testOnRowBeforeStreamHasFinished() { let dataSource = CountingDataSource() let stream = PSQLRowStream( @@ -201,26 +203,26 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - counter += 1 + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") } - XCTAssertEqual(counter, 2) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) XCTAssertEqual(dataSource.hitDemand, 1) stream.receive([ [ByteBuffer(string: "2")], [ByteBuffer(string: "3")] ]) - XCTAssertEqual(counter, 4) + XCTAssertEqual(counter.load(ordering: .relaxed), 4) XCTAssertEqual(dataSource.hitDemand, 2) stream.receive([ [ByteBuffer(string: "4")], [ByteBuffer(string: "5")] ]) - XCTAssertEqual(counter, 6) + XCTAssertEqual(counter.load(ordering: .relaxed), 6) XCTAssertEqual(dataSource.hitDemand, 3) stream.receive(completion: .success("SELECT 6")) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 59917c40..3b1a8ca9 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -304,7 +304,7 @@ class PostgresConnectionTests: XCTestCase { } struct TestPrepareStatement: PostgresPreparedStatement { - static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String var state: String diff --git a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift index 2aad52b6..c6f876f2 100644 --- a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift +++ b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift @@ -1,3 +1,4 @@ +import Atomics import NIOCore import XCTest import PostgresNIO @@ -10,9 +11,9 @@ class PostgresJSONCodingTests: XCTestCase { PostgresNIO._defaultJSONEncoder = previousDefaultJSONEncoder } final class CustomJSONEncoder: PostgresJSONEncoder { - var didEncode = false + let counter = ManagedAtomic(0) func encode(_ value: T) throws -> Data where T : Encodable { - self.didEncode = true + self.counter.wrappingIncrement(ordering: .relaxed) return try JSONEncoder().encode(value) } } @@ -21,14 +22,16 @@ class PostgresJSONCodingTests: XCTestCase { var bar: Int } let customJSONEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONEncoder = customJSONEncoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONEncoder.didEncode) + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 1) let customJSONBEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONEncoder = customJSONBEncoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONBEncoder.didEncode) + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 1) } // https://github.com/vapor/postgres-nio/issues/126 @@ -38,9 +41,9 @@ class PostgresJSONCodingTests: XCTestCase { PostgresNIO._defaultJSONDecoder = previousDefaultJSONDecoder } final class CustomJSONDecoder: PostgresJSONDecoder { - var didDecode = false + let counter = ManagedAtomic(0) func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable { - self.didDecode = true + self.counter.wrappingIncrement(ordering: .relaxed) return try JSONDecoder().decode(type, from: data) } } @@ -49,13 +52,15 @@ class PostgresJSONCodingTests: XCTestCase { var bar: Int } let customJSONDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONDecoder = customJSONDecoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONDecoder.didDecode) + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 1) let customJSONBDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONDecoder = customJSONBDecoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONBDecoder.didDecode) + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 1) } } From c8269926eb3b705b70aff1975860e357760123c8 Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Thu, 2 Nov 2023 12:48:52 +0000 Subject: [PATCH 047/106] Update README.md (#434) Point documentation links to our docs as that's where we host them now --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 489d0e29..6f289673 100644 --- a/README.md +++ b/README.md @@ -176,20 +176,20 @@ Some queries do not receive any rows from the server (most often `INSERT`, `UPDA Please see [SECURITY.md] for details on the security process. [SSWG Incubation]: https://github.com/swift-server/sswg/blob/main/process/incubation.md#graduated-level -[Documentation]: https://swiftpackageindex.com/vapor/postgres-nio/documentation +[Documentation]: https://api.vapor.codes/postgresnio/documentation/postgresnio [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions [Swift 5.7]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md -[`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ -[`query(_:logger:)`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn -[`PostgresQuery`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresquery/ -[`PostgresRow`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrow/ -[`PostgresRowSequence`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrowsequence/ -[`PostgresDecodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresdecodable/ -[`PostgresEncodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresencodable/ +[`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection +[`query(_:logger:)`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn +[`PostgresQuery`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresquery +[`PostgresRow`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrow +[`PostgresRowSequence`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrowsequence +[`PostgresDecodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresdecodable +[`PostgresEncodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresencodable [PostgresKit]: https://github.com/vapor/postgres-kit From 036931d968aab819f5e380a932237118ac4e87ba Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 10 Nov 2023 18:15:46 +0100 Subject: [PATCH 048/106] Fixes Crash in ConnectionPoolStateMachine (#438) - Correctly handle Connection closes while running a keep alive (fix: #436) - Add further keep alive tests - Restructure MockClock quite a bit --- .../PoolStateMachine+ConnectionGroup.swift | 14 +- .../PoolStateMachine+ConnectionState.swift | 19 ++- .../ConnectionPoolTests.swift | 158 +++++++++++++++++- .../Mocks/MockClock.swift | 77 ++++----- .../Mocks/MockConnection.swift | 89 ---------- .../Mocks/MockConnectionFactory.swift | 92 ++++++++++ .../Mocks/MockPingPongBehaviour.swift | 65 ++++++- ...oolStateMachine+ConnectionStateTests.swift | 2 +- 8 files changed, 356 insertions(+), 160 deletions(-) create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index b53f8d68..fabc3009 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -385,7 +385,8 @@ extension PoolStateMachine { @inlinable mutating func keepAliveSucceeded(_ connectionID: Connection.ID) -> (Int, AvailableConnectionContext)? { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - preconditionFailure("A connection that we don't know was released? Something is very wrong...") + // keepAliveSucceeded can race against, closeIfIdle, shutdowns or connection errors + return nil } guard let connectionInfo = self.connections[index].keepAliveSucceeded() else { @@ -430,15 +431,8 @@ extension PoolStateMachine { self.stats.idle -= 1 self.stats.closing += 1 - -// if idleState.runningKeepAlive { -// self.stats.runningKeepAlive -= 1 -// if self.keepAliveReducesAvailableStreams { -// self.stats.availableStreams += 1 -// } -// } - - self.stats.availableStreams -= closeAction.maxStreams + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams return CloseAction( connection: closeAction.connection!, diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index a56b87da..94196a09 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -496,6 +496,9 @@ extension PoolStateMachine { var usedStreams: UInt16 @usableFromInline var maxStreams: UInt16 + @usableFromInline + var runningKeepAlive: Bool + @inlinable init( @@ -503,13 +506,15 @@ extension PoolStateMachine { previousConnectionState: PreviousConnectionState, cancelTimers: Max2Sequence, usedStreams: UInt16, - maxStreams: UInt16 + maxStreams: UInt16, + runningKeepAlive: Bool ) { self.connection = connection self.previousConnectionState = previousConnectionState self.cancelTimers = cancelTimers self.usedStreams = usedStreams self.maxStreams = maxStreams + self.runningKeepAlive = runningKeepAlive } } @@ -526,7 +531,8 @@ extension PoolStateMachine { idleTimerState?.cancellationContinuation ), usedStreams: keepAlive.usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .leased, .closed: @@ -559,7 +565,8 @@ extension PoolStateMachine { idleTimerState?.cancellationContinuation ), usedStreams: keepAlive.usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .leased(let connection, usedStreams: let usedStreams, maxStreams: let maxStreams, var keepAlive): @@ -571,7 +578,8 @@ extension PoolStateMachine { keepAlive.cancelTimerIfScheduled() ), usedStreams: keepAlive.usedStreams + usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .backingOff(let timer): @@ -581,7 +589,8 @@ extension PoolStateMachine { previousConnectionState: .backingOff, cancelTimers: Max2Sequence(timer.cancellationContinuation), usedStreams: 0, - maxStreams: 0 + maxStreams: 0, + runningKeepAlive: false ) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 5be12a1c..57980711 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -16,7 +16,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: ContinuousClock() ) { @@ -74,7 +74,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: clock ) { @@ -119,7 +119,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: clock ) { @@ -135,7 +135,7 @@ final class ConnectionPoolTests: XCTestCase { throw ConnectionCreationError() } - await clock.timerScheduled() + await clock.nextTimerScheduled() taskGroup.cancelAll() } @@ -156,7 +156,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: ContinuousClock() ) { @@ -220,6 +220,154 @@ final class ConnectionPoolTests: XCTestCase { XCTAssert(hasFinished.load(ordering: .relaxed)) XCTAssertEqual(factory.runningConnections.count, 0) } + + func testKeepAliveWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + print("clock advanced to: \(newTime)") + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + print(deadline3) + + // race keep alive vs timeout + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testKeepAliveWorksRacesAgainstShutdown() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + taskGroup.cancelAll() + print("cancelled") + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift index 573ff073..cd08d54e 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift @@ -1,5 +1,6 @@ @testable import _ConnectionPoolModule import Atomics +import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class MockClock: Clock { @@ -34,19 +35,19 @@ final class MockClock: Clock { var sleepersHeap: Array - var waitersHeap: Array + var waiters: Deque + var nextDeadlines: Deque init() { self.now = .init(.seconds(0)) self.sleepersHeap = Array() - self.waitersHeap = Array() + self.waiters = Deque() + self.nextDeadlines = Deque() } } private struct Waiter { - var expectedSleepers: Int - - var continuation: CheckedContinuation + var continuation: CheckedContinuation } private struct Sleeper { @@ -77,39 +78,34 @@ final class MockClock: Clock { case cancel } - let action = self.stateBox.withLockedValue { state -> (SleepAction, ArraySlice) in - state.waitersHeap = state.waitersHeap.map { waiter in - var waiter = waiter; waiter.expectedSleepers -= 1; return waiter - } - let slice: ArraySlice - let lastRemainingIndex = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > 0 }) - if let lastRemainingIndex { - slice = state.waitersHeap[0.. (SleepAction, Waiter?) in + let waiter: Waiter? + if let next = state.waiters.popFirst() { + waiter = next } else { - slice = [] + state.nextDeadlines.append(deadline) + waiter = nil } if Task.isCancelled { - return (.cancel, slice) + return (.cancel, waiter) } if state.now >= deadline { - return (.resume, slice) + return (.resume, waiter) } - let newWaiter = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) + let newSleeper = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) if let index = state.sleepersHeap.lastIndex(where: { $0.deadline < deadline }) { - state.sleepersHeap.insert(newWaiter, at: index + 1) + state.sleepersHeap.insert(newSleeper, at: index + 1) + } else if let first = state.sleepersHeap.first, first.deadline > deadline { + state.sleepersHeap.insert(newSleeper, at: 0) } else { - state.sleepersHeap.append(newWaiter) + state.sleepersHeap.append(newSleeper) } - return (.none, slice) + return (.none, waiter) } switch action.0 { @@ -121,9 +117,7 @@ final class MockClock: Clock { break } - for waiter in action.1 { - waiter.continuation.resume() - } + action.1?.continuation.resume(returning: deadline) } } onCancel: { let continuation = self.stateBox.withLockedValue { state -> CheckedContinuation? in @@ -136,28 +130,21 @@ final class MockClock: Clock { } } - func timerScheduled(n: Int = 1) async { - precondition(n >= 1, "At least one new sleep must be awaited") - await withCheckedContinuation { (continuation: CheckedContinuation<(), Never>) in - let result = self.stateBox.withLockedValue { state -> Bool in - let n = n - state.sleepersHeap.count - - if n <= 0 { - return true - } - - let waiter = Waiter(expectedSleepers: n, continuation: continuation) - - if let index = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > n }) { - state.waitersHeap.insert(waiter, at: index) + @discardableResult + func nextTimerScheduled() async -> Instant { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let instant = self.stateBox.withLockedValue { state -> Instant? in + if let scheduled = state.nextDeadlines.popFirst() { + return scheduled } else { - state.waitersHeap.append(waiter) + let waiter = Waiter(continuation: continuation) + state.waiters.append(waiter) + return nil } - return false } - if result { - continuation.resume() + if let instant { + continuation.resume(returning: instant) } } } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift index 0fa382f7..49bcc23a 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -73,92 +73,3 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } } -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection - - let stateBox = NIOLockedValueBox(State()) - - struct State { - var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() - - var waiter = Deque), Never>>() - - var runningConnections = [ConnectionID: Connection]() - } - - var pendingConnectionAttemptsCount: Int { - self.stateBox.withLockedValue { $0.attempts.count } - } - - var runningConnections: [Connection] { - self.stateBox.withLockedValue { Array($0.runningConnections.values) } - } - - func makeConnection( - id: Int, - for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> - ) async throws -> ConnectionAndMetadata { - // we currently don't support cancellation when creating a connection - let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in - let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in - if let waiter = state.waiter.popFirst() { - return waiter - } else { - state.attempts.append((id, checkedContinuation)) - return nil - } - } - - if let waiter { - waiter.resume(returning: (id, checkedContinuation)) - } - } - - return .init(connection: result.0, maximalStreamsOnConnection: result.1) - } - - @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { - let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in - let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in - if let attempt = state.attempts.popFirst() { - return attempt - } else { - state.waiter.append(continuation) - return nil - } - } - - if let attempt { - continuation.resume(returning: attempt) - } - } - - do { - let streamCount = try await closure(connectionID) - let connection = MockConnection(id: connectionID) - - connection.onClose { _ in - self.stateBox.withLockedValue { state in - _ = state.runningConnections.removeValue(forKey: connectionID) - } - } - - self.stateBox.withLockedValue { state in - _ = state.runningConnections[connectionID] = connection - } - - continuation.resume(returning: (connection, streamCount)) - return connection - } catch { - continuation.resume(throwing: error) - throw error - } - } -} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift new file mode 100644 index 00000000..b0c94467 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -0,0 +1,92 @@ +@testable import _ConnectionPoolModule +import DequeModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class MockConnectionFactory where Clock.Duration == Duration { + typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + typealias Request = ConnectionRequest + typealias KeepAliveBehavior = MockPingPongBehavior + typealias MetricsDelegate = NoOpConnectionPoolMetrics + typealias ConnectionID = Int + typealias Connection = MockConnection + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() + + var waiter = Deque), Never>>() + + var runningConnections = [ConnectionID: Connection]() + } + + var pendingConnectionAttemptsCount: Int { + self.stateBox.withLockedValue { $0.attempts.count } + } + + var runningConnections: [Connection] { + self.stateBox.withLockedValue { Array($0.runningConnections.values) } + } + + func makeConnection( + id: Int, + for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + ) async throws -> ConnectionAndMetadata { + // we currently don't support cancellation when creating a connection + let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.attempts.append((id, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (id, checkedContinuation)) + } + } + + return .init(connection: result.0, maximalStreamsOnConnection: result.1) + } + + @discardableResult + func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in + let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in + if let attempt = state.attempts.popFirst() { + return attempt + } else { + state.waiter.append(continuation) + return nil + } + } + + if let attempt { + continuation.resume(returning: attempt) + } + } + + do { + let streamCount = try await closure(connectionID) + let connection = MockConnection(id: connectionID) + + connection.onClose { _ in + self.stateBox.withLockedValue { state in + _ = state.runningConnections.removeValue(forKey: connectionID) + } + } + + self.stateBox.withLockedValue { state in + _ = state.runningConnections[connectionID] = connection + } + + continuation.resume(returning: (connection, streamCount)) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift index 2ee9b7a0..637f096c 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift @@ -1,14 +1,69 @@ -import _ConnectionPoolModule +@testable import _ConnectionPoolModule +import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -struct MockPingPongBehavior: ConnectionKeepAliveBehavior { +final class MockPingPongBehavior: ConnectionKeepAliveBehavior { let keepAliveFrequency: Duration? - init(keepAliveFrequency: Duration?) { + let stateBox = NIOLockedValueBox(State()) + + struct State { + var runs = Deque<(Connection, CheckedContinuation)>() + + var waiter = Deque), Never>>() + } + + init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: MockConnection) async throws { - preconditionFailure() + func runKeepAlive(for connection: Connection) async throws { + precondition(self.keepAliveFrequency != nil) + + // we currently don't support cancellation when creating a connection + let success = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation) -> () in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(Connection, CheckedContinuation), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.runs.append((connection, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (connection, checkedContinuation)) + } + } + + precondition(success) + } + + @discardableResult + func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in + let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in + if let run = state.runs.popFirst() { + return run + } else { + state.waiter.append(continuation) + return nil + } + } + + if let run { + continuation.resume(returning: run) + } + } + + do { + let success = try await closure(connection) + + continuation.resume(returning: success) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index 7751837e..bc4c2c4b 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -257,7 +257,7 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1)) + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) } From c41f7e217e09c51a4453019b2875ecb82b69df3d Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 10 Nov 2023 12:33:30 -0600 Subject: [PATCH 049/106] Update README.md --- README.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6f289673..ef1dc4ec 100644 --- a/README.md +++ b/README.md @@ -7,22 +7,22 @@

- Documentation + Documentation - MIT License + MIT License - Continuous Integration + Continuous Integration - Swift 5.7 - 5.9 + Swift 5.7 + - SSWG Incubation Level: Graduated + SSWG Incubation Level: Graduated

-
+ 🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. Features: @@ -190,9 +190,7 @@ Please see [SECURITY.md] for details on the security process. [`PostgresRowSequence`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrowsequence [`PostgresDecodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresdecodable [`PostgresEncodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresencodable - -[PostgresKit]: https://github.com/vapor/postgres-kit - [SwiftNIO]: https://github.com/apple/swift-nio +[PostgresKit]: https://github.com/vapor/postgres-kit [SwiftLog]: https://github.com/apple/swift-log [`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html From f0bfba793eb626cda98e456a7f1f2c1ef13a983a Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 10 Nov 2023 12:34:36 -0600 Subject: [PATCH 050/106] Temporarily disable nightly/main CI --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cc34ddcd..fe4aa185 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swiftlang/swift:nightly-5.10-jammy - - swiftlang/swift:nightly-main-jammy + #- swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.9-jammy code-coverage: true From d5d16e3230cc1d86dde3fd9e8266422d27a440b6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 12 Nov 2023 12:17:09 +0100 Subject: [PATCH 051/106] Test cancel connection request (#439) --- .../ConnectionPoolTests.swift | 60 +++++++++- .../Utils/Waiter.swift | 109 ++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 Tests/ConnectionPoolModuleTests/Utils/Waiter.swift diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 57980711..4d4cac95 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -368,6 +368,64 @@ final class ConnectionPoolTests: XCTestCase { } } -} + func testCancelConnectionRequestWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let leaseTask = Task { + _ = try await pool.leaseConnection() + } + + let connectionAttemptWaiter = Waiter(of: Void.self) + + taskGroup.addTask { + try await factory.nextConnectAttempt { connectionID in + connectionAttemptWaiter.yield(value: ()) + throw CancellationError() + } + } + + try await connectionAttemptWaiter.result + leaseTask.cancel() + + let taskResult = await leaseTask.result + switch taskResult { + case .success: + XCTFail("Expected task failure") + case .failure(let failure): + XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled) + } + + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift b/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift new file mode 100644 index 00000000..12cf90cc --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift @@ -0,0 +1,109 @@ +import Atomics +@testable import _ConnectionPoolModule + +final class Waiter: Sendable { + struct State: Sendable { + + var result: Swift.Result? = nil + var continuations: [(Int, CheckedContinuation)] = [] + + } + + let waiterID = ManagedAtomic(0) + let stateBox: NIOLockedValueBox = NIOLockedValueBox(State()) + + init(of: Result.Type) {} + + enum GetAction { + case fail(any Error) + case succeed(Result) + case none + } + + var result: Result { + get async throws { + let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.stateBox.withLockedValue { state -> GetAction in + if Task.isCancelled { + return .fail(CancellationError()) + } + + switch state.result { + case .none: + state.continuations.append((waiterID, continuation)) + return .none + + case .success(let result): + return .succeed(result) + + case .failure(let error): + return .fail(error) + } + } + + switch action { + case .fail(let error): + continuation.resume(throwing: error) + + case .succeed(let result): + continuation.resume(returning: result) + + case .none: + break + } + } + } onCancel: { + let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in + guard state.result == nil else { return nil } + + guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else { + return nil + } + let (_, continuation) = state.continuations.remove(at: contIndex) + return continuation + } + + cont?.resume(throwing: CancellationError()) + } + } + } + + func yield(value: Result) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .success(value) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(returning: value) + } + } + + func yield(error: any Error) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .failure(error) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(throwing: error) + } + } +} From e1781633a8a843b8901ab8b71cdfdf80fad690af Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 13 Nov 2023 11:13:29 +0100 Subject: [PATCH 052/106] Add test to lease multiple connections at once (#440) - Add test to lease multiple connections at once - Rename `Waiter` to `Future` - Rename `Waiter.Result` to `Future.Success` --- .../ConnectionPoolTests.swift | 86 ++++++++++++++++++- .../Mocks/MockConnectionFactory.swift | 2 +- .../Utils/{Waiter.swift => Future.swift} | 25 +++--- 3 files changed, 99 insertions(+), 14 deletions(-) rename Tests/ConnectionPoolModuleTests/Utils/{Waiter.swift => Future.swift} (77%) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 4d4cac95..a4c2cde7 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -401,7 +401,7 @@ final class ConnectionPoolTests: XCTestCase { _ = try await pool.leaseConnection() } - let connectionAttemptWaiter = Waiter(of: Void.self) + let connectionAttemptWaiter = Future(of: Void.self) taskGroup.addTask { try await factory.nextConnectAttempt { connectionID in @@ -410,7 +410,7 @@ final class ConnectionPoolTests: XCTestCase { } } - try await connectionAttemptWaiter.result + try await connectionAttemptWaiter.success leaseTask.cancel() let taskResult = await leaseTask.result @@ -427,5 +427,87 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingMultipleConnectionsAtOnceWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + var connections = [MockConnection]() + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that we got 4 distinct connections + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + + // release all 4 leased connections + for connection in connections { + pool.releaseConnection(connection) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } } +struct ConnectionFuture: ConnectionRequestProtocol { + let id: Int + let future: Future + + init(id: Int) { + self.id = id + self.future = Future(of: MockConnection.self) + } + + func complete(with result: Result) { + switch result { + case .success(let success): + self.future.yield(value: success) + case .failure(let failure): + self.future.yield(error: failure) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift index b0c94467..eec2e7c3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -30,7 +30,7 @@ final class MockConnectionFactory where Clock.Duratio func makeConnection( id: Int, - for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in diff --git a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift b/Tests/ConnectionPoolModuleTests/Utils/Future.swift similarity index 77% rename from Tests/ConnectionPoolModuleTests/Utils/Waiter.swift rename to Tests/ConnectionPoolModuleTests/Utils/Future.swift index 12cf90cc..2bee3216 100644 --- a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift +++ b/Tests/ConnectionPoolModuleTests/Utils/Future.swift @@ -1,31 +1,34 @@ import Atomics @testable import _ConnectionPoolModule -final class Waiter: Sendable { +/// This is a `Future` type that shall make writing tests a bit simpler. I'm well aware, that this is a pattern +/// that should not be embraced with structured concurrency. However writing all tests in full structured +/// concurrency is an effort, that isn't worth the endgoals in my view. +final class Future: Sendable { struct State: Sendable { - var result: Swift.Result? = nil - var continuations: [(Int, CheckedContinuation)] = [] + var result: Swift.Result? = nil + var continuations: [(Int, CheckedContinuation)] = [] } let waiterID = ManagedAtomic(0) let stateBox: NIOLockedValueBox = NIOLockedValueBox(State()) - init(of: Result.Type) {} + init(of: Success.Type) {} enum GetAction { case fail(any Error) - case succeed(Result) + case succeed(Success) case none } - var result: Result { + var success: Success { get async throws { let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let action = self.stateBox.withLockedValue { state -> GetAction in if Task.isCancelled { return .fail(CancellationError()) @@ -56,7 +59,7 @@ final class Waiter: Sendable { } } } onCancel: { - let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in + let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in guard state.result == nil else { return nil } guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else { @@ -71,10 +74,10 @@ final class Waiter: Sendable { } } - func yield(value: Result) { + func yield(value: Success) { let continuations = self.stateBox.withLockedValue { state in guard state.result == nil else { - return [(Int, CheckedContinuation)]().lazy.map(\.1) + return [(Int, CheckedContinuation)]().lazy.map(\.1) } state.result = .success(value) @@ -92,7 +95,7 @@ final class Waiter: Sendable { func yield(error: any Error) { let continuations = self.stateBox.withLockedValue { state in guard state.result == nil else { - return [(Int, CheckedContinuation)]().lazy.map(\.1) + return [(Int, CheckedContinuation)]().lazy.map(\.1) } state.result = .failure(error) From dc94503944f5f0a6b244efacd0ceb92d1e52cdb8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 14 Nov 2023 10:12:42 +0100 Subject: [PATCH 053/106] Add Test: Lease connection after shutdown has started fails (#441) --- .../ConnectionPoolTests.swift | 116 ++++++++++++++++++ .../Mocks/MockConnection.swift | 66 +++++++--- 2 files changed, 165 insertions(+), 17 deletions(-) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index a4c2cde7..d4388893 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -491,6 +491,122 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + do { + _ = try await pool.leaseConnection() + XCTFail("Expected a failure") + } catch { + print("failed") + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + + print("will close connections: \(factory.runningConnections)") + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } + + func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + + for request in requests { + do { + _ = try await request.future.success + XCTFail("Expected a failure") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + } + + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } } struct ConnectionFuture: ConnectionRequestProtocol { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift index 49bcc23a..f826ea04 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -2,38 +2,59 @@ import DequeModule @testable import _ConnectionPoolModule // Sendability enforced through the lock -final class MockConnection: PooledConnection, @unchecked Sendable { +final class MockConnection: PooledConnection, Sendable { typealias ID = Int let id: ID private enum State { - case running([@Sendable ((any Error)?) -> ()]) + case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) case closing([@Sendable ((any Error)?) -> ()]) case closed } - private let lock = NIOLock() - private var _state = State.running([]) + private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) init(id: Int) { self.id = id } + var signalToClose: Void { + get async throws { + try await withCheckedThrowingContinuation { continuation in + let runRightAway = self.lock.withLockedValue { state -> Bool in + switch state { + case .running(var continuations, let callbacks): + continuations.append(continuation) + state = .running(continuations, callbacks) + return false + + case .closing, .closed: + return true + } + } + + if runRightAway { + continuation.resume() + } + } + } + } + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { - let enqueued = self.lock.withLock { () -> Bool in - switch self._state { + let enqueued = self.lock.withLockedValue { state -> Bool in + switch state { case .closed: return false - case .running(var callbacks): + case .running(let continuations, var callbacks): callbacks.append(closure) - self._state = .running(callbacks) + state = .running(continuations, callbacks) return true case .closing(var callbacks): callbacks.append(closure) - self._state = .closing(callbacks) + state = .closing(callbacks) return true } } @@ -44,25 +65,30 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } func close() { - self.lock.withLock { - switch self._state { - case .running(let callbacks): - self._state = .closing(callbacks) + let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in + switch state { + case .running(let continuations, let callbacks): + state = .closing(callbacks) + return continuations case .closing, .closed: - break + return [] } } + + for continuation in continuations { + continuation.resume() + } } func closeIfClosing() { - let callbacks = self.lock.withLock { () -> [@Sendable ((any Error)?) -> ()] in - switch self._state { + let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in + switch state { case .running, .closed: return [] case .closing(let callbacks): - self._state = .closed + state = .closed return callbacks } } @@ -73,3 +99,9 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } } +extension MockConnection: CustomStringConvertible { + var description: String { + let state = self.lock.withLockedValue { $0 } + return "MockConnection(id: \(self.id), state: \(state))" + } +} From 54f491c9b9a1d0a4f099d21a473b630bcc89d551 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 14 Nov 2023 15:02:23 +0100 Subject: [PATCH 054/106] Add support for multiple streams (#442) --- .../ConnectionPoolModule/ConnectionPool.swift | 6 +- .../PoolStateMachine+ConnectionGroup.swift | 48 +++++- .../PoolStateMachine+ConnectionState.swift | 47 ++++++ .../PoolStateMachine.swift | 46 +++++- .../ConnectionPoolTests.swift | 142 ++++++++++++++++++ 5 files changed, 280 insertions(+), 9 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index e9c9c4c9..ec865979 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -265,8 +265,10 @@ public final class ConnectionPool< } - public func connection(_ connection: Connection, didReceiveNewMaxStreamSetting: UInt16) { - + public func connectionReceivedNewMaxStreamSetting(_ connection: Connection, newMaxStreamSetting maxStreams: UInt16) { + self.modifyStateAndRunActions { state in + state.stateMachine.connectionReceivedNewMaxStreamSetting(connection.id, newMaxStreamSetting: maxStreams) + } } public func run() async { diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index fabc3009..0dbca86f 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -256,6 +256,50 @@ extension PoolStateMachine { return self.connections[index].timerScheduled(timer, cancelContinuation: cancelContinuation) } + // MARK: Changes at runtime + + @usableFromInline + struct NewMaxStreamInfo { + + @usableFromInline + var index: Int + + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(index: Int, info: ConnectionState.NewMaxStreamInfo) { + self.index = index + self.newMaxStreams = info.newMaxStreams + self.oldMaxStreams = info.oldMaxStreams + self.usedStreams = info.usedStreams + } + } + + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connectionID: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> NewMaxStreamInfo? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + guard let info = self.connections[index].newMaxStreamSetting(maxStreams) else { + return nil + } + + self.stats.availableStreams += maxStreams - info.oldMaxStreams + + return NewMaxStreamInfo(index: index, info: info) + } + // MARK: Leasing and releasing /// Lease a connection, if an idle connection is available. @@ -424,9 +468,9 @@ extension PoolStateMachine { /// Closes the connection at the given index. @inlinable - mutating func closeConnectionIfIdle(at index: Int) -> CloseAction { + mutating func closeConnectionIfIdle(at index: Int) -> CloseAction? { guard let closeAction = self.connections[index].closeIfIdle() else { - preconditionFailure("Invalid state: \(self)") + return nil // apparently the connection isn't idle } self.stats.idle -= 1 diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 94196a09..98755ff9 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -195,6 +195,53 @@ extension PoolStateMachine { } } + @usableFromInline + struct NewMaxStreamInfo { + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(newMaxStreams: UInt16, oldMaxStreams: UInt16, usedStreams: UInt16) { + self.newMaxStreams = newMaxStreams + self.oldMaxStreams = oldMaxStreams + self.usedStreams = usedStreams + } + } + + @inlinable + mutating func newMaxStreamSetting(_ newMaxStreams: UInt16) -> NewMaxStreamInfo? { + switch self.state { + case .starting, .backingOff: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let oldMaxStreams, let keepAlive, idleTimer: let idleTimer): + self.state = .idle(connection, maxStreams: newMaxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: keepAlive.usedStreams + ) + + case .leased(let connection, let usedStreams, let oldMaxStreams, let keepAlive): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: newMaxStreams, keepAlive: keepAlive) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: usedStreams + keepAlive.usedStreams + ) + + case .closing, .closed: + return nil + } + } + + @inlinable mutating func parkConnection(scheduleKeepAliveTimer: Bool, scheduleIdleTimeoutTimer: Bool) -> Max2Sequence { var keepAliveTimer: ConnectionTimer? diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 4484e405..6671460a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -262,6 +262,39 @@ struct PoolStateMachine< } } + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connection: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> Action { + guard let info = self.connections.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: maxStreams) else { + return .none() + } + + let waitingRequests = self.requestQueue.count + + guard waitingRequests > 0 else { + return .none() + } + + // the only thing we can do if we receive a new max stream setting is check if the new stream + // setting is higher and then dequeue some waiting requests + + guard info.newMaxStreams > info.oldMaxStreams && info.newMaxStreams > info.usedStreams else { + return .none() + } + + let leaseStreams = min(info.newMaxStreams - info.oldMaxStreams, info.newMaxStreams - info.usedStreams, UInt16(clamping: waitingRequests)) + let requests = self.requestQueue.pop(max: leaseStreams) + precondition(Int(leaseStreams) == requests.count) + let leaseResult = self.connections.leaseConnection(at: info.index, streams: leaseStreams) + + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + @inlinable mutating func timerScheduled(_ timer: Timer, cancelContinuation: TimerCancellationToken) -> TimerCancellationToken? { self.connections.timerScheduled(timer.underlying, cancelContinuation: cancelContinuation) @@ -445,11 +478,14 @@ struct PoolStateMachine< } case .overflow: - let closeAction = self.connections.closeConnectionIfIdle(at: index) - return .init( - request: .none, - connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) - ) + if let closeAction = self.connections.closeConnectionIfIdle(at: index) { + return .init( + request: .none, + connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) + ) + } else { + return .none() + } } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index d4388893..0ff2bdf7 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -607,6 +607,148 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 10 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + let requests = (0..<10).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 10 + } + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + // release all 10 leased streams + for connection in connections { + pool.releaseConnection(connection) + } + + for _ in 0..<9 { + _ = try? await factory.nextConnectAttempt { connectionID in + throw CancellationError() + } + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testIncreasingAvailableStreamsWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + var requests = (0..<21).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 1 + } + + let connection = try await requests.first!.future.success + connections.append(connection) + requests.removeFirst() + + pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + + for (index, request) in requests.enumerated() { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + requests = (22..<42).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + + // release all 21 leased streams in a single call + pool.releaseConnection(connection, streams: 21) + + // ensure all 20 new requests got fulfilled + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // release all 20 leased streams one by one + for _ in requests { + pool.releaseConnection(connection, streams: 1) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } } struct ConnectionFuture: ConnectionRequestProtocol { From e60e49507411fbf187fcf9f74a4596d68f3651c9 Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:28:16 +0100 Subject: [PATCH 055/106] Fix crash in PoolStateMachine+ConnectionGroup when closing connection while keepAlive is running (#444) Fixes #443. Co-authored-by: Gwynne Raskind Co-authored-by: Fabian Fett --- .github/workflows/test.yml | 16 ++- .../ConnectionPoolModule/ConnectionPool.swift | 2 +- .../PoolStateMachine+ConnectionGroup.swift | 24 ++++ .../PoolStateMachine+ConnectionState.swift | 5 + .../PoolStateMachine.swift | 9 ++ .../ConnectionPoolTests.swift | 86 ++++++++++++++ ...oolStateMachine+ConnectionGroupTests.swift | 31 +++++ .../PoolStateMachineTests.swift | 111 ++++++++++++++++++ 8 files changed, 278 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fe4aa185..3d1f44a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swiftlang/swift:nightly-5.10-jammy - #- swiftlang/swift:nightly-main-jammy + - swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.9-jammy code-coverage: true @@ -133,7 +133,7 @@ jobs: matrix: postgres-formula: # Only test one version on macOS, let Linux do the rest - - postgresql@15 + - postgresql@16 postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 @@ -157,10 +157,16 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install "${POSTGRES_FORMULA}" && brew link --force "${POSTGRES_FORMULA}" + # ** BEGIN ** Work around bug in both Homebrew and GHA + (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) + (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) + brew upgrade + # ** END ** Work around bug in both Homebrew and GHA + brew install --overwrite "${POSTGRES_FORMULA}" + brew link --overwrite --force "${POSTGRES_FORMULA}" initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait - timeout-minutes: 2 + timeout-minutes: 15 - name: Checkout code uses: actions/checkout@v4 - name: Run all tests @@ -183,7 +189,7 @@ jobs: gh-codeql: runs-on: ubuntu-latest - container: swift:5.8-jammy # CodeQL currently broken with 5.9 + container: swift:5.9-jammy permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index ec865979..c20fa59e 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -481,7 +481,7 @@ public final class ConnectionPool< self.observabilityDelegate.keepAliveFailed(id: connection.id, error: error) self.modifyStateAndRunActions { state in - state.stateMachine.connectionClosed(connection) + state.stateMachine.connectionKeepAliveFailed(connection.id) } } } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 0dbca86f..833365fa 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -449,6 +449,30 @@ extension PoolStateMachine { return (index, context) } + @inlinable + mutating func keepAliveFailed(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // Connection has already been closed + return nil + } + + guard let closeAction = self.connections[index].keepAliveFailed() else { + return nil + } + + self.stats.idle -= 1 + self.stats.closing += 1 + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams + + // force unwrapping the connection is fine, because a close action due to failed + // keepAlive cannot happen without a connection + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + // MARK: Connection close/removal @usableFromInline diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 98755ff9..2fb68a2d 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -455,6 +455,11 @@ extension PoolStateMachine { } } + @inlinable + mutating func keepAliveFailed() -> CloseAction? { + return self.close() + } + @inlinable mutating func timerScheduled( _ timer: ConnectionTimer, diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 6671460a..3b996033 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -374,6 +374,15 @@ struct PoolStateMachine< return self.handleAvailableConnection(index: index, availableContext: context) } + @inlinable + mutating func connectionKeepAliveFailed(_ connectionID: ConnectionID) -> Action { + guard let closeAction = self.connections.keepAliveFailed(connectionID) else { + return .none() + } + + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + @inlinable mutating func connectionIdleTimerTriggered(_ connectionID: ConnectionID) -> Action { precondition(self.requestQueue.isEmpty) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 0ff2bdf7..ba3c6a3f 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -300,6 +300,92 @@ final class ConnectionPoolTests: XCTestCase { } } + func testKeepAliveOnClose() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(20) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + + await keepAlive.nextKeepAlive { keepAliveConnection in + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + let failingKeepAliveDidRun = ManagedAtomic(false) + // the following keep alive should not cause a crash + _ = try? await keepAlive.nextKeepAlive { keepAliveConnection in + defer { + XCTAssertFalse(failingKeepAliveDidRun + .compareExchange(expected: false, desired: true, ordering: .relaxed).original) + } + XCTAssertTrue(keepAliveConnection === lease1Connection) + keepAliveConnection.close() + throw CancellationError() // any error + } // will fail and it's expected + XCTAssertTrue(failingKeepAliveDidRun.load(ordering: .relaxed)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + func testKeepAliveWorksRacesAgainstShutdown() async throws { let clock = MockClock() let factory = MockConnectionFactory() diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index ac0f96f4..6b8d6c6e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -293,4 +293,35 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(afterPingIdleContext.use, .persisted) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) } + + func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 2, + maximumConcurrentConnectionHardLimit: 2, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + _ = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + _ = connections.closeConnectionIfIdle(newConnection.id) + guard connections.keepAliveFailed(newConnection.id) == nil else { + return XCTFail("Expected keepAliveFailed not to cause close again") + } + XCTAssertEqual(connections.stats, .init(closing: 1)) + } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index a19d2326..f5ada14f 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -266,4 +266,115 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(releaseRequest1.connection, .none) } + func testKeepAliveOnClosingConnection() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + _ = stateMachine.releaseConnection(connection1, streams: 1) + + // trigger keep alive + let keepAliveAction1 = stateMachine.connectionKeepAliveTimerTriggered(connection1.id) + XCTAssertEqual(keepAliveAction1.connection, .runKeepAlive(connection1, nil)) + + // fail keep alive and cause closed + let keepAliveFailed1 = stateMachine.connectionKeepAliveFailed(connection1.id) + XCTAssertEqual(keepAliveFailed1.connection, .closeConnection(connection1, [])) + connection1.closeIfClosing() + + // request connection while none exists anymore + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + _ = stateMachine.releaseConnection(connection2, streams: 1) + + // trigger keep alive while connection is still open + let keepAliveAction2 = stateMachine.connectionKeepAliveTimerTriggered(connection2.id) + XCTAssertEqual(keepAliveAction2.connection, .runKeepAlive(connection2, nil)) + + // close connection in the middle of keep alive + connection2.close() + connection2.closeIfClosing() + + // fail keep alive and cause closed + let keepAliveFailed2 = stateMachine.connectionKeepAliveFailed(connection2.id) + XCTAssertEqual(keepAliveFailed2.connection, .closeConnection(connection2, [])) + } + + func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + + // one connection should exist + let request = MockRequest() + let leaseRequest = stateMachine.leaseConnection(request) + XCTAssertEqual(leaseRequest.connection, .none) + XCTAssertEqual(leaseRequest.request, .none) + + // make connection 1 + let connection = MockConnection(id: 0) + let createdAction = stateMachine.connectionEstablished(connection, maxStreams: 1) + XCTAssertEqual(createdAction.request, .leaseConnection(.init(element: request), connection)) + XCTAssertEqual(createdAction.connection, .none) + _ = stateMachine.releaseConnection(connection, streams: 1) + + // trigger keep alive + let keepAliveAction = stateMachine.connectionKeepAliveTimerTriggered(connection.id) + XCTAssertEqual(keepAliveAction.connection, .runKeepAlive(connection, nil)) + + // fail keep alive, cause closed and make new connection + let keepAliveFailed = stateMachine.connectionKeepAliveFailed(connection.id) + XCTAssertEqual(keepAliveFailed.connection, .closeConnection(connection, [])) + let connectionClosed = stateMachine.connectionClosed(connection) + XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) + connection.closeIfClosing() + } + } From fa3137d39bca84843739db1c5a3db2d7f4ae65e6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 12 Dec 2023 17:01:12 +0100 Subject: [PATCH 056/106] Support additional connection parameters (#361) --- .../PostgresConnection+Configuration.swift | 7 +++- .../ConnectionStateMachine.swift | 34 ++++++++++++--- .../New/PostgresChannelHandler.swift | 2 +- .../New/PostgresFrontendMessageEncoder.swift | 9 +++- .../PSQLFrontendMessageDecoder.swift | 11 +++-- .../Extensions/PostgresFrontendMessage.swift | 27 ++++++++++-- .../New/Messages/StartupTests.swift | 41 ++++++++++++++++++- .../New/PostgresChannelHandlerTests.swift | 9 ++-- .../New/PostgresConnectionTests.swift | 2 +- 9 files changed, 117 insertions(+), 25 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index 22c59d8a..dd0f5404 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -85,7 +85,11 @@ extension PostgresConnection { /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool - + + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] + /// Create an options structure with default values. /// /// Most users should not need to adjust the defaults. @@ -93,6 +97,7 @@ extension PostgresConnection { self.connectTimeout = .seconds(10) self.tlsServerName = nil self.requireBackendKeyData = true + self.additionalStartupParameters = [] } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9cde0cf3..d7a609a6 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1113,11 +1113,19 @@ struct SendPrepareStatement { let query: String } -struct AuthContext: Equatable, CustomDebugStringConvertible { - let username: String - let password: String? - let database: String? - +struct AuthContext: CustomDebugStringConvertible { + var username: String + var password: String? + var database: String? + var additionalParameters: [(String, String)] + + init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) { + self.username = username + self.password = password + self.database = database + self.additionalParameters = additionalParameters + } + var debugDescription: String { """ AuthContext(username: \(String(reflecting: self.username)), \ @@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { } } +extension AuthContext: Equatable { + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.username == rhs.username + && lhs.password == rhs.password + && lhs.database == rhs.database + && lhs.additionalParameters.count == rhs.additionalParameters.count + else { + return false + } + + return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in + lhs.0 == rhs.0 && lhs.1 == rhs.1 + } + } +} + enum PasswordAuthencationMode: Equatable { case cleartext case md5(salt: UInt32) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 9d0ef2a5..54ae0fc9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.startup(user: authContext.username, database: authContext.database) + self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: self.encoder.ssl() diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index e98ab1f1..97805418 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder { self.buffer = buffer } - mutating func startup(user: String, database: String?) { + mutating func startup(user: String, database: String?, options: [(String, String)]) { self.clearIfNeeded() self.buffer.psqlLengthPrefixed { buffer in buffer.writeInteger(Self.startupVersionThree) @@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder { buffer.writeNullTerminatedString(database) } + // we don't send replication parameters, as the default is false and this is what we + // need for a client + for (key, value) in options { + buffer.writeNullTerminatedString(key) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(UInt8(0)) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 46c043b1..55ccd0a9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case 196608: var user: String? var database: String? - var options: String? - + var options = [(String, String)]() + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { let value = messageSlice.readNullTerminatedString() @@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case "database": database = value - case "options": - options = value - default: - break + if let value = value { + options.append((name, value)) + } } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 010667dc..2532959a 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable { static let requestCode: Int32 = 80877103 } - struct Startup: Hashable { + struct Startup: Equatable { static let versionThree: Int32 = 0x00_03_00_00 /// Creates a `Startup` with "3.0" as the protocol version. @@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable { /// The protocol version number is followed by one or more pairs of parameter /// name and value strings. A zero byte is required as a terminator after /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Hashable { + struct Parameters: Equatable { enum Replication { case `true` case `false` @@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable { /// of setting individual run-time parameters.) Spaces within this string are /// considered to separate arguments, unless escaped with a /// backslash (\); write \\ to represent a literal backslash. - var options: String? + var options: [(String, String)] /// Used to connect in streaming replication mode, where a small set of /// replication commands can be issued instead of SQL statements. Value /// can be true, false, or database, and the default is false. var replication: Replication + + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.user == rhs.user + && lhs.database == rhs.database + && lhs.replication == rhs.replication + && lhs.options.count == rhs.options.count + else { + return false + } + + var lhsIterator = lhs.options.makeIterator() + var rhsIterator = rhs.options.makeIterator() + + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else { + return false + } + } + return true + } + } var parameters: Parameters diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 39e9bb42..5af3bf34 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -11,7 +11,7 @@ class StartupTests: XCTestCase { let user = "test" let database = "abc123" - encoder.startup(user: user, database: database) + encoder.startup(user: user, database: database, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -32,7 +32,7 @@ class StartupTests: XCTestCase { let user = "test" - encoder.startup(user: user, database: nil) + encoder.startup(user: user, database: nil, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -44,4 +44,41 @@ class StartupTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 0) } + + func testStartupMessageWithAdditionalOptions() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database, options: [("some", "options")]) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} + +extension PostgresFrontendMessage.Startup.Parameters.Replication { + var stringValue: String { + switch self { + case .true: + return "true" + case .false: + return "false" + case .database: + return "replication" + } + } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index b81d0899..dfdcc53e 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) - XCTAssertEqual(startup.parameters.replication, .false) - + XCTAssert(startup.parameters.options.isEmpty) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle))) @@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) + XCTAssert(startup.parameters.options.isEmpty) XCTAssertEqual(startup.parameters.replication, .false) var buffer = ByteBuffer() @@ -282,7 +281,7 @@ extension AuthContext { PostgresFrontendMessage.Startup.Parameters( user: self.username, database: self.database, - options: nil, + options: self.additionalParameters, replication: .false ) } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 3b1a8ca9..82baf914 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false)))) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) From ea0800d12bbf70a3968b6ccd0cf17bb5d861530f Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:53:32 +0100 Subject: [PATCH 057/106] Fix Availability for DiscardingTaskGroup on watchOS (#448) --- Sources/ConnectionPoolModule/ConnectionPool.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index c20fa59e..9f25e82c 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -591,7 +591,7 @@ protocol TaskGroupProtocol { } #if swift(>=5.8) && os(Linux) || swift(>=5.9) -@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 9.0, *) +@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol {} #endif From 6ce96ab041ee055d6da97717fafa742b0f5915c9 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 30 Jan 2024 04:13:09 -0600 Subject: [PATCH 058/106] Add `Sendable` conformance to `PostgresEncodingContext` (#450) --- Sources/PostgresNIO/New/PostgresCodable.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 71c689bf..fd82c8ea 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -166,11 +166,10 @@ extension PostgresDynamicTypeEncodable { /// A context that is passed to Swift objects that are encoded into the Postgres wire format. Used /// to pass further information to the encoding method. -public struct PostgresEncodingContext { +public struct PostgresEncodingContext: Sendable { /// A ``PostgresJSONEncoder`` used to encode the object to json. public var jsonEncoder: JSONEncoder - /// Creates a ``PostgresEncodingContext`` with the given ``PostgresJSONEncoder``. In case you want /// to use the a ``PostgresEncodingContext`` with an unconfigured Foundation `JSONEncoder` /// you can use the ``default`` context instead. From e9b90b2189b6c64d41522d87616b04f6d978bb06 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 30 Jan 2024 08:28:08 -0600 Subject: [PATCH 059/106] Fix mishandling of SASL attribute parsing (#451) --- .../SASLAuthentication+SCRAM-SHA256.swift | 7 +++--- .../AuthenticationStateMachineTests.swift | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index f2fd8e1a..ac1d9ead 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -209,14 +209,13 @@ fileprivate struct SCRAMMessageParser { } static func parse(raw: [UInt8], isGS2Header: Bool = false) -> [SCRAMAttribute]? { - // There are two ways to implement this parse: // 1. All-at-once: Split on comma, split each on equals, validate // each results in a valid attribute. // 2. Sequential: State machine lookahead parse. // The former is simpler. The latter provides better validation. - let likelyAttributeSets = raw.split(separator: .comma, maxSplits: isGS2Header ? 3 : Int.max, omittingEmptySubsequences: false) - let likelyAttributePairs = likelyAttributeSets.map { $0.split(separator: .equals, maxSplits: 2, omittingEmptySubsequences: false) } + let likelyAttributeSets = raw.split(separator: .comma, maxSplits: isGS2Header ? 2 : Int.max, omittingEmptySubsequences: false) + let likelyAttributePairs = likelyAttributeSets.map { $0.split(separator: .equals, maxSplits: 1, omittingEmptySubsequences: false) } let results = likelyAttributePairs.map { parseAttributePair(name: Array($0[0]), value: $0.dropFirst().first.map { Array($0) } ?? [], isGS2Header: isGS2Header) } let validResults = results.compactMap { $0 } @@ -369,7 +368,7 @@ internal struct SHA256_PLUS: SASLAuthenticationMechanism { } // enum SCRAM } // enum SASLMechanism -/// Common impplementation of SCRAM-SHA-256 and SCRAM-SHA-256-PLUS +/// Common implementation of SCRAM-SHA-256 and SCRAM-SHA-256-PLUS fileprivate final class SASLMechanism_SCRAM_SHA256_Common { /// Initialized with initial client state diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index b06b69ab..df881f90 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -45,6 +45,30 @@ class AuthenticationStateMachineTests: XCTestCase { XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) } + func testAuthenticateSCRAMSHA256WithAtypicalEncoding() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + + let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"])) + guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else { + return XCTFail("\(saslResponse) is not .sendSaslInitialResponse") + } + let responseString = String(decoding: responseData, as: UTF8.self) + XCTAssertEqual(name, "SCRAM-SHA-256") + XCTAssert(responseString.starts(with: "n,,n=test,r=")) + + let saslContinueResponse = state.authenticationMessageReceived(.saslContinue(data: .init(bytes: + "r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,s=ijgUVaWgCDLRJyF963BKNA==,i=4096".utf8 + ))) + guard case .sendSaslResponse(let responseData2) = saslContinueResponse else { + return XCTFail("\(saslContinueResponse) is not .sendSaslResponse") + } + let response2String = String(decoding: responseData2, as: UTF8.self) + XCTAssertEqual(response2String.prefix(76), "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=") + } + func testAuthenticationFailure() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) From 69ccfdf4c80144d845e3b439961b7ec6cd7ae33f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 31 Jan 2024 16:23:36 +0100 Subject: [PATCH 060/106] Be resilient about a read after connection closed (#452) fixes #449 --- .../ConnectionStateMachine.swift | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index d7a609a6..8c3252de 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -624,21 +624,19 @@ struct ConnectionStateMachine { mutating func readEventCaught() -> ConnectionAction { switch self.state { case .initialized: - preconditionFailure("Received a read event on a connection that was never opened.") - case .sslRequestSent: - return .read - case .sslNegotiated: - return .read - case .sslHandlerAdded: - return .read - case .waitingToStartAuthentication: - return .read - case .authenticating: - return .read - case .authenticated: - return .read - case .readyForQuery: + preconditionFailure("Invalid state: \(self.state). Read event before connection established?") + + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .closing: + // all states in which we definitely want to make further forward progress... return .read + case .extendedQuery(var extendedQuery, let connectionContext): self.state = .modifying // avoid CoW let action = extendedQuery.readEventCaught() @@ -651,12 +649,15 @@ struct ConnectionStateMachine { self.state = .closeCommand(closeState, connectionContext) return self.modify(with: action) - case .closing: - return .read case .closed: - preconditionFailure("How can we receive a read, if the connection is closed") + // Generally we shouldn't see this event (read after connection closed?!). + // But truth is, adopters run into this, again and again. So preconditioning here leads + // to unnecessary crashes. So let's be resilient and just make more forward progress. + // If we really care, we probably need to dive deep into PostgresNIO and SwiftNIO. + return .read + case .modifying: - preconditionFailure("Invalid state") + preconditionFailure("Invalid state: \(self.state)") } } From 6433f6d87b0fa7daf9aaeb742bd3c8fd1f16ec26 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:09:16 +0100 Subject: [PATCH 061/106] Fix warnings (#454) --- Sources/PostgresNIO/New/PSQLRowStream.swift | 1 + .../SASLAuthentication+SCRAM-SHA256.swift | 111 +++++++++--------- .../ConnectionPoolTests.swift | 2 +- .../ConnectionAction+TestUtils.swift | 12 +- .../New/Messages/DataRowTests.swift | 2 +- 5 files changed, 65 insertions(+), 63 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b3dfea30..0255e462 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -86,6 +86,7 @@ final class PSQLRowStream: @unchecked Sendable { elementType: DataRow.self, failureType: Error.self, backPressureStrategy: AdaptiveRowBuffer(), + finishOnDeinit: false, delegate: self ) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index ac1d9ead..2a717b6b 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,13 +1,10 @@ import Crypto import Foundation -extension UInt8: ExpressibleByUnicodeScalarLiteral { +extension UInt8 { fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ } fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ } fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ } - public init(unicodeScalarLiteral value: Unicode.Scalar) { - self.init(ascii: value) - } } fileprivate extension String { @@ -87,7 +84,7 @@ fileprivate extension Array where Element == UInt8 { */ var isValidScramValue: Bool { // TODO: FInd a better way than doing a whole construction of String... - return self.count > 0 && !(String(bytes: self, encoding: .utf8)?.contains(",") ?? true) + return self.count > 0 && !(String(decoding: self, as: Unicode.UTF8.self).contains(",")) } } @@ -171,40 +168,40 @@ fileprivate struct SCRAMMessageParser { static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? { guard name.count == 1 || isGS2Header else { return nil } switch name.first { - case "m" where !isGS2Header: return .m(value) - case "r" where !isGS2Header: return String(printableAscii: value).map { .r($0) } - case "c" where !isGS2Header: - guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } - guard (1...3).contains(parsedAttrs.count) else { return nil } - switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { - case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) - case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) - case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) - case let (.gp(bind), .none, .none): return .c(binding: bind) - default: return nil - } - case "n" where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) } - case "s" where !isGS2Header: return value.decodingBase64().map { .s($0) } - case "i" where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } - case "p" where !isGS2Header: return value.decodingBase64().map { .p($0) } - case "v" where !isGS2Header: return value.decodingBase64().map { .v($0) } - case "e" where !isGS2Header: // TODO: actually map the specific enum string values - guard value.isValidScramValue else { return nil } - return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) } - - case "y" where isGS2Header && value.count == 0: return .gp(.unused) - case "n" where isGS2Header && value.count == 0: return .gp(.unsupported) - case "p" where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } - case "a" where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) } - case .none where isGS2Header: return .a(nil) + case UInt8(ascii: "m") where !isGS2Header: return .m(value) + case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) } + case UInt8(ascii: "c") where !isGS2Header: + guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } + guard (1...3).contains(parsedAttrs.count) else { return nil } + switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { + case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) + case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) + case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) + case let (.gp(bind), .none, .none): return .c(binding: bind) + default: return nil + } + case UInt8(ascii: "n") where !isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .n($0) } + case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) } + case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } + case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) } + case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) } + case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values + guard value.isValidScramValue else { return nil } + return SCRAMServerError(rawValue: String(decoding: value, as: Unicode.UTF8.self)).flatMap { .e($0) } - default: - if isGS2Header { - return .gm(name + value) - } else { - guard value.count > 0, value.isValidScramValue else { return nil } - return .optional(name: CChar(name[0]), value: value) - } + case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused) + case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported) + case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } + case UInt8(ascii: "a") where isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .a($0) } + case .none where isGS2Header: return .a(nil) + + default: + if isGS2Header { + return .gm(name + value) + } else { + guard value.count > 0, value.isValidScramValue else { return nil } + return .optional(name: CChar(name[0]), value: value) + } } } @@ -230,45 +227,45 @@ fileprivate struct SCRAMMessageParser { for attribute in attributes { switch attribute { case .m(let value): - result.append("m"); result.append("="); result.append(contentsOf: value) + result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value) case .r(let nonce): - result.append("r"); result.append("="); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) case .n(let name): - result.append("n"); result.append("="); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) case .s(let salt): - result.append("s"); result.append("="); result.append(contentsOf: salt.encodingBase64()) + result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64()) case .i(let count): - result.append("i"); result.append("="); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) case .p(let proof): - result.append("p"); result.append("="); result.append(contentsOf: proof.encodingBase64()) + result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64()) case .v(let signature): - result.append("v"); result.append("="); result.append(contentsOf: signature.encodingBase64()) + result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64()) case .e(let error): - result.append("e"); result.append("="); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) case .c(let binding, let identity): if isInitialGS2Header { switch binding { - case .unsupported: result.append("n") - case .unused: result.append("y") - case .bind(let name, _): result.append("p"); result.append("="); result.append(contentsOf: name.utf8.map { UInt8($0) }) + case .unsupported: result.append(UInt8(ascii: "n")) + case .unused: result.append(UInt8(ascii: "y")) + case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) if let identity = identity { - result.append("a"); result.append("="); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) } else { guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil } if case let .bind(_, data) = binding { guard let data = data else { return nil } partial.append(contentsOf: data) } - result.append("c"); result.append("="); result.append(contentsOf: partial.encodingBase64()) + result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64()) } default: return nil } - result.append(",") + result.append(.comma) } return result.dropLast() } @@ -472,7 +469,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = firstMessageBare; authMessage.append(","); authMessage.append(contentsOf: message); authMessage.append(","); authMessage.append(contentsOf: clientFinalNoProof) + var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) var clientProof = Array(clientKey) @@ -485,7 +482,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { } // Generate a `client-final-message` - var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(",") + var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) @@ -590,7 +587,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // Compute client signature let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = clientBareFirstMessage; authMessage.append(","); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(","); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) + var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) // Recompute client key from signature and proof, verify match diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index ba3c6a3f..3e3c9d65 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -803,7 +803,7 @@ final class ConnectionPoolTests: XCTestCase { pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) - for (index, request) in requests.enumerated() { + for (_, request) in requests.enumerated() { let connection = try await request.future.success connections.append(connection) } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index febeee37..d20032a8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -2,7 +2,8 @@ import class Foundation.JSONEncoder import NIOCore @testable import PostgresNIO -extension ConnectionStateMachine.ConnectionAction: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.read, read): @@ -47,7 +48,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { } } -extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol' +extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else { return false @@ -96,13 +98,15 @@ extension ConnectionStateMachine { } } -extension PSQLError: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLError: Swift.Equatable { public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool { return true } } -extension PSQLTask: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLTask: Swift.Equatable { public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool { switch (lhs, rhs) { case (.extendedQuery(let lhs), .extendedQuery(let rhs)): diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index db31b98a..a90d1e93 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -113,7 +113,7 @@ class DataRowTests: XCTestCase { } } -extension DataRow: ExpressibleByArrayLiteral { +extension PostgresNIO.DataRow: Swift.ExpressibleByArrayLiteral { public typealias ArrayLiteralElement = PostgresEncodable public init(arrayLiteral elements: PostgresEncodable...) { From 85d189c461b96a73f42df7b61c9d16dd06f74bfa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:24:55 +0100 Subject: [PATCH 062/106] Run queries directly on PostgresClient (#456) --- Sources/PostgresNIO/New/PSQLRowStream.swift | 27 +++++----- Sources/PostgresNIO/Pool/PostgresClient.swift | 52 +++++++++++++++++++ .../PostgresClientTests.swift | 37 +++++++++++++ 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 0255e462..b7f2d4fb 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -35,7 +35,7 @@ final class PSQLRowStream: @unchecked Sendable { case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) case consumed(Result) - case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource) + case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } internal let rowDescription: [RowDescription.Column] @@ -75,7 +75,7 @@ final class PSQLRowStream: @unchecked Sendable { // MARK: Async Sequence - func asyncSequence() -> PostgresRowSequence { + func asyncSequence(onFinish: @escaping @Sendable () -> () = {}) -> PostgresRowSequence { self.eventLoop.preconditionInEventLoop() guard case .waitingForConsumer(let bufferState) = self.downstreamState else { @@ -95,13 +95,13 @@ final class PSQLRowStream: @unchecked Sendable { switch bufferState { case .streaming(let bufferedRows, let dataSource): let yieldResult = source.yield(contentsOf: bufferedRows) - self.downstreamState = .asyncSequence(source, dataSource) - + self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) case .finished(let buffer, let commandTag): _ = source.yield(contentsOf: buffer) source.finish() + onFinish() self.downstreamState = .consumed(.success(commandTag)) case .failure(let error): @@ -130,7 +130,7 @@ final class PSQLRowStream: @unchecked Sendable { case .consumed: break - case .asyncSequence(_, let dataSource): + case .asyncSequence(_, let dataSource, _): dataSource.request(for: self) } } @@ -147,9 +147,10 @@ final class PSQLRowStream: @unchecked Sendable { private func cancel0() { switch self.downstreamState { - case .asyncSequence(_, let dataSource): + case .asyncSequence(_, let dataSource, let onFinish): self.downstreamState = .consumed(.failure(CancellationError())) dataSource.cancel(for: self) + onFinish() case .consumed: return @@ -320,7 +321,7 @@ final class PSQLRowStream: @unchecked Sendable { // immediately request more dataSource.request(for: self) - case .asyncSequence(let consumer, let source): + case .asyncSequence(let consumer, let source, _): let yieldResult = consumer.yield(contentsOf: newRows) self.executeActionBasedOnYieldResult(yieldResult, source: source) @@ -359,10 +360,11 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .consumed(.success(commandTag)) promise.succeed(rows) - case .asyncSequence(let source, _): - source.finish() + case .asyncSequence(let source, _, let onFinish): self.downstreamState = .consumed(.success(commandTag)) - + source.finish() + onFinish() + case .consumed: break } @@ -384,9 +386,10 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .consumed(.failure(error)) promise.fail(error) - case .asyncSequence(let consumer, _): - consumer.finish(error) + case .asyncSequence(let consumer, _, let onFinish): self.downstreamState = .consumed(.failure(error)) + consumer.finish(error) + onFinish() case .consumed: break diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index fc5a5b00..5b1bfa38 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -290,6 +290,58 @@ public final class PostgresClient: Sendable { return try await closure(connection) } + /// Run a query on the Postgres server the client is connected to. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresRowSequence { + do { + guard query.binds.count <= Int(UInt16.max) else { + throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) + } + + let connection = try await self.leaseConnection() + + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(connection.id)" + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise + ) + + connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult.map { + $0.asyncSequence(onFinish: { + self.pool.releaseConnection(connection) + }) + }.get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = query + throw error // rethrow with more metadata + } + } + /// The client's run method. Users must call this function in order to start the client's background task processing /// like creating and destroying connections and running timers. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index b1e7f9a8..4f22517e 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -41,6 +41,43 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testQueryDirectly() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + for i in 0..<10000 { + taskGroup.addTask { + do { + try await client.query("SELECT 1", logger: logger) + logger.info("Success", metadata: ["run": "\(i)"]) + } catch { + XCTFail("Unexpected error: \(error)") + } + } + } + + for _ in 0..<10000 { + _ = await taskGroup.nextResult()! + } + + taskGroup.cancelAll() + } + } + } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) From 0679ede84f4c628f4d60810c32a33ced02e178ea Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:50:15 +0100 Subject: [PATCH 063/106] Fix prepared statements (#455) --- .../Connection/PostgresConnection.swift | 9 ++- .../ConnectionStateMachine.swift | 8 +- .../ExtendedQueryStateMachine.swift | 14 ++-- Sources/PostgresNIO/New/PSQLTask.swift | 14 +++- .../New/PostgresChannelHandler.swift | 10 ++- .../PostgresNIO/New/PreparedStatement.swift | 23 +++++- Tests/IntegrationTests/AsyncTests.swift | 81 +++++++++++++++++++ .../PrepareStatementStateMachineTests.swift | 12 +-- .../PreparedStatementStateMachineTests.swift | 1 + .../ConnectionAction+TestUtils.swift | 4 +- .../New/PostgresConnectionTests.swift | 8 +- 11 files changed, 150 insertions(+), 34 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index f79a5555..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -234,6 +234,7 @@ public final class PostgresConnection: @unchecked Sendable { let context = ExtendedQueryContext( name: name, query: query, + bindingDataTypes: [], logger: logger, promise: promise ) @@ -472,9 +473,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) @@ -493,10 +495,10 @@ extension PostgresConnection { ) throw error // rethrow with more metadata } - } /// Execute a prepared statement, taking care of the preparation when necessary + @_disfavoredOverload public func execute( _ preparedStatement: Statement, logger: Logger, @@ -506,9 +508,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 8c3252de..9d264bcc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -97,7 +97,7 @@ struct ConnectionStateMachine { case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) @@ -587,7 +587,7 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } case .closeCommand(let closeContext): @@ -1057,8 +1057,8 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - case .sendParseDescribeSync(name: let name, query: let query): - return .sendParseDescribeSync(name: name, query: query) + case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes): + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) case .succeedPreparedStatementCreation(let promise, with: let rowDescription): return .succeedPreparedStatementCreation(promise, with: rowDescription) case .failPreparedStatementCreation(let promise, with: let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 3a84031b..78f0d202 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -26,7 +26,7 @@ struct ExtendedQueryStateMachine { enum Action { case sendParseDescribeBindExecuteSync(PostgresQuery) - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions @@ -79,10 +79,10 @@ struct ExtendedQueryStateMachine { return .sendBindExecuteSync(prepared) } - case .prepareStatement(let name, let query, _): + case .prepareStatement(let name, let query, let bindingDataTypes, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) - return .sendParseDescribeSync(name: name, query: query) + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) } } } @@ -107,7 +107,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } @@ -165,7 +165,7 @@ struct ExtendedQueryStateMachine { return .wait } - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .succeedPreparedStatementCreation(promise, with: nil) @@ -200,7 +200,7 @@ struct ExtendedQueryStateMachine { case .unnamed, .executeStatement: return .wait - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) } } @@ -477,7 +477,7 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6308a5b3..363f9394 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -21,7 +21,7 @@ enum PSQLTask { eventLoopPromise.fail(error) case .executeStatement(_, let eventLoopPromise): eventLoopPromise.fail(error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): eventLoopPromise.fail(error) } @@ -35,7 +35,7 @@ final class ExtendedQueryContext { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) - case prepareStatement(name: String, query: String, EventLoopPromise) + case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } let query: Query @@ -62,10 +62,11 @@ final class ExtendedQueryContext { init( name: String, query: String, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { - self.query = .prepareStatement(name: name, query: query, promise) + self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise) self.logger = logger } } @@ -73,6 +74,7 @@ final class ExtendedQueryContext { final class PreparedStatementContext: Sendable { let name: String let sql: String + let bindingDataTypes: [PostgresDataType] let bindings: PostgresBindings let logger: Logger let promise: EventLoopPromise @@ -81,12 +83,18 @@ final class PreparedStatementContext: Sendable { name: String, sql: String, bindings: PostgresBindings, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { self.name = name self.sql = sql self.bindings = bindings + if bindingDataTypes.isEmpty { + self.bindingDataTypes = bindings.metadata.map(\.dataType) + } else { + self.bindingDataTypes = bindingDataTypes + } self.logger = logger self.promise = promise } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 54ae0fc9..32dea4a5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -345,8 +345,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: context.fireChannelInactive() - case .sendParseDescribeSync(let name, let query): - self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context) + case .sendParseDescribeSync(let name, let query, let bindingDataTypes): + self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context) case .sendBindExecuteSync(let executeStatement): self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): @@ -489,13 +489,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func sendParseDecribeAndSyncMessage( + private func sendParseDescribeAndSyncMessage( statementName: String, query: String, + bindingDataTypes: [PostgresDataType], context: ChannelHandlerContext ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes) self.encoder.describePreparedStatement(statementName) self.encoder.sync() context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) @@ -724,6 +725,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return .extendedQuery(.init( name: preparedStatement.name, query: preparedStatement.sql, + bindingDataTypes: preparedStatement.bindingDataTypes, logger: preparedStatement.logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 1e0b5d5a..21165388 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -26,15 +26,36 @@ /// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, /// which will take care of preparing the statement on the server side and executing it. public protocol PostgresPreparedStatement: Sendable { + /// The prepared statements name. + /// + /// > Note: There is a default implementation that returns the implementor's name. + static var name: String { get } + /// The type rows returned by the statement will be decoded into associatedtype Row /// The SQL statement to prepare on the database server. static var sql: String { get } - /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + /// The postgres data types of the values that are bind when this statement is executed. + /// + /// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned + /// from ``PostgresPreparedStatement/makeBindings()``. + /// + /// > Note: There is a default implementation that returns an empty array, which will lead to + /// automatic inference. + static var bindingDataTypes: [PostgresDataType] { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement. + /// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``. func makeBindings() throws -> PostgresBindings /// Decode a row returned by the database into an instance of `Row` func decodeRow(_ row: PostgresRow) throws -> Row } + +extension PostgresPreparedStatement { + public static var name: String { String(reflecting: self) } + + public static var bindingDataTypes: [PostgresDataType] { [] } +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 91b5656c..75e5b6ba 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -358,6 +358,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } } + + static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable" + func testPreparedStatementWithIntegerBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(uuid)" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 6a08afeb..547f5cdf 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -12,11 +12,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -38,11 +38,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -60,11 +60,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index ab77a57c..f6c1ddf7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -152,6 +152,7 @@ class PreparedStatementStateMachineTests: XCTestCase { name: "test", sql: "INSERT INTO test_table (column1) VALUES (1)", bindings: PostgresBindings(), + bindingDataTypes: [], logger: .psqlTest, promise: promise ) diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index d20032a8..9a1224d8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -36,8 +36,8 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext - case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): - return lhsName == rhsName && lhsQuery == rhsQuery + case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes)): + return lhsName == rhsName && lhsQuery == rhsQuery && lhsDataTypes == rhsDataTypes case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription case (.fireChannelInactive, .fireChannelInactive): diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 82baf914..a773cf2c 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -337,7 +337,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -393,7 +393,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -487,7 +487,7 @@ class PostgresConnectionTests: XCTestCase { // The channel deduplicates prepare requests, we're going to see only one of them let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -555,7 +555,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } From 17b23b1a24f0e7b451be6ae27d30f29a4c29099f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 19:26:49 +0100 Subject: [PATCH 064/106] Adds prepared statement support to client (#459) --- Sources/PostgresNIO/Pool/PostgresClient.swift | 42 ++++++++++ .../PostgresClientTests.swift | 81 ++++++++++++++++++- .../New/PostgresConnectionTests.swift | 2 +- 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 5b1bfa38..4a576085 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -342,6 +342,48 @@ public final class PostgresClient: Sendable { } } + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + + do { + let connection = try await self.leaseConnection() + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, + logger: logger, + promise: promise + )) + connection.channel.write(task, promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult + .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } + /// The client's run method. Users must call this function in order to start the client's background task processing /// like creating and destroying connections and running timers. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 4f22517e..9115dc82 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -25,16 +25,17 @@ final class PostgresClientTests: XCTestCase { await client.run() } - for i in 0..<10000 { + let iterations = 1000 + + for i in 0.. PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + for try await (id, uuid) in try await client.execute(Example(id: 200), logger: logger) { + logger.info("id: \(id), uuid: \(uuid.uuidString)") + } + + try await client.query( + """ + DROP TABLE "\(unescaped: tableName)"; + """, + logger: logger + ) + + taskGroup.cancelAll() + } + } catch { + XCTFail("Unexpected error: \(String(reflecting: error))") + } + } } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index a773cf2c..f2cd96f8 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -155,7 +155,7 @@ class PostgresConnectionTests: XCTestCase { _ = try await iterator.next() XCTFail("Did not expect to not throw") } catch { - print(error) + self.logger.error("error", metadata: ["error": "\(error)"]) } } From c75349fadbffaba06dedaf6c0eb936a4edff5dc5 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 22 Feb 2024 14:00:01 +0100 Subject: [PATCH 065/106] PostgresClient implements ServiceLifecycle's Service (#457) --- Package.swift | 2 ++ Sources/PostgresNIO/Pool/PostgresClient.swift | 27 ++++++++++--------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Package.swift b/Package.swift index 814335bd..4d008371 100644 --- a/Package.swift +++ b/Package.swift @@ -22,6 +22,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.4.1"), ], targets: [ .target( @@ -39,6 +40,7 @@ let package = Package( .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), ] ), .target( diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 4a576085..2c21cce7 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -2,6 +2,7 @@ import NIOCore import NIOSSL import Atomics import Logging +import ServiceLifecycle import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's @@ -17,23 +18,22 @@ import _ConnectionPoolModule /// client.run() // !important /// } /// -/// taskGroup.addTask { -/// client.withConnection { connection in -/// do { -/// let rows = try await connection.query("SELECT userID, name, age FROM users;") -/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { -/// // do something with the values -/// } -/// } catch { -/// // handle errors -/// } +/// do { +/// let rows = try await connection.query("SELECT userID, name, age FROM users;") +/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { +/// // do something with the values /// } +/// } catch { +/// // handle errors /// } +/// +/// // shutdown the client +/// taskGroup.cancelAll() /// } /// ``` @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @_spi(ConnectionPool) -public final class PostgresClient: Sendable { +public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { public struct TLS: Sendable { enum Base { @@ -391,7 +391,10 @@ public final class PostgresClient: Sendable { public func run() async { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") - await self.pool.run() + + await cancelOnGracefulShutdown { + await self.pool.run() + } } // MARK: - Private Methods - From 7632411e5964f0fb8ffa92acd5cd7b6be46625a6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 22 Feb 2024 15:43:51 +0100 Subject: [PATCH 066/106] Make PostgresClient API (#460) --- Sources/PostgresNIO/Pool/PostgresClient.swift | 30 ++++++++++++++----- .../PostgresClientTests.swift | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2c21cce7..865dafc8 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -8,11 +8,11 @@ import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's /// behavior. /// -/// > Important: +/// > Warning: /// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: /// /// ```swift -/// let client = PostgresClient(configuration: configuration, logger: logger) +/// let client = PostgresClient(configuration: configuration) /// await withTaskGroup(of: Void.self) { /// taskGroup.addTask { /// client.run() // !important @@ -32,7 +32,6 @@ import _ConnectionPoolModule /// } /// ``` @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -@_spi(ConnectionPool) public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { public struct TLS: Sendable { @@ -246,8 +245,22 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let factory: ConnectionFactory let runningAtomic = ManagedAtomic(false) let backgroundLogger: Logger - + + /// Creates a new ``PostgresClient``, that does not log any background information. + /// Don't forget to run ``run()`` the client in a long running task. + /// + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + public convenience init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup + ) { + self.init(configuration: configuration, eventLoopGroup: eventLoopGroup, backgroundLogger: Self.loggingDisabled) + } + /// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task. + /// /// - Parameters: /// - configuration: The client's configuration. See ``Configuration`` for details. /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. @@ -302,10 +315,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { @discardableResult public func query( _ query: PostgresQuery, - logger: Logger, + logger: Logger? = nil, file: String = #fileID, line: Int = #line ) async throws -> PostgresRowSequence { + let logger = logger ?? Self.loggingDisabled do { guard query.binds.count <= Int(UInt16.max) else { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) @@ -345,11 +359,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// Execute a prepared statement, taking care of the preparation when necessary public func execute( _ preparedStatement: Statement, - logger: Logger, + logger: Logger? = nil, file: String = #fileID, line: Int = #line ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { let bindings = try preparedStatement.makeBindings() + let logger = logger ?? Self.loggingDisabled do { let connection = try await self.leaseConnection() @@ -412,6 +427,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { public static var defaultEventLoopGroup: EventLoopGroup { PostgresConnection.defaultEventLoopGroup } + + static let loggingDisabled = Logger(label: "Postgres-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -444,7 +461,6 @@ extension ConnectionPoolConfiguration { } } -@_spi(ConnectionPool) extension PostgresConnection: PooledConnection { public func close() { self.channel.close(mode: .all, promise: nil) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 9115dc82..d6d89dc3 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -27,7 +27,7 @@ final class PostgresClientTests: XCTestCase { let iterations = 1000 - for i in 0.. Date: Thu, 22 Feb 2024 16:14:04 +0100 Subject: [PATCH 067/106] Improve docs before releasing PostgresClient (#461) --- README.md | 66 ++++++------- Snippets/Birthdays.swift | 74 +++++++++++++++ Snippets/PostgresClient.swift | 40 ++++++++ Sources/PostgresNIO/Docs.docc/coding.md | 39 ++++++++ Sources/PostgresNIO/Docs.docc/deprecated.md | 43 +++++++++ Sources/PostgresNIO/Docs.docc/index.md | 93 +++++++------------ Sources/PostgresNIO/Docs.docc/listen.md | 9 ++ Sources/PostgresNIO/Docs.docc/migrations.md | 12 --- .../Docs.docc/prepared-statement.md | 7 ++ .../PostgresNIO/Docs.docc/running-queries.md | 27 ++++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 62 ++++++++----- 11 files changed, 336 insertions(+), 136 deletions(-) create mode 100644 Snippets/Birthdays.swift create mode 100644 Snippets/PostgresClient.swift create mode 100644 Sources/PostgresNIO/Docs.docc/coding.md create mode 100644 Sources/PostgresNIO/Docs.docc/deprecated.md create mode 100644 Sources/PostgresNIO/Docs.docc/listen.md create mode 100644 Sources/PostgresNIO/Docs.docc/prepared-statement.md create mode 100644 Sources/PostgresNIO/Docs.docc/running-queries.md diff --git a/README.md b/README.md index ef1dc4ec..c2dc545e 100644 --- a/README.md +++ b/README.md @@ -28,15 +28,14 @@ Features: - A [`PostgresConnection`] which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A [`PostgresClient`] which pools and manages connections - An async/await interface that supports backpressure - Automatic conversions between Swift primitive types and the Postgres wire format -- Integrated with the Swift server ecosystem, including use of [SwiftLog]. +- Integrated with the Swift server ecosystem, including use of [SwiftLog] and [ServiceLifecycle]. - Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) - Support for `Network.framework` when available (e.g. on Apple platforms) - Supports running on Unix Domain Sockets -PostgresNIO does not provide a `ConnectionPool` as of today, but this is a [feature high on our list](https://github.com/vapor/postgres-nio/issues/256). If you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. - ## API Docs Check out the [PostgresNIO API docs][Documentation] for a @@ -44,13 +43,16 @@ detailed look at all of the classes, structs, protocols, and more. ## Getting started +Interested in an example? We prepared a simple [Birthday example](/vapor/postgres-nio/tree/main/Snippets/Birthdays.swift) +in the Snippets folder. + #### Adding the dependency Add `PostgresNIO` as dependency to your `Package.swift`: ```swift dependencies: [ - .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.14.0"), + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.21.0"), ... ] ``` @@ -64,14 +66,14 @@ Add `PostgresNIO` to the target you want to use it in: ] ``` -#### Creating a connection +#### Creating a client -To create a connection, first create a connection configuration object: +To create a [`PostgresClient`], which pools connections for you, first create a configuration object: ```swift import PostgresNIO -let config = PostgresConnection.Configuration( +let config = PostgresClient.Configuration( host: "localhost", port: 5432, username: "my_username", @@ -81,50 +83,35 @@ let config = PostgresConnection.Configuration( ) ``` -To create a connection we need a [`Logger`], that is used to log connection background events. - +Next you can create you client with it: ```swift -import Logging - -let logger = Logger(label: "postgres-logger") +let client = PostgresClient(configuration: config) ``` -Now we can put it together: - +Once you have create your client, you must [`run()`] it: ```swift -import PostgresNIO -import Logging - -let logger = Logger(label: "postgres-logger") - -let config = PostgresConnection.Configuration( - host: "localhost", - port: 5432, - username: "my_username", - password: "my_password", - database: "my_database", - tls: .disable -) +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } -let connection = try await PostgresConnection.connect( - configuration: config, - id: 1, - logger: logger -) + // You can use the client while the `client.run()` method is not cancelled. -// Close your connection once done -try await connection.close() + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} ``` #### Querying -Once a connection is established, queries can be sent to the server. This is very straightforward: +Once a client is running, queries can be sent to the server. This is straightforward: ```swift -let rows = try await connection.query("SELECT id, username, birthday FROM users", logger: logger) +let rows = try await client.query("SELECT id, username, birthday FROM users") ``` -The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. The rows can be iterated one-by-one: +The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. +The rows can be iterated one-by-one: ```swift for try await row in rows { @@ -160,7 +147,7 @@ Sending parameterized queries to the database is also supported (in the coolest let id = 1 let username = "fancyuser" let birthday = Date() -try await connection.query(""" +try await client.query(""" INSERT INTO users (id, username, birthday) VALUES (\(id), \(username), \(birthday)) """, logger: logger @@ -184,6 +171,8 @@ Please see [SECURITY.md] for details on the security process. [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection +[`PostgresClient`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient +[`run()`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient/run() [`query(_:logger:)`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn [`PostgresQuery`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresquery [`PostgresRow`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrow @@ -193,4 +182,5 @@ Please see [SECURITY.md] for details on the security process. [SwiftNIO]: https://github.com/apple/swift-nio [PostgresKit]: https://github.com/vapor/postgres-kit [SwiftLog]: https://github.com/apple/swift-log +[ServiceLifecycle]: https://github.com/swift-server/swift-service-lifecycle [`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html diff --git a/Snippets/Birthdays.swift b/Snippets/Birthdays.swift new file mode 100644 index 00000000..60516aa1 --- /dev/null +++ b/Snippets/Birthdays.swift @@ -0,0 +1,74 @@ +import PostgresNIO +import Foundation + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Birthday { + static func main() async throws { + // 1. Create a configuration to match server's parameters + let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "test_username", + password: "test_password", + database: "test_database", + tls: .disable + ) + + // 2. Create a client + let client = PostgresClient(configuration: config) + + // 3. Run the client + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // 4. Create a friends table to store data into + try await client.query(""" + CREATE TABLE IF NOT EXISTS "friends" ( + id SERIAL PRIMARY KEY, + given_name TEXT, + last_name TEXT, + birthday TIMESTAMP WITH TIME ZONE + ) + """ + ) + + // 5. Create a Swift friend representation + struct Friend { + var firstName: String + var lastName: String + var birthday: Date + } + + // 6. Create John Appleseed with special birthday + let dateFormatter = DateFormatter() + dateFormatter.dateFormat = "yyyy-MM-dd" + let johnsBirthday = dateFormatter.date(from: "1960-09-26")! + let friend = Friend(firstName: "Hans", lastName: "Müller", birthday: johnsBirthday) + + // 7. Store friend into the database + try await client.query(""" + INSERT INTO "friends" (given_name, last_name, birthday) + VALUES + (\(friend.firstName), \(friend.lastName), \(friend.birthday)); + """ + ) + + // 8. Query database for the friend we just inserted + let rows = try await client.query(""" + SELECT id, given_name, last_name, birthday FROM "friends" WHERE given_name = \(friend.firstName) + """ + ) + + // 9. Iterate the returned rows, decoding the rows into Swift primitives + for try await (id, firstName, lastName, birthday) in rows.decode((Int, String, String, Date).self) { + print("\(id) | \(firstName) \(lastName), \(birthday)") + } + + // 10. Shutdown the client, by cancelling its run method, through cancelling the taskGroup. + taskGroup.cancelAll() + } + } +} + diff --git a/Snippets/PostgresClient.swift b/Snippets/PostgresClient.swift new file mode 100644 index 00000000..9bfacc28 --- /dev/null +++ b/Snippets/PostgresClient.swift @@ -0,0 +1,40 @@ +import PostgresNIO +import struct Foundation.UUID + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Runner { + static func main() async throws { + +// snippet.configuration +let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable +) +// snippet.end + +// snippet.makeClient +let client = PostgresClient(configuration: config) +// snippet.end + + } + + static func runAndCancel(client: PostgresClient) async { +// snippet.run +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // You can use the client while the `client.run()` method is not cancelled. + + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} +// snippet.end + } +} + diff --git a/Sources/PostgresNIO/Docs.docc/coding.md b/Sources/PostgresNIO/Docs.docc/coding.md new file mode 100644 index 00000000..3bcc4a7e --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/coding.md @@ -0,0 +1,39 @@ +# PostgreSQL data types + +Translate Swift data types to Postgres data types and vica versa. Learn how to write translations +for your own custom Swift types. + +## Topics + +### Essentials + +- ``PostgresCodable`` +- ``PostgresDataType`` +- ``PostgresFormat`` +- ``PostgresNumeric`` + +### Encoding + +- ``PostgresEncodable`` +- ``PostgresNonThrowingEncodable`` +- ``PostgresDynamicTypeEncodable`` +- ``PostgresThrowingDynamicTypeEncodable`` +- ``PostgresArrayEncodable`` +- ``PostgresRangeEncodable`` +- ``PostgresRangeArrayEncodable`` +- ``PostgresEncodingContext`` + +### Decoding + +- ``PostgresDecodable`` +- ``PostgresArrayDecodable`` +- ``PostgresRangeDecodable`` +- ``PostgresRangeArrayDecodable`` +- ``PostgresDecodingContext`` + +### JSON + +- ``PostgresJSONEncoder`` +- ``PostgresJSONDecoder`` + + diff --git a/Sources/PostgresNIO/Docs.docc/deprecated.md b/Sources/PostgresNIO/Docs.docc/deprecated.md new file mode 100644 index 00000000..a29465f6 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/deprecated.md @@ -0,0 +1,43 @@ +# Deprecations + +`PostgresNIO` follows SemVer 2.0.0. Learn which APIs are considered deprecated and how to migrate to +their replacements. + +``PostgresNIO`` reached 1.0 in April 2020. Since then the maintainers have been hard at work to +guarantee API stability. However as the Swift and Swift on server ecosystem have matured approaches +have changed. The introduction of structured concurrency changed what developers expect from a +modern Swift library. Because of this ``PostgresNIO`` added various APIs that embrace the new Swift +patterns. This means however, that PostgresNIO still offers APIs that have fallen out of favor. +Those are documented here. All those APIs will be removed once the maintainers release the next +major version. The maintainers recommend all adopters to move of those APIs sooner rather than +later. + +## Topics + +### Migrate of deprecated APIs + +- + +### Deprecated APIs + +These types are already deprecated or will be deprecated in the near future. All of them will be +removed from the public API with the next major release. + +- ``PostgresDatabase`` +- ``PostgresData`` +- ``PostgresDataConvertible`` +- ``PostgresQueryResult`` +- ``PostgresJSONCodable`` +- ``PostgresJSONBCodable`` +- ``PostgresMessageEncoder`` +- ``PostgresMessageDecoder`` +- ``PostgresRequest`` +- ``PostgresMessage`` +- ``PostgresMessageType`` +- ``PostgresFormatCode`` +- ``PostgresListenContext`` +- ``PreparedQuery`` +- ``SASLAuthenticationManager`` +- ``SASLAuthenticationMechanism`` +- ``SASLAuthenticationError`` +- ``SASLAuthenticationStepResult`` diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index ebe27cd0..6355a7a4 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -8,80 +8,51 @@ ## Overview -Features: - -- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server using [SwiftNIO]. -- An async/await interface that supports backpressure -- Automatic conversions between Swift primitive types and the Postgres wire format -- Integrated with the Swift server ecosystem, including use of [SwiftLog]. -- Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) -- Support for `Network.framework` when available (e.g. on Apple platforms) +``PostgresNIO`` allows you to connect to, authorize with, query, and retrieve results from a +PostgreSQL server. PostgreSQL is an open source relational database. + +Use a ``PostgresConnection`` to create a connection to the PostgreSQL server. You can then use it to +run queries and prepared statements against the server. ``PostgresConnection`` also supports +PostgreSQL's Listen & Notify API. + +Developers, who don't want to manage connections themselves, can use the ``PostgresClient``, which +offers the same functionality as ``PostgresConnection``. ``PostgresClient`` +pools connections for rapid connection reuse and hides the complexities of connection +management from the user, allowing developers to focus on their SQL queries. ``PostgresClient`` +implements the `Service` protocol from Service Lifecycle allowing an easy adoption in Swift server +applications. + +``PostgresNIO`` embraces Swift structured concurrency, offering async/await APIs which handle +task cancellation. The query interface makes use of backpressure to ensure that memory can not grow +unbounded for queries that return thousands of rows. + +``PostgresNIO`` runs efficiently on Linux and Apple platforms. On Apple platforms developers can +configure ``PostgresConnection`` to use `Network.framework` as the underlying transport framework. ## Topics -### Articles - -- - -### Connections +### Essentials +- ``PostgresClient`` +- ``PostgresClient/Configuration`` - ``PostgresConnection`` +- -### Querying - -- ``PostgresQuery`` -- ``PostgresBindings`` -- ``PostgresRow`` -- ``PostgresRowSequence`` -- ``PostgresRandomAccessRow`` -- ``PostgresCell`` -- ``PreparedQuery`` -- ``PostgresQueryMetadata`` - -### Encoding and Decoding +### Advanced -- ``PostgresEncodable`` -- ``PostgresEncodingContext`` -- ``PostgresDecodable`` -- ``PostgresDecodingContext`` -- ``PostgresArrayEncodable`` -- ``PostgresArrayDecodable`` -- ``PostgresJSONEncoder`` -- ``PostgresJSONDecoder`` -- ``PostgresDataType`` -- ``PostgresFormat`` -- ``PostgresNumeric`` - -### Notifications - -- ``PostgresListenContext`` +- +- +- ### Errors - ``PostgresError`` - ``PostgresDecodingError`` +- ``PSQLError`` + +### Deprecations -### Deprecated - -These types are already deprecated or will be deprecated in the near future. All of them will be -removed from the public API with the next major release. - -- ``PostgresDatabase`` -- ``PostgresData`` -- ``PostgresDataConvertible`` -- ``PostgresQueryResult`` -- ``PostgresJSONCodable`` -- ``PostgresJSONBCodable`` -- ``PostgresMessageEncoder`` -- ``PostgresMessageDecoder`` -- ``PostgresRequest`` -- ``PostgresMessage`` -- ``PostgresMessageType`` -- ``PostgresFormatCode`` -- ``SASLAuthenticationManager`` -- ``SASLAuthenticationMechanism`` -- ``SASLAuthenticationError`` -- ``SASLAuthenticationStepResult`` +- [SwiftNIO]: https://github.com/apple/swift-nio [SwiftLog]: https://github.com/apple/swift-log diff --git a/Sources/PostgresNIO/Docs.docc/listen.md b/Sources/PostgresNIO/Docs.docc/listen.md new file mode 100644 index 00000000..10c5d8bf --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/listen.md @@ -0,0 +1,9 @@ +# Listen & Notify + +``PostgresNIO`` supports PostgreSQL's listen and notify API. Learn how to listen for changes and +notify other listeners. + +## Topics + +- ``PostgresNotification`` +- ``PostgresNotificationSequence`` diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md index 7185ba06..3a7c634a 100644 --- a/Sources/PostgresNIO/Docs.docc/migrations.md +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -87,16 +87,4 @@ connection.query("SELECT id, name, email, age FROM users").whenComplete { } ``` -## Topics - -### Relevant types - -- ``PostgresConnection`` -- ``PostgresQuery`` -- ``PostgresBindings`` -- ``PostgresRow`` -- ``PostgresRandomAccessRow`` -- ``PostgresEncodable`` -- ``PostgresDecodable`` - [`1.9.0`]: https://github.com/vapor/postgres-nio/releases/tag/1.9.0 diff --git a/Sources/PostgresNIO/Docs.docc/prepared-statement.md b/Sources/PostgresNIO/Docs.docc/prepared-statement.md new file mode 100644 index 00000000..ff4b1c62 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/prepared-statement.md @@ -0,0 +1,7 @@ +# Boosting Performance with Prepared Statements + +Improve performance by leveraging PostgreSQL's prepared statements. + +## Topics + +- ``PostgresPreparedStatement`` diff --git a/Sources/PostgresNIO/Docs.docc/running-queries.md b/Sources/PostgresNIO/Docs.docc/running-queries.md new file mode 100644 index 00000000..b2c4586f --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/running-queries.md @@ -0,0 +1,27 @@ +# Running Queries + +Interact with the PostgreSQL database by running Queries. + +## Overview + + + +You interact with the Postgres database by running SQL [Queries]. + + + +``PostgresQuery`` conforms to + + +## Topics + +- ``PostgresQuery`` +- ``PostgresBindings`` +- ``PostgresRow`` +- ``PostgresRowSequence`` +- ``PostgresRandomAccessRow`` +- ``PostgresCell`` +- ``PostgresQueryMetadata`` + +[Queries]: doc:PostgresQuery +[`ExpressibleByStringInterpolation`]: https://developer.apple.com/documentation/swift/expressiblebystringinterpolation diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 865dafc8..9383ffcd 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -8,29 +8,28 @@ import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's /// behavior. /// -/// > Warning: -/// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: +/// ## Creating a client /// -/// ```swift -/// let client = PostgresClient(configuration: configuration) -/// await withTaskGroup(of: Void.self) { -/// taskGroup.addTask { -/// client.run() // !important -/// } +/// You create a ``PostgresClient`` by first creating a ``PostgresClient/Configuration`` struct that you can +/// use to modify the client's behavior. /// -/// do { -/// let rows = try await connection.query("SELECT userID, name, age FROM users;") -/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { -/// // do something with the values -/// } -/// } catch { -/// // handle errors -/// } -/// -/// // shutdown the client -/// taskGroup.cancelAll() -/// } -/// ``` +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "configuration") +/// +/// Now you can create a client with your configuration object: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "makeClient") +/// +/// ## Running a client +/// +/// ``PostgresClient`` relies on structured concurrency. Because of this it needs a task in which it can schedule all the +/// background work that it needs to do in order to manage connections on the users behave. For this reason, developers +/// must provide a task to the client by scheduling the client's run method in a long running task: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") +/// +/// ``PostgresClient`` can not lease connections, if its ``run()`` method isn't active. Cancelling the ``run()`` method +/// is equivalent to closing the client. Once a client's ``run()`` method has been cancelled, executing queries or prepared +/// statements will fail. @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { @@ -247,7 +246,9 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let backgroundLogger: Logger /// Creates a new ``PostgresClient``, that does not log any background information. - /// Don't forget to run ``run()`` the client in a long running task. + /// + /// > Warning: + /// The client can only lease connections if the user is running the client's ``run()`` method in a long running task. /// /// - Parameters: /// - configuration: The client's configuration. See ``Configuration`` for details. @@ -399,10 +400,21 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { } } - /// The client's run method. Users must call this function in order to start the client's background task processing - /// like creating and destroying connections and running timers. + /// The structured root task for the client's background work. + /// + /// > Warning: + /// Users must call this function in order to allow the client to process any background work. Executing queries, + /// prepared statements or leasing connections will hang until the developer executes the client's ``run()`` + /// method. + /// + /// Cancelling the task which executes the ``run()`` method, is equivalent to closing the client. Once the task + /// has been cancelled the client is not able to process any new queries or prepared statements. + /// + /// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") /// - /// Calls to ``withConnection(_:)`` will emit a `logger` warning, if ``run()`` hasn't been called previously. + /// > Note: + /// ``PostgresClient`` implements [ServiceLifecycle](https://github.com/swift-server/swift-service-lifecycle)'s `Service` protocol. Because of this + /// ``PostgresClient`` can be passed to a `ServiceGroup` for easier lifecycle management. public func run() async { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") From b6496eb211a0d5c225bcc6d3ff4f26c2dd4238de Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 8 Mar 2024 11:52:04 -0600 Subject: [PATCH 068/106] Fix multiple array type mapping mistakes and add missing date and time array types (#463) * Add missing definitions for Postgres type OIDs 1182 and 1183 (_date and _time), fix typos in the `macaddr8Array` and `datemultirange` types, and add missing array mappings for `timestamp` and `tstzrange`. * Add PostgresArrayCodable conformance for Date * Add tests for date arrays. * Fix test to account for rounding error in conversion to days during Postgres encoding --- .../PostgresNIO/Data/PostgresDataType.swift | 30 +++++++++++---- .../New/Data/Array+PostgresCodable.swift | 7 ++++ Tests/IntegrationTests/PostgresNIOTests.swift | 38 +++++++++++++++++++ .../New/Data/Array+PSQLCodableTests.swift | 4 ++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index f3ab4dca..c3e4e747 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -113,12 +113,14 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri /// `774` public static let macaddr8 = PostgresDataType(774) /// `775` - public static let macaddr8Aray = PostgresDataType(775) + @available(*, deprecated, renamed: "macaddr8Array") + public static let macaddr8Aray = Self.macaddr8Array + public static let macaddr8Array = PostgresDataType(775) /// `790` public static let money = PostgresDataType(790) /// `791` @available(*, deprecated, renamed: "moneyArray") - public static let _money = PostgresDataType(791) + public static let _money = Self.moneyArray public static let moneyArray = PostgresDataType(791) /// `829` public static let macaddr = PostgresDataType(829) @@ -192,6 +194,10 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let timestamp = PostgresDataType(1114) /// `1115` _timestamp public static let timestampArray = PostgresDataType(1115) + /// `1182` + public static let dateArray = PostgresDataType(1182) + /// `1183` + public static let timeArray = PostgresDataType(1183) /// `1184` public static let timestamptz = PostgresDataType(1184) /// `1185` @@ -446,7 +452,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .circle: return "CIRCLE" case .circleArray: return "CIRCLE[]" case .macaddr8: return "MACADDR8" - case .macaddr8Aray: return "MACADDR8[]" + case .macaddr8Array: return "MACADDR8[]" case .money: return "MONEY" case .moneyArray: return "MONEY[]" case .macaddr: return "MACADDR" @@ -485,6 +491,8 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .time: return "TIME" case .timestamp: return "TIMESTAMP" case .timestampArray: return "TIMESTAMP[]" + case .dateArray: return "DATE[]" + case .timeArray: return "TIME[]" case .timestamptz: return "TIMESTAMPTZ" case .timestamptzArray: return "TIMESTAMPTZ[]" case .interval: return "INTERVAL" @@ -596,7 +604,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .line: return .lineArray case .cidr: return .cidrArray case .circle: return .circleArray - case .macaddr8Aray: return .macaddr8 + case .macaddr8: return .macaddr8Array case .money: return .moneyArray case .int2vector: return .int2vectorArray case .regproc: return .regprocArray @@ -613,6 +621,9 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .aclitem: return .aclitemArray case .macaddr: return .macaddrArray case .inet: return .inetArray + case .timestamp: return .timestampArray + case .date: return .dateArray + case .time: return .timeArray case .timestamptz: return .timestamptzArray case .interval: return .intervalArray case .numeric: return .numericArray @@ -635,6 +646,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .regdictionary: return .regdictionaryArray case .numrange: return .numrangeArray case .tsrange: return .tsrangeArray + case .tstzrange: return .tstzrangeArray case .daterange: return .daterangeArray case .jsonpath: return .jsonpathArray case .regnamespace: return .regnamespaceArray @@ -643,7 +655,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .int4multirange: return .int4multirangeArray case .tsmultirange: return .tsmultirangeArray case .tstzmultirange: return .tstzmultirangeArray - case .datemultirange: return .datemultirange + case .datemultirange: return .datemultirangeArray case .int8multirange: return .int8multirangeArray case .bool: return .boolArray case .bytea: return .byteaArray @@ -677,7 +689,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .lineArray: return .line case .cidrArray: return .cidr case .circleArray: return .circle - case .macaddr8: return .macaddr8Aray + case .macaddr8Array: return .macaddr8 case .moneyArray: return .money case .int2vectorArray: return .int2vector case .regprocArray: return .regproc @@ -694,6 +706,9 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .aclitemArray: return .aclitem case .macaddrArray: return .macaddr case .inetArray: return .inet + case .timestampArray: return .timestamp + case .dateArray: return .date + case .timeArray: return .time case .timestamptzArray: return .timestamptz case .intervalArray: return .interval case .numericArray: return .numeric @@ -716,6 +731,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .regdictionaryArray: return .regdictionary case .numrangeArray: return .numrange case .tsrangeArray: return .tsrange + case .tstzrangeArray: return .tstzrange case .daterangeArray: return .daterange case .jsonpathArray: return .jsonpath case .regnamespaceArray: return .regnamespace @@ -724,7 +740,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .int4multirangeArray: return .int4multirange case .tsmultirangeArray: return .tsmultirange case .tstzmultirangeArray: return .tstzmultirange - case .datemultirange: return .datemultirange + case .datemultirangeArray: return .datemultirange case .int8multirangeArray: return .int8multirange case .boolArray: return .bool case .byteaArray: return .bytea diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index d605a6c1..ddab0fff 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -1,4 +1,5 @@ import NIOCore +import struct Foundation.Date import struct Foundation.UUID // MARK: Protocols @@ -85,6 +86,12 @@ extension UUID: PostgresArrayEncodable { public static var psqlArrayType: PostgresDataType { .uuidArray } } +extension Date: PostgresArrayDecodable {} + +extension Date: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .timestamptzArray } +} + extension Range: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} extension Range: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ea4d8d05..de6aaf73 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -783,6 +783,44 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int64?.self), [1, nil, 3]) } + @available(*, deprecated, message: "Testing deprecated functionality") + func testDateArraySerialize() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let date1 = Date(timeIntervalSince1970: 1704088800), + date2 = Date(timeIntervalSince1970: 1706767200), + date3 = Date(timeIntervalSince1970: 1709272800) + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + $1::timestamptz[] as array + """, [ + PostgresData(array: [date1, date2, date3]) + ]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Date.self), [date1, date2, date3]) + } + + @available(*, deprecated, message: "Testing deprecated functionality") + func testDateArraySerializeAsPostgresDate() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + let date1 = Date(timeIntervalSince1970: 1704088800),//8766 + date2 = Date(timeIntervalSince1970: 1706767200),//8797 + date3 = Date(timeIntervalSince1970: 1709272800) //8826 + var data = PostgresData(array: [date1, date2, date3].map { Int32(($0.timeIntervalSince1970 - 946_684_800) / 86_400).postgresData }, elementType: .date) + data.type = .dateArray // N.B.: `.date` format is an Int32 count of days since psqlStartDate + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("select $1::date[] as array", [data]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual( + row?[data: "array"].array(of: Date.self)?.map { Int32((($0.timeIntervalSince1970 - 946_684_800) / 86_400).rounded(.toNearestOrAwayFromZero)) }, + [date1, date2, date3].map { Int32((($0.timeIntervalSince1970 - 946_684_800) / 86_400).rounded(.toNearestOrAwayFromZero)) } + ) + } + // https://github.com/vapor/postgres-nio/issues/143 func testEmptyStringFromNonNullColumn() { var conn: PostgresConnection? diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 79d47c30..bfffef52 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -56,6 +56,10 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertEqual(UUID.psqlType, .uuid) XCTAssertEqual([UUID].psqlType, .uuidArray) + XCTAssertEqual(Date.psqlArrayType, .timestamptzArray) + XCTAssertEqual(Date.psqlType, .timestamptz) + XCTAssertEqual([Date].psqlType, .timestamptzArray) + XCTAssertEqual(Range.psqlArrayType, .int4RangeArray) XCTAssertEqual(Range.psqlType, .int4Range) XCTAssertEqual([Range].psqlType, .int4RangeArray) From 43929b0fa76dae1c3679ea6bea49737b1c94cf40 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 8 Mar 2024 11:59:32 -0600 Subject: [PATCH 069/106] Minor package cleanup (#464) * Disable CodeQL CI, since GitHub seems disinclined to fix their mistakes. * Fix a few very minor issues in the API docs and README. * Make LOG_LEVEL env actually work in tests * Update CI for Swift 5.10 release * We only need two macOS tests, not four --- .github/workflows/test.yml | 20 +++--- README.md | 6 +- .../PostgresNIO/Docs.docc/images/article.svg | 1 - .../Docs.docc/images/vapor-postgres-logo.svg | 60 ------------------ .../images/vapor-postgresnio-logo.svg | 21 +++++++ .../PostgresNIO/Docs.docc/theme-settings.json | 61 ++++++------------- Tests/IntegrationTests/PostgresNIOTests.swift | 8 ++- Tests/IntegrationTests/Utilities.swift | 22 +++---- 8 files changed, 67 insertions(+), 132 deletions(-) delete mode 100644 Sources/PostgresNIO/Docs.docc/images/article.svg delete mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg create mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3d1f44a4..49d2cef1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,10 +21,10 @@ jobs: - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - - swiftlang/swift:nightly-5.10-jammy + - swift:5.10-jammy - swiftlang/swift:nightly-main-jammy include: - - swift-image: swift:5.9-jammy + - swift-image: swift:5.10-jammy code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -63,7 +63,7 @@ jobs: - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.9-jammy + image: swift:5.10-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -140,7 +140,12 @@ jobs: xcode-version: - '~14.3' - '~15.0' - runs-on: macos-13 + include: + - xcode-version: '~14.3' + macos-version: 'macos-13' + - xcode-version: '~15.0' + macos-version: 'macos-14' + runs-on: ${{ matrix.macos-version }} env: POSTGRES_HOSTNAME: 127.0.0.1 POSTGRES_USER: 'test_username' @@ -188,8 +193,9 @@ jobs: swift package diagnose-api-breaking-changes origin/main gh-codeql: + if: ${{ false }} runs-on: ubuntu-latest - container: swift:5.9-jammy + container: swift:jammy permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code @@ -197,10 +203,10 @@ jobs: - name: Mark repo safe in non-fake global config run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: swift - name: Perform build run: swift build - name: Run CodeQL analyze - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/README.md b/README.md index c2dc545e..9e7d4e3b 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,10 @@ Continuous Integration - Swift 5.7 + + Swift 5.7+ - - SSWG Incubation Level: Graduated + + SSWG Incubation Level: Graduated

diff --git a/Sources/PostgresNIO/Docs.docc/images/article.svg b/Sources/PostgresNIO/Docs.docc/images/article.svg deleted file mode 100644 index 3dc6a66c..00000000 --- a/Sources/PostgresNIO/Docs.docc/images/article.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg deleted file mode 100644 index 2b3fe0b1..00000000 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ /dev/null @@ -1,60 +0,0 @@ - - - PostgresNIO - - - - - - - - - - - - - - - - - - diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg new file mode 100644 index 00000000..a831189d --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index a8042a54..dda76197 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,46 +1,21 @@ { - "theme": { - "aside": { - "border-radius": "6px", - "border-style": "double", - "border-width": "3px" - }, - "border-radius": "0", - "button": { - "border-radius": "16px", - "border-width": "1px", - "border-style": "solid" - }, - "code": { - "border-radius": "16px", - "border-width": "1px", - "border-style": "solid" - }, - "color": { - "fill": { - "dark": "rgb(0, 0, 0)", - "light": "rgb(255, 255, 255)" - }, - "psql-blue": "#336791", - "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #000 100%)", - "documentation-intro-accent": "var(--color-psql-blue)", - "documentation-intro-accent-outer": { - "dark": "rgb(255, 255, 255)", - "light": "rgb(0, 0, 0)" - }, - "documentation-intro-accent-inner": { - "dark": "rgb(0, 0, 0)", - "light": "rgb(255, 255, 255)" - } - }, - "icons": { - "technology": "/postgresnio/images/vapor-postgres-logo.svg", - "article": "/postgresnio/images/article.svg" - } + "theme": { + "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "border-radius": "0", + "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "color": { + "psqlnio": "#336791", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", + "documentation-intro-accent": "var(--color-psqlnio)", + "logo-base": { "dark": "#fff", "light": "#000" }, + "logo-shape": { "dark": "#000", "light": "#fff" }, + "fill": { "dark": "#000", "light": "#fff" } }, - "features": { - "quickNavigation": { - "enable": true - } - } + "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } + }, + "features": { + "quickNavigation": { "enable": true }, + "i18n": { "enable": true } + } } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index de6aaf73..88df2519 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -9,12 +9,14 @@ import NIOSSL final class PostgresNIOTests: XCTestCase { private var group: EventLoopGroup! - private var eventLoop: EventLoop { self.group.next() } + override class func setUp() { + XCTAssertTrue(isLoggingConfigured) + } + override func setUpWithError() throws { try super.setUpWithError() - XCTAssertTrue(isLoggingConfigured) self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) } @@ -1475,7 +1477,7 @@ final class PostgresNIOTests: XCTestCase { let isLoggingConfigured: Bool = { LoggingSystem.bootstrap { label in var handler = StreamLogHandler.standardOutput(label: label) - handler.logLevel = env("LOG_LEVEL").flatMap { Logger.Level(rawValue: $0) } ?? .debug + handler.logLevel = env("LOG_LEVEL").flatMap { .init(rawValue: $0) } ?? .info return handler } return true diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index b1788110..001d9ee4 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -24,10 +24,8 @@ extension PostgresConnection { } } - static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func test(on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, @@ -40,10 +38,8 @@ extension PostgresConnection { return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } - static func testUDS(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func testUDS(on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( unixSocketPath: env("POSTGRES_SOCKET") ?? "/tmp/.s.PGSQL.\(env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432)", username: env("POSTGRES_USER") ?? "test_username", @@ -54,10 +50,8 @@ extension PostgresConnection { return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } - static func testChannel(_ channel: Channel, on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func testChannel(_ channel: Channel, on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( establishedChannel: channel, username: env("POSTGRES_USER") ?? "test_username", @@ -71,9 +65,7 @@ extension PostgresConnection { extension Logger { static var psqlTest: Logger { - var logger = Logger(label: "psql.test") - logger.logLevel = .info - return logger + .init(label: "psql.test") } } From 6f0fc054babeed13850f9014e03ced7a1d714868 Mon Sep 17 00:00:00 2001 From: Jia-Han Wu Date: Sat, 9 Mar 2024 04:27:46 +0800 Subject: [PATCH 070/106] Fix `reverseChunked(by:)` Method Implementation (#465) --- Sources/PostgresNIO/Data/PostgresData+Numeric.swift | 10 ++-------- Tests/IntegrationTests/PostgresNIOTests.swift | 8 ++++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift index 5e564d6d..e736a61c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift @@ -268,16 +268,10 @@ private extension Collection { // splits the collection into chunks of the supplied size // if the collection is not evenly divisible, the first chunk will be smaller func reverseChunked(by maxSize: Int) -> [SubSequence] { - var lastDistance = 0 var chunkStartIndex = self.startIndex return stride(from: 0, to: self.count, by: maxSize).reversed().map { current in - let distance = (self.count - current) - lastDistance - lastDistance = distance - let chunkEndOffset = Swift.min( - self.distance(from: chunkStartIndex, to: self.endIndex), - distance - ) - let chunkEndIndex = self.index(chunkStartIndex, offsetBy: chunkEndOffset) + let distance = self.count - current + let chunkEndIndex = self.index(self.startIndex, offsetBy: distance) defer { chunkStartIndex = chunkEndIndex } return self[chunkStartIndex.. Date: Tue, 19 Mar 2024 02:27:23 -0500 Subject: [PATCH 071/106] Temporarily disable Thread Sanitizer in CI --- .github/workflows/test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 49d2cef1..8c6c3897 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,11 +38,11 @@ jobs: swift --version - name: Check out package uses: actions/checkout@v4 - - name: Run unit tests with Thread Sanitizer + - name: Run unit tests env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 @@ -139,11 +139,11 @@ jobs: - scram-sha-256 xcode-version: - '~14.3' - - '~15.0' + - '~15' include: - xcode-version: '~14.3' macos-version: 'macos-13' - - xcode-version: '~15.0' + - xcode-version: '~15' macos-version: 'macos-14' runs-on: ${{ matrix.macos-version }} env: @@ -175,7 +175,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run all tests - run: swift test + run: swift test --sanitize=thread api-breakage: if: github.event_name == 'pull_request' From 8f8724e496a8f26c0c13ceaa347647ac7248d6fd Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 26 Mar 2024 04:55:55 -0500 Subject: [PATCH 072/106] Turn Thread Sanitizer back on in CI (Github-side issue has been fixed) --- .github/workflows/test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c6c3897..7373e17d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,11 +38,11 @@ jobs: swift --version - name: Check out package uses: actions/checkout@v4 - - name: Run unit tests + - name: Run unit tests with Thread Sanitizer env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 @@ -175,7 +175,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run all tests - run: swift test --sanitize=thread + run: swift test api-breakage: if: github.event_name == 'pull_request' From 35587e988316ee42924d7bb72e5cb14735c75470 Mon Sep 17 00:00:00 2001 From: Jay Herron <30518755+NeedleInAJayStack@users.noreply.github.com> Date: Tue, 26 Mar 2024 04:35:26 -0700 Subject: [PATCH 073/106] Fixes `LISTEN` to quote channel name (#466) Co-authored-by: Fabian Fett --- .../New/PostgresChannelHandler.swift | 4 +-- Tests/IntegrationTests/AsyncTests.swift | 29 ++++++++++++------- .../New/PostgresConnectionTests.swift | 10 +++---- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32dea4a5..53dbd8c9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -594,7 +594,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) @@ -642,7 +642,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 75e5b6ba..ce6fe027 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -225,25 +225,32 @@ final class AsyncPostgresConnectionTests: XCTestCase { } func testListenAndNotify() async throws { + let channelNames = [ + "foo", + "default" + ] + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - try await self.withTestConnection(on: eventLoop) { connection in - let stream = try await connection.listen("foo") - var iterator = stream.makeAsyncIterator() + for channelName in channelNames { + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen(channelName) + var iterator = stream.makeAsyncIterator() - try await self.withTestConnection(on: eventLoop) { other in - try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'bar';"#, logger: .psqlTest) - try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) - } + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'foo';"#, logger: .psqlTest) + } - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "bar") + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") - let second = try await iterator.next() - XCTAssertEqual(second?.payload, "foo") + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index f2cd96f8..fe94633a 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -51,7 +51,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -63,7 +63,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -111,7 +111,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -124,7 +124,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -160,7 +160,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) From e345cbb9cf6052b37b27c0c4f976134fc01dbe15 Mon Sep 17 00:00:00 2001 From: Jia-Han Wu Date: Tue, 26 Mar 2024 19:38:41 +0800 Subject: [PATCH 074/106] Fix broken link in README.md (#467) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e7d4e3b..b6cecc2d 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ detailed look at all of the classes, structs, protocols, and more. ## Getting started -Interested in an example? We prepared a simple [Birthday example](/vapor/postgres-nio/tree/main/Snippets/Birthdays.swift) +Interested in an example? We prepared a simple [Birthday example](https://github.com/vapor/postgres-nio/blob/main/Snippets/Birthdays.swift) in the Snippets folder. #### Adding the dependency From ee5d5e159c9892df957e06ac9f1f357502270487 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 1 May 2024 09:23:50 +0100 Subject: [PATCH 075/106] Make `TLS.disable` a let instead of a var (#471) This currently emits a Sendable warning since a global var isn't Sendable safe. --- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 9383ffcd..2116a51d 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -47,7 +47,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { } /// Do not try to create a TLS connection to the server. - public static var disable: Self = Self.init(.disable) + public static let disable: Self = Self.init(.disable) /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. /// If the server does not support TLS, create an insecure connection. From a48eebc4f9c83de18e608f5a096769427e1177b9 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Fri, 10 May 2024 01:33:20 +0330 Subject: [PATCH 076/106] Actually use additional connection parameters (#473) --- .../New/PostgresChannelHandler.swift | 3 +- Tests/IntegrationTests/AsyncTests.swift | 33 ++++++++++++++- Tests/IntegrationTests/Utilities.swift | 9 ++-- .../New/PostgresConnectionTests.swift | 42 +++++++++++++++++++ 4 files changed, 82 insertions(+), 5 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 53dbd8c9..a3190aa7 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -390,7 +390,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { let authContext = AuthContext( username: username, password: self.configuration.password, - database: self.configuration.database + database: self.configuration.database, + additionalParameters: self.configuration.options.additionalStartupParameters ) let action = self.state.provideAuthenticationContext(authContext) return self.run(action, with: context) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index ce6fe027..513157fd 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -84,6 +84,36 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testAdditionalParametersTakeEffect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + current_setting('application_name'); + """ + + let applicationName = "postgres-nio-test" + var options = PostgresConnection.Configuration.Options() + options.additionalStartupParameters = [ + ("application_name", applicationName) + ] + + try await withTestConnection(on: eventLoop, options: options) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode(String.self) { + XCTAssertEqual(element, applicationName) + + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + func testSelectTimeoutWhileLongRunningQuery() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -452,11 +482,12 @@ extension XCTestCase { func withTestConnection( on eventLoop: EventLoop, + options: PostgresConnection.Configuration.Options? = nil, file: StaticString = #filePath, line: UInt = #line, _ closure: (PostgresConnection) async throws -> Result ) async throws -> Result { - let connection = try await PostgresConnection.test(on: eventLoop).get() + let connection = try await PostgresConnection.test(on: eventLoop, options: options).get() do { let result = try await closure(connection) diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 001d9ee4..91dbb62e 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -24,9 +24,9 @@ extension PostgresConnection { } } - static func test(on eventLoop: EventLoop) -> EventLoopFuture { + static func test(on eventLoop: EventLoop, options: Configuration.Options? = nil) -> EventLoopFuture { let logger = Logger(label: "postgres.connection.test") - let config = PostgresConnection.Configuration( + var config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, username: env("POSTGRES_USER") ?? "test_username", @@ -34,7 +34,10 @@ extension PostgresConnection { database: env("POSTGRES_DB") ?? "test_database", tls: .disable ) - + if let options { + config.options = options + } + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index fe94633a..34528f7e 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -38,6 +38,48 @@ class PostgresConnectionTests: XCTestCase { } } + func testOptionsAreSentOnTheWire() async throws { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = await NIOAsyncTestingChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + ], loop: eventLoop) + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = { + var config = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + config.options.additionalStartupParameters = [ + ("DateStyle", "ISO, MDY"), + ("application_name", "postgres-nio-test"), + ("server_encoding", "UTF8"), + ("integer_datetimes", "on"), + ("client_encoding", "UTF8"), + ("TimeZone", "Etc/UTC"), + ("is_superuser", "on"), + ("server_version", "13.1 (Debian 13.1-1.pgdg100+1)"), + ("session_authorization", "postgres"), + ("IntervalStyle", "postgres"), + ("standard_conforming_strings", "on") + ] + return config + }() + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + try await connection.close() + } + func testSimpleListen() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() From e62cc88d244a075e0263b33edb54ef793cd5a1f8 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 28 May 2024 01:22:08 -0500 Subject: [PATCH 077/106] [CI] Update code coverage action, attempt fix for Homebrew nonsense (#476) --- .github/workflows/test.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7373e17d..808718fb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,6 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-jammy + - swiftlang/swift:nightly-6.0-jammy - swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.10-jammy @@ -45,7 +46,9 @@ jobs: swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} - uses: vapor/swift-codecov-action@v0.2 + uses: vapor/swift-codecov-action@v0.3 + with: + codecov_token: ${{ secrets.CODECOV_TOKEN }} linux-integration-and-dependencies: strategy: @@ -165,7 +168,7 @@ jobs: # ** BEGIN ** Work around bug in both Homebrew and GHA (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) - brew upgrade + (brew upgrade || true) # ** END ** Work around bug in both Homebrew and GHA brew install --overwrite "${POSTGRES_FORMULA}" brew link --overwrite --force "${POSTGRES_FORMULA}" From d3795844d488210b65ace34c5f003e47d812d999 Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Wed, 29 May 2024 15:48:59 +0100 Subject: [PATCH 078/106] Workaround DiscardingTaskGroup non-conformance with nightly compilers (#478) --- .../ConnectionPoolModule/ConnectionPool.swift | 20 +++++++++++++------ .../ConnectionPoolTests.swift | 14 ++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 9f25e82c..8ba0e7be 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -419,7 +419,7 @@ public final class ConnectionPool< @inlinable /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { - taskGroup.addTask { + taskGroup.addTask_ { self.observabilityDelegate.startedConnecting(id: request.connectionID) do { @@ -468,7 +468,7 @@ public final class ConnectionPool< /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { self.observabilityDelegate.keepAliveTriggered(id: connection.id) - taskGroup.addTask { + taskGroup.addTask_ { do { try await self.keepAliveBehavior.runKeepAlive(for: connection) @@ -503,7 +503,7 @@ public final class ConnectionPool< @inlinable /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { - poolGroup.addTask { () async -> () in + poolGroup.addTask_ { () async -> () in await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { @@ -587,17 +587,25 @@ extension AsyncStream { @usableFromInline protocol TaskGroupProtocol { - mutating func addTask(operation: @escaping @Sendable () async -> Void) + // We need to call this `addTask_` because some Swift versions define this + // under exactly this name and others have different attributes. So let's pick + // a name that doesn't clash anywhere and implement it using the standard `addTask`. + mutating func addTask_(operation: @escaping @Sendable () async -> Void) } #if swift(>=5.8) && os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) -extension DiscardingTaskGroup: TaskGroupProtocol {} +extension DiscardingTaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} #endif extension TaskGroup: TaskGroupProtocol { @inlinable - mutating func addTask(operation: @escaping @Sendable () async -> Void) { + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { self.addTask(priority: nil, operation: operation) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3e3c9d65..3c0e7a6b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -26,7 +26,7 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -82,14 +82,14 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) - taskGroup.addTask { + taskGroup.addTask_ { _ = try? await factory.nextConnectAttempt { _ in blockCancelContinuation.yield() var iterator = blockConnCreationStream.makeAsyncIterator() @@ -127,7 +127,7 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -170,12 +170,12 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) } - taskGroup.addTask { + taskGroup.addTask_ { var usedConnectionIDs = Set() for _ in 0.. Date: Thu, 30 May 2024 13:31:54 +0200 Subject: [PATCH 079/106] Fix crash when recreating minimal connections (#480) --- .../PoolStateMachine+ConnectionGroup.swift | 2 +- Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 833365fa..f26f244d 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -592,7 +592,7 @@ extension PoolStateMachine { let newConnectionRequest: ConnectionRequest? if self.connections.count < self.minimumConcurrentConnections { - newConnectionRequest = .init(connectionID: self.generator.next()) + newConnectionRequest = self.createNewConnection() } else { newConnectionRequest = .none } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index f5ada14f..2f3ae617 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -375,6 +375,10 @@ final class PoolStateMachineTests: XCTestCase { let connectionClosed = stateMachine.connectionClosed(connection) XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) connection.closeIfClosing() + let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) + XCTAssertEqual(establishAction.request, .none) + guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } + XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) } } From 5c268768890b062803a49f1358becc478f954265 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 13 Jun 2024 18:54:18 +0200 Subject: [PATCH 080/106] Fix totally unnecessary `preconditionFailure` in `PSQLEventsHandler` (#481) --- Sources/PostgresNIO/New/PSQLEventsHandler.swift | 4 +--- Tests/PostgresNIOTests/New/PostgresConnectionTests.swift | 9 +++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 2bf0d6d8..0f426f20 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -68,10 +68,8 @@ final class PSQLEventsHandler: ChannelInboundHandler { case .authenticated: break } - case TLSUserEvent.shutdownCompleted: - break default: - preconditionFailure() + context.fireUserInboundEventTriggered(event) } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 34528f7e..209522dd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -416,6 +416,15 @@ class PostgresConnectionTests: XCTestCase { } } + func testWeDontCrashOnUnexpectedChannelEvents() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + } + func testSerialExecutionOfSamePreparedStatement() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() From e7b9a08a11c0a4eedafb8032f13cfa764ae45b13 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 13 Jun 2024 12:18:17 -0500 Subject: [PATCH 081/106] [CI] Use Ubuntu 24.04 image, more code coverage, disable CodeQL completely (#482) * [CI] Use Ubuntu 24.04 image for Swift 5.10, upload code coverage more often, completely disable CodeQL * Add CODEOWNERS --- .github/CODEOWNERS | 1 + .github/workflows/test.yml | 50 +++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 28 deletions(-) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..6413432f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @fabianfett @gwynne diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 808718fb..f74427c3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,12 +21,9 @@ jobs: - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - - swift:5.10-jammy + - swift:5.10-noble - swiftlang/swift:nightly-6.0-jammy - swiftlang/swift:nightly-main-jammy - include: - - swift-image: swift:5.10-jammy - code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: @@ -40,12 +37,9 @@ jobs: - name: Check out package uses: actions/checkout@v4 - name: Run unit tests with Thread Sanitizer - env: - CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread --enable-code-coverage - name: Submit code coverage - if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.3 with: codecov_token: ${{ secrets.CODECOV_TOKEN }} @@ -66,7 +60,7 @@ jobs: - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.10-jammy + image: swift:5.10-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -183,7 +177,7 @@ jobs: api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:jammy + container: swift:noble steps: - name: Checkout uses: actions/checkout@v4 @@ -195,21 +189,21 @@ jobs: git config --global --add safe.directory "${GITHUB_WORKSPACE}" swift package diagnose-api-breaking-changes origin/main - gh-codeql: - if: ${{ false }} - runs-on: ubuntu-latest - container: swift:jammy - permissions: { actions: write, contents: read, security-events: write } - steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Mark repo safe in non-fake global config - run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: swift - - name: Perform build - run: swift build - - name: Run CodeQL analyze - uses: github/codeql-action/analyze@v3 +# gh-codeql: +# if: ${{ false }} +# runs-on: ubuntu-latest +# container: swift:noble +# permissions: { actions: write, contents: read, security-events: write } +# steps: +# - name: Check out code +# uses: actions/checkout@v4 +# - name: Mark repo safe in non-fake global config +# run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" +# - name: Initialize CodeQL +# uses: github/codeql-action/init@v3 +# with: +# languages: swift +# - name: Perform build +# run: swift build +# - name: Run CodeQL analyze +# uses: github/codeql-action/analyze@v3 From 5f541d05970a4fad5accb54365191f1f8e91ea3e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 11:49:36 +0200 Subject: [PATCH 082/106] Drop support for Swift 5.7 (#485) --- .github/workflows/test.yml | 1 - Package.swift | 2 +- README.md | 4 ++-- Sources/ConnectionPoolModule/ConnectionPool.swift | 8 ++++---- Sources/PostgresNIO/New/NotificationListener.swift | 2 +- .../New/PostgresNotificationSequence.swift | 7 +------ Sources/PostgresNIO/Utilities/Exports.swift | 11 ----------- 7 files changed, 9 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f74427c3..1761880d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-noble diff --git a/Package.swift b/Package.swift index 4d008371..79c740f9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.7 +// swift-tools-version:5.8 import PackageDescription let package = Package( diff --git a/README.md b/README.md index b6cecc2d..bc56953b 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Continuous Integration - Swift 5.7+ + Swift 5.8+ SSWG Incubation Level: Graduated @@ -167,7 +167,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.7]: https://swift.org +[Swift 5.8]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 8ba0e7be..3231cc06 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -273,7 +273,7 @@ public final class ConnectionPool< public func run() async { await withTaskCancellationHandler { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { return await withDiscardingTaskGroup() { taskGroup in await self.run(in: &taskGroup) @@ -313,7 +313,7 @@ public final class ConnectionPool< case scheduleTimer(StateMachine.Timer) } - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) private func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { @@ -507,7 +507,7 @@ public final class ConnectionPool< await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) try await self.clock.sleep(for: timer.duration) #else try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) @@ -593,7 +593,7 @@ protocol TaskGroupProtocol { mutating func addTask_(operation: @escaping @Sendable () async -> Void) } -#if swift(>=5.8) && os(Linux) || swift(>=5.9) +#if os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol { @inlinable diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 9e47ff34..4982b8ad 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -142,7 +142,7 @@ final class NotificationListener: @unchecked Sendable { } -#if swift(<5.9) +#if compiler(<5.9) // Async stream API backfill extension AsyncThrowingStream { static func makeStream( diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 55fb0670..735c01b0 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -3,7 +3,7 @@ public struct PostgresNotification: Sendable { public let payload: String } -public struct PostgresNotificationSequence: AsyncSequence { +public struct PostgresNotificationSequence: AsyncSequence, Sendable { public typealias Element = PostgresNotification let base: AsyncThrowingStream @@ -20,8 +20,3 @@ public struct PostgresNotificationSequence: AsyncSequence { } } } - -#if swift(>=5.7) -// AsyncThrowingStream is marked as Sendable in Swift 5.6 -extension PostgresNotificationSequence: Sendable {} -#endif diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 58e12891..144ff3c9 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,14 +1,3 @@ -#if swift(>=5.8) - @_documentation(visibility: internal) @_exported import NIO @_documentation(visibility: internal) @_exported import NIOSSL @_documentation(visibility: internal) @_exported import struct Logging.Logger - -#else - -// TODO: Remove this with the next major release! -@_exported import NIO -@_exported import NIOSSL -@_exported import struct Logging.Logger - -#endif From 6c3d0a938d248965da42d451f619cf74f0fff882 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 11:53:39 +0200 Subject: [PATCH 083/106] Update ServiceLifecycle to 2.5.0 (#484) --- Package.swift | 2 +- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Package.swift b/Package.swift index 79c740f9..d24ee979 100644 --- a/Package.swift +++ b/Package.swift @@ -22,7 +22,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), - .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.4.1"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"), ], targets: [ .target( diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2116a51d..2e1b7e11 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -419,7 +419,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") - await cancelOnGracefulShutdown { + await cancelWhenGracefulShutdown { await self.pool.run() } } From 7b621c16f6a0a8a0af8badd56b6f980457a1b7c5 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 13:32:56 +0200 Subject: [PATCH 084/106] Enable StrictConcurrency checking (#483) --- Package.swift | 19 ++++-- .../ConnectionPoolModule/ConnectionPool.swift | 4 +- .../ConnectionPoolObservabilityDelegate.swift | 2 +- .../Message/PostgresMessage+Identifier.swift | 2 +- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- .../Utilities/PostgresError+Code.swift | 2 +- .../Mocks/MockConnectionFactory.swift | 2 +- Tests/IntegrationTests/PostgresNIOTests.swift | 61 ++++++++++--------- .../New/PostgresConnectionTests.swift | 19 +++--- 9 files changed, 63 insertions(+), 50 deletions(-) diff --git a/Package.swift b/Package.swift index d24ee979..0683dbe9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,6 +1,10 @@ // swift-tools-version:5.8 import PackageDescription +let swiftSettings: [SwiftSetting] = [ + .enableUpcomingFeature("StrictConcurrency") +] + let package = Package( name: "postgres-nio", platforms: [ @@ -41,7 +45,8 @@ let package = Package( .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), - ] + ], + swiftSettings: swiftSettings ), .target( name: "_ConnectionPoolModule", @@ -49,7 +54,8 @@ let package = Package( .product(name: "Atomics", package: "swift-atomics"), .product(name: "DequeModule", package: "swift-collections"), ], - path: "Sources/ConnectionPoolModule" + path: "Sources/ConnectionPoolModule", + swiftSettings: swiftSettings ), .testTarget( name: "PostgresNIOTests", @@ -57,7 +63,8 @@ let package = Package( .target(name: "PostgresNIO"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "ConnectionPoolModuleTests", @@ -67,14 +74,16 @@ let package = Package( .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "IntegrationTests", dependencies: [ .target(name: "PostgresNIO"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), ] ) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 3231cc06..03c269ee 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -1,6 +1,6 @@ @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -public struct ConnectionAndMetadata { +public struct ConnectionAndMetadata: Sendable { public var connection: Connection @@ -495,7 +495,7 @@ public final class ConnectionPool< } @usableFromInline - enum TimerRunResult { + enum TimerRunResult: Sendable { case timerTriggered case timerCancelled case cancellationContinuationFinished diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift index 35f30dcb..fc1e300c 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -37,7 +37,7 @@ public protocol ConnectionPoolObservabilityDelegate: Sendable { func requestQueueDepthChanged(_ newDepth: Int) } -public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { public init(connectionIDType: ConnectionID.Type) {} public func startedConnecting(id: ConnectionID) {} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 786b91ef..5d111e3b 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -4,7 +4,7 @@ extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. /// Values are not unique across all identifiers, meaning some messages will require keeping state to identify. @available(*, deprecated, message: "Will be removed from public API.") - public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { + public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { // special public static let none: Identifier = 0x00 // special diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2e1b7e11..0907f1f8 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -478,7 +478,7 @@ extension PostgresConnection: PooledConnection { self.channel.close(mode: .all, promise: nil) } - public func onClose(_ closure: @escaping ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { self.closeFuture.whenComplete { _ in closure(nil) } } } diff --git a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift index 11224f4b..fae903fe 100644 --- a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift +++ b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift @@ -1,5 +1,5 @@ extension PostgresError { - public struct Code: ExpressibleByStringLiteral, Equatable { + public struct Code: Sendable, ExpressibleByStringLiteral, Equatable { // Class 00 — Successful Completion public static let successfulCompletion: Code = "00000" diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift index eec2e7c3..1c9bfff8 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -2,7 +2,7 @@ import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory where Clock.Duration == Duration { +final class MockConnectionFactory: Sendable where Clock.Duration == Duration { typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator typealias Request = ConnectionRequest typealias KeepAliveBehavior = MockPingPongBehavior diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 4d06c13e..ff59209b 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1,5 +1,6 @@ import Logging @testable import PostgresNIO +import Atomics import XCTest import NIOCore import NIOPosix @@ -112,59 +113,59 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications: [PostgresMessage.NotificationResponse] = [] + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsNonEmptyPayload() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications: [PostgresMessage.NotificationResponse] = [] + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "Notification payload example") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "Notification payload example") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerWithinHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) context.stop() } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerOutsideHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) let context = conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) } XCTAssertNotNil(context) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) @@ -173,47 +174,47 @@ final class PostgresNIOTests: XCTestCase { context?.stop() XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlers() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) } - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 1) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlersRemoval() throws { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) context.stop() }) - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) }) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 2) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2) } func testNotificationHandlerFiltersOnChannel() { @@ -1283,11 +1284,11 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } var queries: [[PostgresRow]]? - XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { query in + XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in let a = query.execute(["a"]) let b = query.execute(["b"]) let c = query.execute(["c"]) - return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) + return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) var resultIterator = queries?.makeIterator() diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 209522dd..0bc61efd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -187,7 +187,7 @@ class PostgresConnectionTests: XCTestCase { func testSimpleListenConnectionDrops() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { let events = try await connection.listen("foo") var iterator = events.makeAsyncIterator() @@ -197,7 +197,7 @@ class PostgresConnectionTests: XCTestCase { _ = try await iterator.next() XCTFail("Did not expect to not throw") } catch { - self.logger.error("error", metadata: ["error": "\(error)"]) + logger.error("error", metadata: ["error": "\(error)"]) } } @@ -226,10 +226,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: self.logger) + let rows = try await connection.query("SELECT 1;", logger: logger) var iterator = rows.decode(Int.self).makeAsyncIterator() let first = try await iterator.next() XCTAssertEqual(first, 1) @@ -286,10 +286,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseClosesImmediatly() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - try await connection.query("SELECT 1;", logger: self.logger) + try await connection.query("SELECT 1;", logger: logger) } } @@ -319,8 +319,9 @@ class PostgresConnectionTests: XCTestCase { func testIfServerJustClosesTheErrorReflectsThat() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: self.logger) + async let response = try await connection.query("SELECT 1;", logger: logger) let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") @@ -423,6 +424,7 @@ class PostgresConnectionTests: XCTestCase { case pleaseDontCrash } channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() } func testSerialExecutionOfSamePreparedStatement() async throws { @@ -651,7 +653,8 @@ class PostgresConnectionTests: XCTestCase { database: "database" ) - async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) + let logger = self.logger + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) From f55caa7745a43357f7af7dfdd0300955dbd8c6a3 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Mon, 24 Jun 2024 15:15:22 +0330 Subject: [PATCH 085/106] [Fix] Query Hangs if Connection is Closed (#487) --- .../Connection/PostgresConnection.swift | 39 ++-- .../PSQLIntegrationTests.swift | 1 - .../New/PostgresConnectionTests.swift | 169 ++++++++++++++++++ 3 files changed, 197 insertions(+), 12 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..a6efcfdf 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) return promise.futureResult } @@ -239,7 +239,8 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -255,7 +256,8 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -263,7 +265,8 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.channel.write(HandlerTask.closeCommand(context), promise: nil) + self.write(.closeCommand(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -426,7 +429,7 @@ extension PostgresConnection { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -455,7 +458,11 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.whenFailure { error in + listener.failed(error) + } } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -480,7 +487,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -515,7 +524,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.commandTag } @@ -530,6 +541,12 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.cascadeFailure(to: promise) + } } // MARK: EventLoopFuture interface @@ -674,7 +691,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - private let promise: EventLoopPromise + let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -713,8 +730,7 @@ extension PostgresConnection { closure: notificationHandler ) - let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -761,3 +777,4 @@ extension PostgresConnection { #endif } } + diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..913d91b2 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,5 +359,4 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } - } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0bc61efd..5c7d4c83 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase { } } + func testSimpleListenFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.listen("test_channel") + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "wooohooo") + do { + _ = try await iterator.next() + XCTFail("Did not expect to not throw") + } catch let error as PSQLError { + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + try await connection.close() + + XCTAssertEqual(channel.isActive, false) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase { } } + func testQueryFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.query("SELECT version;", logger: self.logger) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testPrepareStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecuteFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) + _ = try await connection.execute(statement, logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" + typealias Row = () + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ From 200a94a13381f2cbc2c4f5303da777997a80937d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 27 Jun 2024 17:59:03 +0200 Subject: [PATCH 086/106] Explicitly mark the AsyncSequence iterators as non Sendable (#490) --- Package.swift | 2 +- Sources/PostgresNIO/New/PostgresNotificationSequence.swift | 3 +++ Sources/PostgresNIO/New/PostgresQuery.swift | 2 +- Sources/PostgresNIO/New/PostgresRowSequence.swift | 5 ++++- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Package.swift b/Package.swift index 0683dbe9..5c83eded 100644 --- a/Package.swift +++ b/Package.swift @@ -2,7 +2,7 @@ import PackageDescription let swiftSettings: [SwiftSetting] = [ - .enableUpcomingFeature("StrictConcurrency") + .enableUpcomingFeature("StrictConcurrency"), ] let package = Package( diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 735c01b0..d8f525eb 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -20,3 +20,6 @@ public struct PostgresNotificationSequence: AsyncSequence, Sendable { } } } + +@available(*, unavailable) +extension PostgresNotificationSequence.AsyncIterator: Sendable {} diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 1cfcf2dc..b695dcfe 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -26,7 +26,7 @@ extension PostgresQuery: ExpressibleByStringInterpolation { } extension PostgresQuery { - public struct StringInterpolation: StringInterpolationProtocol { + public struct StringInterpolation: StringInterpolationProtocol, Sendable { public typealias StringLiteralType = String @usableFromInline diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index ccf4f69c..3936b51e 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -4,7 +4,7 @@ import NIOConcurrencyHelpers /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. -public struct PostgresRowSequence: AsyncSequence { +public struct PostgresRowSequence: AsyncSequence, Sendable { public typealias Element = PostgresRow typealias BackingSequence = NIOThrowingAsyncSequenceProducer @@ -56,6 +56,9 @@ extension PostgresRowSequence { } } +@available(*, unavailable) +extension PostgresRowSequence.AsyncIterator: Sendable {} + extension PostgresRowSequence { public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() From d18b137640222fe29a22568077c4799d213fdf96 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Thu, 25 Jul 2024 09:56:51 +0100 Subject: [PATCH 087/106] Change 'unsafeDowncast' to 'as!' (#495) Motivation: The 'unsafeDowncast' can cause a miscompile leading to unexpected runtime behaviour. Modifications: - Use 'as!' instead Result: No miscompiles on 5.10 --- Sources/ConnectionPoolModule/NIOLock.swift | 29 +++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index dbc7dbe9..13a9df4a 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -52,12 +52,12 @@ extension LockOperations { debugOnly { pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) } - + let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -69,7 +69,7 @@ extension LockOperations { precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -81,7 +81,7 @@ extension LockOperations { precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -125,49 +125,50 @@ extension LockOperations { // See also: https://github.com/apple/swift/pull/40000 @usableFromInline final class LockStorage: ManagedBuffer { - + @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in return value } - let storage = unsafeDowncast(buffer, to: Self.self) - + // Avoid 'unsafeDowncast' as there is a miscompilation on 5.10. + let storage = buffer as! Self + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } - + return storage } - + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } - + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } - + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } - + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in return try body(lockPtr) } } - + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { try self.withUnsafeMutablePointers { valuePtr, lockPtr in @@ -192,7 +193,7 @@ extension LockStorage: @unchecked Sendable { } struct NIOLock { @usableFromInline internal let _storage: LockStorage - + /// Create a new lock. @inlinable init() { From cd5318a01a1efcb1e0b3c82a0ce5c9fefaf1cb2d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 20 Aug 2024 14:04:28 +0200 Subject: [PATCH 088/106] Revert "[Fix] Query Hangs if Connection is Closed (#487)" (#501) This reverts commit f55caa7745a43357f7af7dfdd0300955dbd8c6a3. --- .../Connection/PostgresConnection.swift | 39 ++-- .../PSQLIntegrationTests.swift | 1 + .../New/PostgresConnectionTests.swift | 169 ------------------ 3 files changed, 12 insertions(+), 197 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index a6efcfdf..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -256,8 +255,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,8 +263,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.write(.closeCommand(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -429,7 +426,7 @@ extension PostgresConnection { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -458,11 +455,7 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.whenFailure { error in - listener.failed(error) - } + self.channel.write(task, promise: nil) } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -487,9 +480,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -524,9 +515,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.commandTag } @@ -541,12 +530,6 @@ extension PostgresConnection { throw error // rethrow with more metadata } } - - private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.cascadeFailure(to: promise) - } } // MARK: EventLoopFuture interface @@ -691,7 +674,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - let promise: EventLoopPromise + private let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -730,7 +713,8 @@ extension PostgresConnection { closure: notificationHandler ) - self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -777,4 +761,3 @@ extension PostgresConnection { #endif } } - diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 913d91b2..57939c06 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,4 +359,5 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } + } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 5c7d4c83..0bc61efd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,63 +224,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testSimpleListenFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.listen("test_channel") - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch let error as PSQLError { - XCTAssertEqual(error.code, .clientClosedConnection) - } - } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - try await connection.close() - - XCTAssertEqual(channel.isActive, false) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -695,118 +638,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testQueryFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.query("SELECT version;", logger: self.logger) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testPrepareStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecuteFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) - _ = try await connection.execute(statement, logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" - typealias Row = (Int, String) - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - try row.decode(Row.self) - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" - typealias Row = () - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - () - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ From 3de37e6438d018159a9c3ef1ea0ca154039ce480 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Wed, 21 Aug 2024 16:15:31 +0330 Subject: [PATCH 089/106] Handle `EmptyQueryResponse` (#500) --- .../ExtendedQueryStateMachine.swift | 36 ++++++++-- Sources/PostgresNIO/New/PSQLRowStream.swift | 66 +++++++++++-------- .../New/PostgresChannelHandler.swift | 4 +- .../PostgresNIO/PostgresDatabase+Query.swift | 5 +- .../PSQLIntegrationTests.swift | 19 ++++++ .../ExtendedQueryStateMachineTests.swift | 22 ++++++- .../PreparedStatementStateMachineTests.swift | 8 +-- .../New/PSQLRowStreamTests.swift | 2 +- 8 files changed, 114 insertions(+), 48 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..087a6c24 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine { case parameterDescriptionReceived(ExtendedQueryContext) case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) - + case emptyQueryResponseReceived + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(.queryCancelled, read: true) } - case .commandComplete, .error, .drain: + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: // the stream has already finished. return .wait @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine { .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, + .emptyQueryResponseReceived, .bindCompleteReceived, .streaming, .drain, @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, .error: @@ -319,7 +323,22 @@ struct ExtendedQueryStateMachine { } mutating func emptyQueryResponseReceived() -> Action { - preconditionFailure("Unimplemented") + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } } mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { @@ -336,7 +355,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) - case .commandComplete: + case .commandComplete, .emptyQueryResponseReceived: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) case .error: preconditionFailure(""" @@ -382,6 +401,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: preconditionFailure("Requested to consume next row without anything going on.") @@ -405,6 +425,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: return .wait @@ -449,6 +470,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .emptyQueryResponseReceived, .drain, .error: // we already have the complete stream received, now we are waiting for a @@ -495,7 +517,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(error, read: true) } - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the ConnectionStateMachine must not send any further events to the substate machine. @@ -507,7 +529,7 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: return true case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b7f2d4fb..ee925d0e 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,7 +3,7 @@ import Logging struct QueryResult { enum Value: Equatable { - case noRows(String) + case noRows(PSQLRowStream.StatementSummary) case rowDescription([RowDescription.Column]) } @@ -16,25 +16,30 @@ struct QueryResult { final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + enum Source { case stream([RowDescription.Column], PSQLRowsDataSource) - case noRows(Result) + case noRows(Result) } let eventLoop: EventLoop let logger: Logger - + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case finished(buffer: CircularBuffer, summary: StatementSummary) case failure(Error) } - + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable { case .stream(let rowDescription, let dataSource): self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) - case .noRows(.success(let commandTag)): + case .noRows(.success(let summary)): self.rowDescription = [] - bufferState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), summary: summary) case .noRows(.failure(let error)): self.rowDescription = [] bufferState = .failure(error) @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) - + self.downstreamState = .consumed(.success(summary)) + case .failure(let error): source.finish(error) self.downstreamState = .consumed(.failure(error)) @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.request(for: self) return promise.futureResult - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): let rows = buffer.map { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable { } return promise.futureResult - - case .finished(let buffer, let commandTag): + + case .finished(let buffer, let summary): do { for data in buffer { let row = PostgresRow( @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.finished), .waitingForConsumer(.failure): preconditionFailure("How can new rows be received, if an end was already signalled?") - + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable { private func receiveEnd(_ commandTag: String) { switch self.downstreamState { case .waitingForConsumer(.streaming(buffer: let buffer, _)): - self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) - - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.streaming): self.downstreamState = .waitingForConsumer(.failure(error)) - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.downstreamState else { + guard case .consumed(.success(let consumed)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } - return commandTag + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3190aa7..ee2af0fe 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -550,9 +550,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.rowStream = rows - case .noRows(let commandTag): + case .noRows(let summary): rows = PSQLRowStream( - source: .noRows(.success(commandTag)), + source: .noRows(.success(summary)), eventLoop: context.channel.eventLoop, logger: result.logger ) diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 01a7e61f..483d5a7b 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable { init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..d541899b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(foo, "hello") } + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 40e32468..ae484acc 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } - + + func testExtendedQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "-- some comments" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + func testReceiveTotallyUnexpectedMessageInQuery() { var state = ConnectionStateMachine.readyForQuery() diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index f6c1ddf7..e35e93f7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -28,7 +28,7 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertEqual(preparationCompleteAction.statements.count, 1) XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -46,7 +46,7 @@ class PreparedStatementStateMachineTests: XCTestCase { return } secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -135,12 +135,12 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 9a1e9e41..65ca26c3 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -12,7 +12,7 @@ final class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { let stream = PSQLRowStream( - source: .noRows(.success("INSERT 0 1")), + source: .noRows(.success(.tag("INSERT 0 1"))), eventLoop: self.eventLoop, logger: self.logger ) From 9f84290f4f7ba3b3edb749d196243fc2df6b82e6 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Thu, 22 Aug 2024 22:22:00 +0330 Subject: [PATCH 090/106] Fix Flaky Nightly Tests (#503) --- Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 769bde4b..3f406598 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Date_PSQLCodableTests: XCTestCase { var result: Date? XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) - XCTAssertEqual(value, result) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) } func testDecodeRandomDate() { From 8f7e9002462c1a625e590e568fe31251a2429c8a Mon Sep 17 00:00:00 2001 From: Lei Nelissen Date: Wed, 25 Sep 2024 16:48:33 +0200 Subject: [PATCH 091/106] Fix cross-compilation to the static Linux SDK (#510) --- Sources/ConnectionPoolModule/PoolStateMachine.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 3b996033..6e41f730 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -1,7 +1,9 @@ #if canImport(Darwin) import Darwin -#else +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #endif @usableFromInline From c13a11a97b9878cdc1366b4adf03c03cea0b6163 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 9 Oct 2024 03:33:36 -0500 Subject: [PATCH 092/106] Drop Swift 5.8 support and update CI (#515) --- .github/workflows/test.yml | 16 ++++++---------- Package.swift | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1761880d..8364e8ae 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,10 +18,9 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-noble - - swiftlang/swift:nightly-6.0-jammy + - swift:6.0-noble - swiftlang/swift:nightly-main-jammy container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -48,13 +47,13 @@ jobs: fail-fast: false matrix: postgres-image: - - postgres:16 - - postgres:14 + - postgres:17 + - postgres:15 - postgres:12 include: - - postgres-image: postgres:16 + - postgres-image: postgres:17 postgres-auth: scram-sha-256 - - postgres-image: postgres:14 + - postgres-image: postgres:15 postgres-auth: md5 - postgres-image: postgres:12 postgres-auth: trust @@ -134,11 +133,8 @@ jobs: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 xcode-version: - - '~14.3' - '~15' include: - - xcode-version: '~14.3' - macos-version: 'macos-13' - xcode-version: '~15' macos-version: 'macos-14' runs-on: ${{ matrix.macos-version }} @@ -172,7 +168,7 @@ jobs: uses: actions/checkout@v4 - name: Run all tests run: swift test - + api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest diff --git a/Package.swift b/Package.swift index 5c83eded..5f6562f6 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.8 +// swift-tools-version:5.9 import PackageDescription let swiftSettings: [SwiftSetting] = [ From 225c5c4adaf48e69fec20321187843c75dada65d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 9 Oct 2024 10:43:37 +0200 Subject: [PATCH 093/106] Remove all code that solely existed to support Swift 5.8 (#516) --- .../ConnectionPoolModule/ConnectionPool.swift | 14 - .../New/NotificationListener.swift | 16 - .../New/PostgresRow-multi-decode.swift | 1175 ----------------- .../PostgresRowSequence-multi-decode.swift | 215 --- .../PostgresNIO/New/VariadicGenerics.swift | 4 +- 5 files changed, 1 insertion(+), 1423 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PostgresRow-multi-decode.swift delete mode 100644 Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 03c269ee..5cdb980d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -571,20 +571,6 @@ extension PoolConfiguration { } } -#if swift(<5.9) -// This should be removed once we support Swift 5.9+ only -extension AsyncStream { - static func makeStream( - of elementType: Element.Type = Element.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { - var continuation: AsyncStream.Continuation! - let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } -} -#endif - @usableFromInline protocol TaskGroupProtocol { // We need to call this `addTask_` because some Swift versions define this diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 4982b8ad..2f784e33 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -140,19 +140,3 @@ final class NotificationListener: @unchecked Sendable { } } } - - -#if compiler(<5.9) -// Async stream API backfill -extension AsyncThrowingStream { - static func makeStream( - of elementType: Element.Type = Element.self, - throwing failureType: Failure.Type = Failure.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { - var continuation: AsyncThrowingStream.Continuation! - let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } - } - #endif diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift deleted file mode 100644 index 71aa04dc..00000000 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ /dev/null @@ -1,1175 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh - -#if compiler(<5.9) -extension PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0) { - precondition(self.columns.count >= 1) - let columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - let column = columnIterator.next().unsafelyUnwrapped - let swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) throws -> (T0) { - try self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - precondition(self.columns.count >= 2) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - try self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - precondition(self.columns.count >= 3) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - precondition(self.columns.count >= 4) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - precondition(self.columns.count >= 5) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - precondition(self.columns.count >= 6) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - precondition(self.columns.count >= 7) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - precondition(self.columns.count >= 8) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - precondition(self.columns.count >= 9) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - precondition(self.columns.count >= 10) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - precondition(self.columns.count >= 11) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - precondition(self.columns.count >= 12) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - precondition(self.columns.count >= 13) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - precondition(self.columns.count >= 14) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - precondition(self.columns.count >= 15) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 14 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T14.self - let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift deleted file mode 100644 index f45357d8..00000000 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ /dev/null @@ -1,215 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh - -#if compiler(<5.9) -extension AsyncSequence where Element == PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode(T0.self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift index 312d36dc..7931c90c 100644 --- a/Sources/PostgresNIO/New/VariadicGenerics.swift +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.9) + extension PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable @@ -170,5 +170,3 @@ enum ComputeParameterPackLength { MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride } } -#endif // compiler(>=5.9) - From d4c2f38ff5b5bdce6fd952ee75670631c4c8b5a4 Mon Sep 17 00:00:00 2001 From: Robert Cottrell Date: Mon, 21 Oct 2024 08:19:08 +0100 Subject: [PATCH 094/106] Allow bindings with optional values in PostgresBindings (#520) --- Sources/PostgresNIO/New/PostgresQuery.swift | 46 ++++++++++++ Tests/IntegrationTests/AsyncTests.swift | 81 +++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index b695dcfe..6449ab29 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable { try self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + @inlinable public mutating func append(_ value: Value) { self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + @inlinable mutating func appendUnprotected( _ value: Value, diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 513157fd..b4c8e93f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -476,6 +476,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTFail("Unexpected error: \(String(describing: error))") } } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { From f2a6394a2e7157d547727b975fc0328b92f89fb1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Oct 2024 10:57:36 +0200 Subject: [PATCH 095/106] Support `additionalStartupParameters` in PostgresClient (#521) --- .../PostgresNIO/Pool/ConnectionFactory.swift | 1 + Sources/PostgresNIO/Pool/PostgresClient.swift | 4 +++ .../PostgresClientTests.swift | 35 +++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift index 77a0c047..319b86c4 100644 --- a/Sources/PostgresNIO/Pool/ConnectionFactory.swift +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -89,6 +89,7 @@ final class ConnectionFactory: Sendable { connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) connectionConfig.options.tlsServerName = config.options.tlsServerName connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters return connectionConfig } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0907f1f8..ad8a4bf1 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -106,6 +106,10 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool = true + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] = [] + /// The minimum number of connections that the client shall keep open at any time, even if there is no /// demand. Default to `0`. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index d6d89dc3..579c92cd 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -43,6 +43,41 @@ final class PostgresClientTests: XCTestCase { } } + func testApplicationNameIsForwardedCorrectly() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + var clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let applicationName = "postgres_nio_test_run" + clientConfig.options.additionalStartupParameters = [("application_name", applicationName)] + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + let rows = try await client.query("select * from pg_stat_activity;"); + var applicationNameFound = 0 + for try await row in rows { + let randomAccessRow = row.makeRandomAccess() + if try randomAccessRow["application_name"].decode(String?.self) == applicationName { + applicationNameFound += 1 + } + } + + XCTAssertGreaterThanOrEqual(applicationNameFound, 1) + + taskGroup.cancelAll() + } + } + + func testQueryDirectly() async throws { var mlogger = Logger(label: "test") mlogger.logLevel = .debug From 96ed89ff0dc457a2533bed80d4cf2a87976bc296 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Sun, 8 Dec 2024 23:04:18 +0100 Subject: [PATCH 096/106] Correctly place the SSL channel handler in front of the PostgresChannelHandler (#527) --- Sources/PostgresNIO/Connection/PostgresConnection.swift | 6 +++--- Sources/PostgresNIO/New/PostgresChannelHandler.swift | 8 ++++---- .../New/PostgresChannelHandlerTests.swift | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..229cd647 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable { func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers - let configureSSLCallback: ((Channel) throws -> ())? + let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())? switch configuration.tls.base { case .prefer(let context), .require(let context): - configureSSLCallback = { channel in + configureSSLCallback = { channel, postgresChannelHandler in channel.eventLoop.assertInEventLoop() let sslHandler = try NIOSSLClientHandler( context: context, serverHostname: configuration.serverNameForTLS ) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler)) } case .disable: configureSSLCallback = nil diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index ee2af0fe..0a14849a 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration - private let configureSSLCallback: ((Channel) throws -> Void)? + private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? private var listenState = ListenStateMachine() private var preparedStatementState = PreparedStatementStateMachine() @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { configuration: PostgresConnection.InternalConfiguration, eventLoop: EventLoop, logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { eventLoop: EventLoop, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = state self.eventLoop = eventLoop @@ -439,7 +439,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // This method must only be called, if we signalized the StateMachine before that we are // able to setup a SSL connection. do { - try self.configureSSLCallback!(context.channel) + try self.configureSSLCallback!(context.channel, self) let action = self.state.sslHandlerAdded() self.run(action, with: context) } catch { diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index dfdcc53e..a2c90969 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let eventHandler = TestEventHandler() @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() throws { let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } From fd0e415a705c490499f983639b04f491a2ed9d99 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Tue, 10 Dec 2024 10:11:53 +0100 Subject: [PATCH 097/106] Allow TLS enabled connections when providing an established channel (#526) Co-authored-by: Fabian Fett --- .../PostgresConnection+Configuration.swift | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index dd0f5404..b260723a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -192,9 +192,22 @@ extension PostgresConnection { /// - Parameters: /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). - /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + /// - tls: The TLS mode to use. + public init(establishedChannel channel: Channel, tls: PostgresConnection.Configuration.TLS, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { - self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) + self.init(establishedChannel: channel, tls: .disable, username: username, password: password, database: database) } // MARK: - Implementation details From 045cc49fbe224093cc1d77e79065e9e00081d119 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 21 Dec 2024 04:57:06 -0600 Subject: [PATCH 098/106] Update DocC settings to latest version of Vapor theme (#529) Update DocC settings to latest version of Vapor theme, for compatibility with Swift 6's DocC changes --- Sources/PostgresNIO/Docs.docc/theme-settings.json | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index dda76197..911cc1bc 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,16 +1,19 @@ { "theme": { - "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "aside": { "border-radius": "16px", "border-style": "double", "border-width": "3px" }, "border-radius": "0", "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { + "fill": { "dark": "#000", "light": "#fff" } "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", + "documentation-intro-eyebrow": "white", + "documentation-intro-figure": "white", + "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, "logo-shape": { "dark": "#000", "light": "#fff" }, - "fill": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, From 7c29718fe5631462417ed3350ccc1e131678bf13 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 21 Dec 2024 05:06:54 -0600 Subject: [PATCH 099/106] Fix malformed JSON in theme settings (#530) Fix malformed JSON in theme settings due to comma misplacement --- Sources/PostgresNIO/Docs.docc/theme-settings.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index 911cc1bc..38914a04 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -5,7 +5,7 @@ "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { - "fill": { "dark": "#000", "light": "#fff" } + "fill": { "dark": "#000", "light": "#fff" }, "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", @@ -13,7 +13,7 @@ "documentation-intro-figure": "white", "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, - "logo-shape": { "dark": "#000", "light": "#fff" }, + "logo-shape": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, From d6b6487c967a04000db58e622e78cff91fd5bc26 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 27 Jan 2025 18:29:56 +0100 Subject: [PATCH 100/106] Fix sendable warnings (#533) --- Sources/ConnectionPoolModule/NIOLock.swift | 60 +++++++++++-------- .../NIOLockedValueBox.swift | 46 +++++++++++++- Sources/PostgresNIO/New/PSQLTask.swift | 13 ++-- .../PostgresNIO/PostgresDatabase+Query.swift | 2 +- 4 files changed, 86 insertions(+), 35 deletions(-) diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index 13a9df4a..b6cd7164 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -24,6 +24,13 @@ import WinSDK import Glibc #elseif canImport(Musl) import Musl +#elseif canImport(Bionic) +import Bionic +#elseif canImport(WASILibc) +import WASILibc +#if canImport(wasi_pthread) +import wasi_pthread +#endif #else #error("The concurrency NIOLock module was unable to identify your C library.") #endif @@ -37,16 +44,16 @@ typealias LockPrimitive = pthread_mutex_t #endif @usableFromInline -enum LockOperations { } +enum LockOperations {} extension LockOperations { @inlinable static func create(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) InitializeSRWLock(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { @@ -55,43 +62,43 @@ extension LockOperations { let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_destroy(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_lock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_unlock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } } @@ -129,9 +136,11 @@ final class LockStorage: ManagedBuffer { @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in - return value + value } - // Avoid 'unsafeDowncast' as there is a miscompilation on 5.10. + // Intentionally using a force cast here to avoid a miss compiliation in 5.10. + // This is as fast as an unsafeDownCast since ManagedBuffer is inlined and the optimizer + // can eliminate the upcast/downcast pair let storage = buffer as! Self storage.withUnsafeMutablePointers { _, lockPtr in @@ -165,7 +174,7 @@ final class LockStorage: ManagedBuffer { @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in - return try body(lockPtr) + try body(lockPtr) } } @@ -179,17 +188,14 @@ final class LockStorage: ManagedBuffer { } } -extension LockStorage: @unchecked Sendable { } - /// A threading lock based on `libpthread` instead of `libdispatch`. /// -/// - note: ``NIOLock`` has reference semantics. +/// - Note: ``NIOLock`` has reference semantics. /// /// This object provides a lock on top of a single `pthread_mutex_t`. This kind /// of lock is safe to use with `libpthread`-based threading models, such as the /// one used by NIO. On Windows, the lock is based on the substantially similar /// `SRWLOCK` type. -@usableFromInline struct NIOLock { @usableFromInline internal let _storage: LockStorage @@ -220,7 +226,7 @@ struct NIOLock { @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { - return try self._storage.withLockPrimitive(body) + try self._storage.withLockPrimitive(body) } } @@ -243,12 +249,12 @@ extension NIOLock { } @inlinable - func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } -extension NIOLock: Sendable {} +extension NIOLock: @unchecked Sendable {} extension UnsafeMutablePointer { @inlinable @@ -264,6 +270,10 @@ extension UnsafeMutablePointer { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - // FIXME: duplicated with NIO. - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift index e5a3e6a2..c9cd89e0 100644 --- a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -17,7 +17,7 @@ /// Provides locked access to `Value`. /// -/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// - Note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` /// alongside a lock behind a reference. /// /// This is no different than creating a ``Lock`` and protecting all @@ -39,8 +39,48 @@ struct NIOLockedValueBox { /// Access the `Value`, allowing mutation of it. @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { - return try self._storage.withLockedValue(mutate) + try self._storage.withLockedValue(mutate) + } + + /// Provides an unsafe view over the lock and its value. + /// + /// This can be beneficial when you require fine grained control over the lock in some + /// situations but don't want lose the benefits of ``withLockedValue(_:)`` in others by + /// switching to ``NIOLock``. + var unsafe: Unsafe { + Unsafe(_storage: self._storage) + } + + /// Provides an unsafe view over the lock and its value. + struct Unsafe { + @usableFromInline + let _storage: LockStorage + + /// Manually acquire the lock. + @inlinable + func lock() { + self._storage.lock() + } + + /// Manually release the lock. + @inlinable + func unlock() { + self._storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + /// + /// - Parameter mutate: A closure with scoped access to the value. + /// - Returns: The result of the `mutate` closure. + @inlinable + func withValueAssumingLockIsAcquired( + _ mutate: (_ value: inout Value) throws -> Result + ) rethrows -> Result { + try self._storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } } } -extension NIOLockedValueBox: Sendable where Value: Sendable {} +extension NIOLockedValueBox: @unchecked Sendable where Value: Sendable {} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 363f9394..6106fd21 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,7 +1,7 @@ import Logging import NIOCore -enum HandlerTask { +enum HandlerTask: Sendable { case extendedQuery(ExtendedQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) @@ -31,7 +31,7 @@ enum PSQLTask { } } -final class ExtendedQueryContext { +final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) @@ -100,14 +100,15 @@ final class PreparedStatementContext: Sendable { } } -final class CloseCommandContext { +final class CloseCommandContext: Sendable { let target: CloseTarget let logger: Logger let promise: EventLoopPromise - init(target: CloseTarget, - logger: Logger, - promise: EventLoopPromise + init( + target: CloseTarget, + logger: Logger, + promise: EventLoopPromise ) { self.target = target self.logger = logger diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 483d5a7b..8de93814 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -40,7 +40,7 @@ extension PostgresDatabase { } } -public struct PostgresQueryResult { +public struct PostgresQueryResult: Sendable { public let metadata: PostgresQueryMetadata public let rows: [PostgresRow] } From 8d07f2049531a60c08b8dda7011a3ad8ac3c989b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 10 Feb 2025 16:00:18 +0100 Subject: [PATCH 101/106] Fix Sendable warnings (#536) --- Package.swift | 2 +- Tests/IntegrationTests/PostgresNIOTests.swift | 22 -------- .../New/PostgresChannelHandlerTests.swift | 2 +- .../New/PostgresConnectionTests.swift | 16 +++--- .../New/PostgresRowSequenceTests.swift | 51 +++++++++++-------- 5 files changed, 41 insertions(+), 52 deletions(-) diff --git a/Package.swift b/Package.swift index 5f6562f6..3dd21c3c 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ff59209b..9a58f050 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1032,28 +1032,6 @@ final class PostgresNIOTests: XCTestCase { } } - func testRemoteTLSServer() { - // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj - var conn: PostgresConnection? - let logger = Logger(label: "test") - let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) - let config = PostgresConnection.Configuration( - host: "elmer.db.elephantsql.com", - port: 5432, - username: "uymgphwj", - password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA", - database: "uymgphwj", - tls: .require(sslContext) - ) - XCTAssertNoThrow(conn = try PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger).wait()) - defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) - XCTAssertEqual(rows?.count, 1) - let row = rows?.first?.makeRandomAccess() - XCTAssertEqual(row?[data: "version"].string?.contains("PostgreSQL"), true) - } - @available(*, deprecated, message: "Test deprecated functionality") func testFailingTLSConnectionClosesConnection() { // There was a bug (https://github.com/vapor/postgres-nio/issues/133) where we would hit diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index a2c90969..206f38a3 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -124,7 +124,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) let eventHandler = TestEventHandler() - try embedded.pipeline.addHandler(eventHandler, position: .last).wait() + try embedded.pipeline.syncOperations.addHandler(eventHandler, position: .last) embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) XCTAssertTrue(embedded.isActive) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0bc61efd..d0f8e2b0 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -40,10 +40,10 @@ class PostgresConnectionTests: XCTestCase { func testOptionsAreSentOnTheWire() async throws { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = { @@ -640,10 +640,10 @@ class PostgresConnectionTests: XCTestCase { func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = PostgresConnection.Configuration( diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 816daf04..9d662252 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,6 +1,6 @@ import Atomics import NIOEmbedded -import Dispatch +import NIOPosix import XCTest @testable import PostgresNIO import NIOCore @@ -8,10 +8,10 @@ import Logging final class PostgresRowSequenceTests: XCTestCase { let logger = Logger(label: "PSQLRowStreamTests") - let eventLoop = EmbeddedEventLoop() func testBackpressureWorks() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -19,7 +19,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -41,6 +41,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksWhileIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -48,7 +49,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -72,6 +73,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksBeforeIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -79,7 +81,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -97,6 +99,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testDroppingTheSequenceCancelsTheSource() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -104,7 +107,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -117,6 +120,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBasedOnCompletedQuery() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -124,7 +128,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -144,6 +148,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithAllData() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -151,7 +156,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -172,6 +177,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithError() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -179,7 +185,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -200,6 +206,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testSucceedingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -207,14 +214,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -222,7 +229,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .success("SELECT 1")) } @@ -232,6 +239,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testFailingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -239,14 +247,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -254,7 +262,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } @@ -268,6 +276,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowBufferShrinksAndGrows() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -275,7 +284,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -332,6 +341,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowShrinksToMin() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -339,7 +349,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -386,6 +396,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBufferAcceptsNewRowsEventhoughItDidntAskForIt() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -393,7 +404,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) From 712740b1f528210a3ce05618336f5c7dd2470bb9 Mon Sep 17 00:00:00 2001 From: Stevenson Michel <130018170+thoven87@users.noreply.github.com> Date: Tue, 11 Feb 2025 05:29:25 -0500 Subject: [PATCH 102/106] Add `withTransaction` API (#519) Co-authored-by: Fabian Fett --- Sources/PostgresNIO/Pool/PostgresClient.swift | 22 ++++ .../PostgresClientTests.swift | 104 ++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index ad8a4bf1..e9e947ef 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -307,6 +307,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } + + /// Lease a connection for the provided `closure`'s lifetime. + /// A transation starts with call to withConnection + /// A transaction should end with a call to COMMIT or ROLLBACK + /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withTransaction(_ process: (PostgresConnection) async throws -> Result) async throws -> Result { + try await withConnection { connection in + try await connection.query("BEGIN;", logger: self.backgroundLogger) + do { + let value = try await process(connection) + try await connection.query("COMMIT;", logger: self.backgroundLogger) + return value + } catch { + try await connection.query("ROLLBACK;", logger: self.backgroundLogger) + throw error + } + } + } /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 579c92cd..167ba298 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let tableName = "test_client_transactions" + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + do { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + uuid UUID NOT NULL + ); + """, + logger: logger + ) + + let iterations = 1000 + + for _ in 0.. Date: Thu, 13 Feb 2025 12:32:47 +0100 Subject: [PATCH 103/106] Improve transaction handling (#538) --- .../Connection/PostgresConnection.swift | 104 ++++++++++++++++++ .../New/PostgresTransactionError.swift | 21 ++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 84 +++++++++++--- .../PostgresClientTests.swift | 12 +- 4 files changed, 199 insertions(+), 22 deletions(-) create mode 100644 Sources/PostgresNIO/New/PostgresTransactionError.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 229cd647..e267d8f9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -530,6 +530,110 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + #if compiler(>=6.0) + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ process: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #else + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ process: (PostgresConnection) async throws -> Result + ) async throws -> Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #endif } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/PostgresTransactionError.swift b/Sources/PostgresNIO/New/PostgresTransactionError.swift new file mode 100644 index 00000000..35038446 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresTransactionError.swift @@ -0,0 +1,21 @@ +/// A wrapper around the errors that can occur during a transaction. +public struct PostgresTransactionError: Error { + + /// The file in which the transaction was started + public var file: String + /// The line in which the transaction was started + public var line: Int + + /// The error thrown when running the `BEGIN` query + public var beginError: Error? + /// The error thrown in the transaction closure + public var closureError: Error? + + /// The error thrown while rolling the transaction back. If the ``closureError`` is set, + /// but the ``rollbackError`` is empty, the rollback was successful. If the ``rollbackError`` + /// is set, the rollback failed. + public var rollbackError: Error? + + /// The error thrown while commiting the transaction. + public var commitError: Error? +} diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index e9e947ef..d54e34eb 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -293,13 +293,13 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) } } - /// Lease a connection for the provided `closure`'s lifetime. /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. + @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { let connection = try await self.leaseConnection() @@ -307,28 +307,80 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } - + + #if compiler(>=6.0) /// Lease a connection for the provided `closure`'s lifetime. - /// A transation starts with call to withConnection - /// A transaction should end with a call to COMMIT or ROLLBACK - /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. - public func withTransaction(_ process: (PostgresConnection) async throws -> Result) async throws -> Result { - try await withConnection { connection in - try await connection.query("BEGIN;", logger: self.backgroundLogger) - do { - let value = try await process(connection) - try await connection.query("COMMIT;", logger: self.backgroundLogger) - return value - } catch { - try await connection.query("ROLLBACK;", logger: self.backgroundLogger) - throw error - } + public func withConnection( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } + } + #else + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) } } + #endif /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 167ba298..34a8ad2a 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -77,7 +77,7 @@ final class PostgresClientTests: XCTestCase { for _ in 0.. Date: Wed, 2 Apr 2025 13:20:07 +0200 Subject: [PATCH 104/106] Move ConnectionPool test-utils into separate target (#544) --- Package.swift | 8 +++++ .../ConnectionPoolTestUtils}/MockClock.swift | 33 ++++++++++--------- .../MockConnection.swift | 22 ++++++------- .../MockConnectionFactory.swift | 27 ++++++++------- .../MockPingPongBehaviour.swift | 10 +++--- .../ConnectionPoolTestUtils/MockRequest.swift | 29 ++++++++++++++++ .../ConnectionPoolTests.swift | 3 +- .../ConnectionRequestTests.swift | 1 + .../Mocks/MockRequest.swift | 28 ---------------- .../NoKeepAliveBehaviorTests.swift | 1 + ...oolStateMachine+ConnectionGroupTests.swift | 3 +- ...oolStateMachine+ConnectionStateTests.swift | 1 + .../PoolStateMachine+RequestQueueTests.swift | 1 + .../PoolStateMachineTests.swift | 1 + 14 files changed, 95 insertions(+), 73 deletions(-) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockClock.swift (84%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockConnection.swift (86%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockConnectionFactory.swift (79%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockPingPongBehaviour.swift (84%) create mode 100644 Sources/ConnectionPoolTestUtils/MockRequest.swift delete mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift diff --git a/Package.swift b/Package.swift index 3dd21c3c..ff071f88 100644 --- a/Package.swift +++ b/Package.swift @@ -57,6 +57,13 @@ let package = Package( path: "Sources/ConnectionPoolModule", swiftSettings: swiftSettings ), + .target( + name: "ConnectionPoolTestUtils", + dependencies: [ + "_ConnectionPoolModule", + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + ] + ), .testTarget( name: "PostgresNIOTests", dependencies: [ @@ -70,6 +77,7 @@ let package = Package( name: "ConnectionPoolModuleTests", dependencies: [ .target(name: "_ConnectionPoolModule"), + .target(name: "ConnectionPoolTestUtils"), .product(name: "DequeModule", package: "swift-collections"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Sources/ConnectionPoolTestUtils/MockClock.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift rename to Sources/ConnectionPoolTestUtils/MockClock.swift index cd08d54e..34bf17e3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Sources/ConnectionPoolTestUtils/MockClock.swift @@ -1,31 +1,32 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import Atomics import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockClock: Clock { - struct Instant: InstantProtocol, Comparable { - typealias Duration = Swift.Duration +public final class MockClock: Clock { + public struct Instant: InstantProtocol, Comparable { + public typealias Duration = Swift.Duration - func advanced(by duration: Self.Duration) -> Self { + public func advanced(by duration: Self.Duration) -> Self { .init(self.base + duration) } - func duration(to other: Self) -> Self.Duration { + public func duration(to other: Self) -> Self.Duration { self.base - other.base } private var base: Swift.Duration - init(_ base: Duration) { + public init(_ base: Duration) { self.base = base } - static func < (lhs: Self, rhs: Self) -> Bool { + public static func < (lhs: Self, rhs: Self) -> Bool { lhs.base < rhs.base } - static func == (lhs: Self, rhs: Self) -> Bool { + public static func == (lhs: Self, rhs: Self) -> Bool { lhs.base == rhs.base } } @@ -58,16 +59,18 @@ final class MockClock: Clock { var continuation: CheckedContinuation } - typealias Duration = Swift.Duration + public typealias Duration = Swift.Duration - var minimumResolution: Duration { .nanoseconds(1) } + public var minimumResolution: Duration { .nanoseconds(1) } - var now: Instant { self.stateBox.withLockedValue { $0.now } } + public var now: Instant { self.stateBox.withLockedValue { $0.now } } private let stateBox = NIOLockedValueBox(State()) private let waiterIDGenerator = ManagedAtomic(0) - func sleep(until deadline: Instant, tolerance: Duration?) async throws { + public init() {} + + public func sleep(until deadline: Instant, tolerance: Duration?) async throws { let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { @@ -131,7 +134,7 @@ final class MockClock: Clock { } @discardableResult - func nextTimerScheduled() async -> Instant { + public func nextTimerScheduled() async -> Instant { await withCheckedContinuation { (continuation: CheckedContinuation) in let instant = self.stateBox.withLockedValue { state -> Instant? in if let scheduled = state.nextDeadlines.popFirst() { @@ -149,7 +152,7 @@ final class MockClock: Clock { } } - func advance(to deadline: Instant) { + public func advance(to deadline: Instant) { let waiters = self.stateBox.withLockedValue { state -> ArraySlice in precondition(deadline > state.now, "Time can only move forward") state.now = deadline diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Sources/ConnectionPoolTestUtils/MockConnection.swift similarity index 86% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift rename to Sources/ConnectionPoolTestUtils/MockConnection.swift index f826ea04..db5c3ef7 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnection.swift @@ -1,11 +1,11 @@ +import _ConnectionPoolModule import DequeModule -@testable import _ConnectionPoolModule +import NIOConcurrencyHelpers -// Sendability enforced through the lock -final class MockConnection: PooledConnection, Sendable { - typealias ID = Int +public final class MockConnection: PooledConnection, Sendable { + public typealias ID = Int - let id: ID + public let id: ID private enum State { case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) @@ -15,11 +15,11 @@ final class MockConnection: PooledConnection, Sendable { private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) - init(id: Int) { + public init(id: Int) { self.id = id } - var signalToClose: Void { + public var signalToClose: Void { get async throws { try await withCheckedThrowingContinuation { continuation in let runRightAway = self.lock.withLockedValue { state -> Bool in @@ -41,7 +41,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { let enqueued = self.lock.withLockedValue { state -> Bool in switch state { case .closed: @@ -64,7 +64,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func close() { + public func close() { let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in switch state { case .running(let continuations, let callbacks): @@ -81,7 +81,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func closeIfClosing() { + public func closeIfClosing() { let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in switch state { case .running, .closed: @@ -100,7 +100,7 @@ final class MockConnection: PooledConnection, Sendable { } extension MockConnection: CustomStringConvertible { - var description: String { + public var description: String { let state = self.lock.withLockedValue { $0 } return "MockConnection(id: \(self.id), state: \(state))" } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift similarity index 79% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift rename to Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift index 1c9bfff8..59552d30 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift @@ -1,14 +1,15 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory: Sendable where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection +public final class MockConnectionFactory: Sendable where Clock.Duration == Duration { + public typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + public typealias Request = ConnectionRequest + public typealias KeepAliveBehavior = MockPingPongBehavior + public typealias MetricsDelegate = NoOpConnectionPoolMetrics + public typealias ConnectionID = Int + public typealias Connection = MockConnection let stateBox = NIOLockedValueBox(State()) @@ -20,15 +21,17 @@ final class MockConnectionFactory: Sendable where Clo var runningConnections = [ConnectionID: Connection]() } - var pendingConnectionAttemptsCount: Int { + public init() {} + + public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } - var runningConnections: [Connection] { + public var runningConnections: [Connection] { self.stateBox.withLockedValue { Array($0.runningConnections.values) } } - func makeConnection( + public func makeConnection( id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { @@ -52,7 +55,7 @@ final class MockConnectionFactory: Sendable where Clo } @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + public func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in if let attempt = state.attempts.popFirst() { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename to Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 637f096c..5a274079 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -2,8 +2,8 @@ import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockPingPongBehavior: ConnectionKeepAliveBehavior { - let keepAliveFrequency: Duration? +public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { + public let keepAliveFrequency: Duration? let stateBox = NIOLockedValueBox(State()) @@ -13,11 +13,11 @@ final class MockPingPongBehavior: ConnectionKeepAl var waiter = Deque), Never>>() } - init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { + public init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: Connection) async throws { + public func runKeepAlive(for connection: Connection) async throws { precondition(self.keepAliveFrequency != nil) // we currently don't support cancellation when creating a connection @@ -40,7 +40,7 @@ final class MockPingPongBehavior: ConnectionKeepAl } @discardableResult - func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + public func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in if let run = state.runs.popFirst() { diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift new file mode 100644 index 00000000..06fc49bc --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -0,0 +1,29 @@ +import _ConnectionPoolModule + +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + public typealias Connection = MockConnection + + public struct ID: Hashable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + public init() {} + + public var id: ID { ID(self) } + + public static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + public func complete(with: Result) { + + } +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3c0e7a6b..9b3d5871 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,7 +1,8 @@ @testable import _ConnectionPoolModule import Atomics -import XCTest +import ConnectionPoolTestUtils import NIOEmbedded +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class ConnectionPoolTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 5845267f..cbdc4f65 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest final class ConnectionRequestTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift deleted file mode 100644 index 6aaa9c91..00000000 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift +++ /dev/null @@ -1,28 +0,0 @@ -import _ConnectionPoolModule - -final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - typealias Connection = MockConnection - - struct ID: Hashable { - var objectID: ObjectIdentifier - - init(_ request: MockRequest) { - self.objectID = ObjectIdentifier(request) - } - } - - var id: ID { ID(self) } - - - static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { - lhs.id == rhs.id - } - - func hash(into hasher: inout Hasher) { - hasher.combine(self.id) - } - - func complete(with: Result) { - - } -} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index b817ce19..4ddad00d 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,4 +1,5 @@ import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 6b8d6c6e..3ec7dc80 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,5 +1,6 @@ -import XCTest @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class PoolStateMachine_ConnectionGroupTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index bc4c2c4b..77ad713d 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 0231da51..2ec450a6 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 2f3ae617..ca5cb54d 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,3 +1,4 @@ +import ConnectionPoolTestUtils import XCTest @testable import _ConnectionPoolModule From b775835ff0dbef8db8af178fb9eff400bbad1582 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Apr 2025 16:32:34 +0200 Subject: [PATCH 105/106] Add Benchmarks for ConnectionPool (#545) --- Benchmarks/.gitignore | 1 + .../ConnectionPoolBenchmarks.swift | 51 +++++++++++++++++++ Benchmarks/Package.swift | 28 ++++++++++ Package.swift | 9 ++-- .../MockConnectionFactory.swift | 15 +++++- .../MockPingPongBehaviour.swift | 3 +- .../ConnectionPoolTests.swift | 2 +- .../ConnectionRequestTests.swift | 2 +- .../NoKeepAliveBehaviorTests.swift | 2 +- ...oolStateMachine+ConnectionGroupTests.swift | 2 +- ...oolStateMachine+ConnectionStateTests.swift | 2 +- .../PoolStateMachine+RequestQueueTests.swift | 2 +- .../PoolStateMachineTests.swift | 4 +- 13 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 Benchmarks/.gitignore create mode 100644 Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift create mode 100644 Benchmarks/Package.swift diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 00000000..24e5b0a1 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1 @@ +.build diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift new file mode 100644 index 00000000..98f21f62 --- /dev/null +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -0,0 +1,51 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Benchmark + +let benchmarks: @Sendable () -> Void = { + Benchmark("Minimal benchmark", configuration: .init(scalingFactor: .kilo)) { benchmark in + let clock = MockClock() + let factory = MockConnectionFactory(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + for parallel in 0..: Sendable wh var runningConnections = [ConnectionID: Connection]() } - public init() {} + let autoMaxStreams: UInt16? + + public init(autoMaxStreams: UInt16? = nil) { + self.autoMaxStreams = autoMaxStreams + } public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } @@ -35,6 +39,15 @@ public final class MockConnectionFactory: Sendable wh id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { + if let autoMaxStreams = self.autoMaxStreams { + let connection = MockConnection(id: id) + Task { + try? await connection.signalToClose + connection.closeIfClosing() + } + return .init(connection: connection, maximalStreamsOnConnection: autoMaxStreams) + } + // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in diff --git a/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 5a274079..de1a7275 100644 --- a/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -1,5 +1,6 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 9b3d5871..c745b4a0 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,6 +1,6 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import Atomics -import ConnectionPoolTestUtils import NIOEmbedded import XCTest diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index cbdc4f65..537efbd9 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest final class ConnectionRequestTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index 4ddad00d..b1b54023 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,5 +1,5 @@ import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 3ec7dc80..b09bfcb4 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index 77ad713d..7dd2b726 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 2ec450a6..b74b86cc 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index ca5cb54d..c0b6ddcd 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,6 +1,6 @@ -import ConnectionPoolTestUtils -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) typealias TestPoolStateMachine = PoolStateMachine< From ecbc3eb092cb41015c02643ff5258cb94ccbd342 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Apr 2025 17:59:26 +0200 Subject: [PATCH 106/106] Make ConnectionPool faster (#546) --- Sources/ConnectionPoolModule/ConnectionRequest.swift | 5 ++++- .../PoolStateMachine+ConnectionGroup.swift | 10 ++++++++-- .../PoolStateMachine+ConnectionState.swift | 4 ++-- Sources/ConnectionPoolTestUtils/MockRequest.swift | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 19ed9bd2..1d1c55da 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -21,7 +21,8 @@ public struct ConnectionRequest: ConnectionRequest } } -fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() +@usableFromInline +let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension ConnectionPool where Request == ConnectionRequest { @@ -44,6 +45,7 @@ extension ConnectionPool where Request == ConnectionRequest { ) } + @inlinable public func leaseConnection() async throws -> Connection { let requestID = requestIDGenerator.next() @@ -67,6 +69,7 @@ extension ConnectionPool where Request == ConnectionRequest { return connection } + @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { let connection = try await self.leaseConnection() defer { self.releaseConnection(connection) } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index f26f244d..a8e97ffd 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -132,6 +132,12 @@ extension PoolStateMachine { @usableFromInline var info: ConnectionAvailableInfo + + @inlinable + init(use: ConnectionUse, info: ConnectionAvailableInfo) { + self.use = use + self.info = info + } } mutating func refillConnections() -> [ConnectionRequest] { @@ -623,7 +629,7 @@ extension PoolStateMachine { // MARK: - Private functions - - @usableFromInline + @inlinable /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { switch index { case 0.. AvailableConnectionContext { precondition(self.connections[index].isAvailable) let use = self.getConnectionUse(index: index) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 2fb68a2d..9912f13a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -164,7 +164,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isLeased: Bool { switch self.state { case .leased: @@ -174,7 +174,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isConnected: Bool { switch self.state { case .idle, .leased: diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift index 06fc49bc..5e4e2fc0 100644 --- a/Sources/ConnectionPoolTestUtils/MockRequest.swift +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -3,7 +3,7 @@ import _ConnectionPoolModule public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { public typealias Connection = MockConnection - public struct ID: Hashable { + public struct ID: Hashable, Sendable { var objectID: ObjectIdentifier init(_ request: MockRequest) {