diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..6413432f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @fabianfett @gwynne diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index 03fed6f3..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: [tanner0101] diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index f27346af..dc2e0634 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -1,18 +1,14 @@ name: deploy-api-docs on: - push: - branches: - - master + push: + branches: + - main jobs: - deploy: - name: api.vapor.codes - runs-on: ubuntu-latest - steps: - - name: Deploy api-docs - uses: appleboy/ssh-action@master - with: - host: vapor.codes - username: vapor - key: ${{ secrets.VAPOR_CODES_SSH_KEY }} - script: ./github-actions/deploy-api-docs.sh + build-and-deploy: + uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@main + secrets: inherit + with: + package_name: postgres-nio + modules: PostgresNIO + pathsToInvalidate: /postgresnio/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e4cfccb..8364e8ae 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,132 +1,204 @@ -name: test +name: CI +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true on: - - pull_request -defaults: - run: - shell: bash + push: + branches: + - "main" + pull_request: + branches: + - "*" +env: + LOG_LEVEL: info + jobs: - dependents: - runs-on: ubuntu-latest - services: - psql-a: - image: ${{ matrix.dbimage }} - env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - psql-b: - image: ${{ matrix.dbimage }} - env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - container: swift:5.2-bionic + linux-unit: strategy: fail-fast: false matrix: - dbimage: - - postgres:12 - - postgres:11 - dependent: - - postgres-kit - - fluent-postgres-driver + swift-image: + - swift:5.9-jammy + - swift:5.10-noble + - swift:6.0-noble + - swiftlang/swift:nightly-main-jammy + container: ${{ matrix.swift-image }} + runs-on: ubuntu-latest steps: + - name: Display OS and Swift versions + shell: bash + run: | + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" + swift --version - name: Check out package - uses: actions/checkout@v2 - with: - path: package - - name: Check out dependent - uses: actions/checkout@v2 + uses: actions/checkout@v4 + - name: Run unit tests with Thread Sanitizer + run: | + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread --enable-code-coverage + - name: Submit code coverage + uses: vapor/swift-codecov-action@v0.3 with: - repository: vapor/${{ matrix.dependent }} - path: dependent - - name: Use local package - run: swift package edit postgres-nio --path ../package - working-directory: dependent - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread - working-directory: dependent - env: - POSTGRES_HOSTNAME: psql-a - POSTGRES_HOSTNAME_A: psql-a - POSTGRES_HOSTNAME_B: psql-b - LOG_LEVEL: notice - linux: + codecov_token: ${{ secrets.CODECOV_TOKEN }} + + linux-integration-and-dependencies: strategy: fail-fast: false matrix: - dbimage: + postgres-image: + - postgres:17 + - postgres:15 - postgres:12 - - postgres:11 - runner: - # 5.2 Stable - - swift:5.2-xenial - - swift:5.2-bionic - # 5.2 Unstable - - swiftlang/swift:nightly-5.2-xenial - - swiftlang/swift:nightly-5.2-bionic - # 5.3 Unstable - - swiftlang/swift:nightly-5.3-xenial - - swiftlang/swift:nightly-5.3-bionic - # Master Unsable - - swiftlang/swift:nightly-master-xenial - - swiftlang/swift:nightly-master-bionic - - swiftlang/swift:nightly-master-focal - - swiftlang/swift:nightly-master-centos8 - - swiftlang/swift:nightly-master-amazonlinux2 - container: ${{ matrix.runner }} + include: + - postgres-image: postgres:17 + postgres-auth: scram-sha-256 + - postgres-image: postgres:15 + postgres-auth: md5 + - postgres-image: postgres:12 + postgres-auth: trust + container: + image: swift:5.10-noble + volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest + env: + # Unfortunately, fluent-postgres-driver details leak through here + POSTGRES_DB: 'test_database' + POSTGRES_DB_A: 'test_database' + POSTGRES_DB_B: 'test_database' + POSTGRES_USER: 'test_username' + POSTGRES_USER_A: 'test_username' + POSTGRES_USER_B: 'test_username' + POSTGRES_PASSWORD: 'test_password' + POSTGRES_PASSWORD_A: 'test_password' + POSTGRES_PASSWORD_B: 'test_password' + POSTGRES_HOSTNAME: 'psql-a' + POSTGRES_HOSTNAME_A: 'psql-a' + POSTGRES_HOSTNAME_B: 'psql-b' + POSTGRES_SOCKET: '/var/run/postgresql/.s.PGSQL.5432' + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} services: - psql: - image: ${{ matrix.dbimage }} - env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - steps: - - name: Check out code - uses: actions/checkout@v2 - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread + psql-a: + image: ${{ matrix.postgres-image }} + volumes: [ 'pgrunshare:/var/run/postgresql' ] env: - POSTGRES_HOSTNAME: psql - LOG_LEVEL: notice - macOS: + POSTGRES_USER: 'test_username' + POSTGRES_DB: 'test_database' + POSTGRES_PASSWORD: 'test_password' + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} + psql-b: + image: ${{ matrix.postgres-image }} + volumes: [ 'pgrunshare:/var/run/postgresql' ] + env: + POSTGRES_USER: 'test_username' + POSTGRES_DB: 'test_database' + POSTGRES_PASSWORD: 'test_password' + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} + steps: + - name: Display OS and Swift versions + run: | + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version + - name: Check out package + uses: actions/checkout@v4 + with: { path: 'postgres-nio' } + - name: Run integration tests + run: swift test --package-path postgres-nio --filter=^IntegrationTests + - name: Check out postgres-kit dependent + uses: actions/checkout@v4 + with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } + - name: Check out fluent-postgres-driver dependent + uses: actions/checkout@v4 + with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } + - name: Use local package in dependents + run: | + swift package --package-path postgres-kit edit postgres-nio --path postgres-nio + swift package --package-path fluent-postgres-driver edit postgres-nio --path postgres-nio + - name: Run postgres-kit tests + run: swift test --package-path postgres-kit + - name: Run fluent-postgres-driver tests + run: swift test --package-path fluent-postgres-driver + + macos-all: strategy: fail-fast: false matrix: - include: - - formula: postgresql@11 - datadir: postgresql@11 - - formula: postgresql@12 - datadir: postgres - runs-on: macos-latest + postgres-formula: + # Only test one version on macOS, let Linux do the rest + - postgresql@16 + postgres-auth: + # Only test one auth method on macOS, Linux tests will cover the others + - scram-sha-256 + xcode-version: + - '~15' + include: + - xcode-version: '~15' + macos-version: 'macos-14' + runs-on: ${{ matrix.macos-version }} + env: + POSTGRES_HOSTNAME: 127.0.0.1 + POSTGRES_USER: 'test_username' + POSTGRES_PASSWORD: 'test_password' + POSTGRES_DB: 'postgres' + POSTGRES_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' + POSTGRES_FORMULA: ${{ matrix.postgres-formula }} steps: - name: Select latest available Xcode - uses: maxim-lobanov/setup-xcode@1.0 + uses: maxim-lobanov/setup-xcode@v1 with: - xcode-version: latest - - name: Replace Postgres install and start server + xcode-version: ${{ matrix.xcode-version }} + - name: Install Postgres, setup DB and auth, and wait for server start run: | - brew uninstall --force postgresql php && rm -rf /usr/local/{etc,var}/{postgres,pg}* - brew install ${{ matrix.formula }} && brew link --force ${{ matrix.formula }} - initdb --locale=C -E UTF-8 $(brew --prefix)/var/${{ matrix.datadir }} - brew services start ${{ matrix.formula }} - - name: Wait for server to be ready - run: until pg_isready; do sleep 1; done - timeout-minutes: 2 - - name: Setup users and databases for Postgres - run: | - createuser --createdb --login vapor_username - for db in vapor_database_{a,b}; do - createdb -Ovapor_username $db && psql $db <<<"ALTER SCHEMA public OWNER TO vapor_username;" - done + export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + # ** BEGIN ** Work around bug in both Homebrew and GHA + (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) + (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) + (brew upgrade || true) + # ** END ** Work around bug in both Homebrew and GHA + brew install --overwrite "${POSTGRES_FORMULA}" + brew link --overwrite --force "${POSTGRES_FORMULA}" + initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") + pg_ctl start --wait + timeout-minutes: 15 - name: Checkout code - uses: actions/checkout@v2 - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread - env: - POSTGRES_DATABASE: vapor_database_a - POSTGRES_DATABASE_A: vapor_database_a - POSTGRES_DATABASE_B: vapor_database_b - LOG_LEVEL: notice + uses: actions/checkout@v4 + - name: Run all tests + run: swift test + + api-breakage: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + container: swift:noble + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + # https://github.com/actions/checkout/issues/766 + - name: API breaking changes + run: | + git config --global --add safe.directory "${GITHUB_WORKSPACE}" + swift package diagnose-api-breaking-changes origin/main + +# gh-codeql: +# if: ${{ false }} +# runs-on: ubuntu-latest +# container: swift:noble +# permissions: { actions: write, contents: read, security-events: write } +# steps: +# - name: Check out code +# uses: actions/checkout@v4 +# - name: Mark repo safe in non-fake global config +# run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" +# - name: Initialize CodeQL +# uses: github/codeql-action/init@v3 +# with: +# languages: swift +# - name: Perform build +# run: swift build +# - name: Run CodeQL analyze +# uses: github/codeql-action/analyze@v3 diff --git a/.spi.yml b/.spi.yml new file mode 100644 index 00000000..690e4f2a --- /dev/null +++ b/.spi.yml @@ -0,0 +1,4 @@ +version: 1 +external_links: + documentation: "https://api.vapor.codes/postgresnio/documentation/postgresnio/" + diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 00000000..24e5b0a1 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1 @@ +.build diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift new file mode 100644 index 00000000..98f21f62 --- /dev/null +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -0,0 +1,51 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Benchmark + +let benchmarks: @Sendable () -> Void = { + Benchmark("Minimal benchmark", configuration: .init(scalingFactor: .kilo)) { benchmark in + let clock = MockClock() + let factory = MockConnectionFactory(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + for parallel in 0.. +

+ + + + PostgresNIO +
- - Documentation - - - Team Chat +
+
+ Documentation - MIT License + MIT License - - Continuous Integration + + Continuous Integration - Swift 5.2 + Swift 5.8+ -
-
- -🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO](https://github.com/apple/swift-nio). - -### Major Releases - -The table below shows a list of PostgresNIO major releases alongside their compatible NIO and Swift versions. - -|Version|NIO|Swift|SPM| -|-|-|-|-| -|1.0|2.0+|5.2+|`from: "1.0.0"`| - -Use the SPM string to easily include the dependendency in your `Package.swift` file. - -```swift -.package(url: "https://github.com/vapor/postgres-nio.git", from: ...) -``` - -### Supported Platforms - -PostgresNIO supports the following platforms: - -- Ubuntu 16.04+ -- macOS 10.15+ - -## Overview - -PostgresNIO is a client package for connecting to, authorizing, and querying a PostgreSQL server. At the heart of this module are NIO channel handlers for parsing and serializing messages in PostgreSQL's proprietary wire protocol. These channel handlers are combined in a request / response style connection type that provides a convenient, client-like interface for performing queries. - -Support for both simple (text) and parameterized (binary) querying is provided out of the box alongside a `PostgresData` type that handles conversion between PostgreSQL's wire format and native Swift types. - -### Motivation - -Most Swift implementations of Postgres clients are based on the [libpq](https://www.postgresql.org/docs/11/libpq.html) C library which handles transport internally. Building a library directly on top of Postgres' wire protocol using SwiftNIO should yield a more reliable, maintainable, and performant interface for PostgreSQL databases. - -### Goals - -This package is meant to be a low-level, unopinionated PostgreSQL wire-protocol implementation for Swift. The hope is that higher level packages can share PostgresNIO as a foundation for interacting with PostgreSQL servers without needing to duplicate complex logic. - -Because of this, PostgresNIO excludes some important concepts for the sake of simplicity, such as: - -- Connection pooling -- Swift `Codable` integration -- Query building - -If you are looking for a PostgreSQL client package to use in your project, take a look at these higher-level packages built on top of PostgresNIO: - -- [`vapor/postgres-kit`](https://github.com/vapor/postgresql) - -### Dependencies + + SSWG Incubation Level: Graduated + +

-This package has four dependencies: +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. -- [`apple/swift-nio`](https://github.com/apple/swift-nio) for IO -- [`apple/swift-nio-ssl`](https://github.com/apple/swift-nio-ssl) for TLS -- [`apple/swift-log`](https://github.com/apple/swift-log) for logging -- [`apple/swift-metrics`](https://github.com/apple/swift-metrics) for metrics +Features: -This package has no additional system dependencies. +- A [`PostgresConnection`] which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A [`PostgresClient`] which pools and manages connections +- An async/await interface that supports backpressure +- Automatic conversions between Swift primitive types and the Postgres wire format +- Integrated with the Swift server ecosystem, including use of [SwiftLog] and [ServiceLifecycle]. +- Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) +- Support for `Network.framework` when available (e.g. on Apple platforms) +- Supports running on Unix Domain Sockets ## API Docs -Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/master/PostgresNIO/) for a detailed look at all of the classes, structs, protocols, and more. +Check out the [PostgresNIO API docs][Documentation] for a +detailed look at all of the classes, structs, protocols, and more. -## Getting Started +## Getting started -This section will provide a quick look at using PostgresNIO. +Interested in an example? We prepared a simple [Birthday example](https://github.com/vapor/postgres-nio/blob/main/Snippets/Birthdays.swift) +in the Snippets folder. -### Creating a Connection +#### Adding the dependency -The first step to making a query is creating a new `PostgresConnection`. The minimum requirements to create one are a `SocketAddress` and `EventLoop`. +Add `PostgresNIO` as dependency to your `Package.swift`: ```swift -import PostgresNIO - -let eventLoop: EventLoop = ... -let conn = try PostgresConnection.connect( - to: .makeAddressResolvingHost("my.psql.server", port: 5432), - on: eventLoop -).wait() -defer { try! conn.close().wait() } + dependencies: [ + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.21.0"), + ... + ] ``` -Note: These examples will make use of `wait()` for simplicity. This is appropriate if you are using PostgresNIO on the main thread, like for a CLI tool or in tests. However, you should never use `wait()` on an event loop. - -There are a few ways to create a `SocketAddress`: - -- `init(ipAddress: String, port: Int)` -- `init(unixDomainSocketPath: String)` -- `makeAddressResolvingHost(_ host: String, port: Int)` - -There are also some additional arguments you can supply to `connect`. - -- `tlsConfiguration` An optional `TLSConfiguration` struct. If supplied, the PostgreSQL connection will be upgraded to use SSL. -- `serverHostname` An optional `String` to use in conjunction with `tlsConfiguration` to specify the server's hostname. - -`connect` will return a future `PostgresConnection`, or an error if it could not connect. Make sure you close the connection before it deinitializes. - -### Authentication - -Once you have a connection, you will need to authenticate with the server using the `authenticate` method. - +Add `PostgresNIO` to the target you want to use it in: ```swift -try conn.authenticate( - username: "vapor_username", - database: "vapor_database", - password: "vapor_password" -).wait() + targets: [ + .target(name: "MyFancyTarget", dependencies: [ + .product(name: "PostgresNIO", package: "postgres-nio"), + ]) + ] ``` -This requires a username. You may supply a database name and password if needed. - -### Database Protocol +#### Creating a client -Interaction with a server revolves around the `PostgresDatabase` protocol. This protocol includes methods like `query(_:)` for executing SQL queries and reading the resulting rows. - -`PostgresConnection` is the default implementation of `PostgresDatabase` provided by this package. Assume `db` here is the connection from the previous example. +To create a [`PostgresClient`], which pools connections for you, first create a configuration object: ```swift import PostgresNIO -let db: PostgresDatabase = ... -// now we can use client to do queries +let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable +) ``` -### Simple Query - -Simple (or text) queries allow you to execute a SQL string on the connected PostgreSQL server. These queries do not support binding parameters, so any values sent must be escaped manually. - -These queries are most useful for schema or transactional queries, or simple selects. Note that values returned by simple queries will be transferred in the less efficient text format. - -`simpleQuery` has two overloads, one that returns an array of rows, and one that accepts a closure for handling each row as it is returned. - +Next you can create you client with it: ```swift -let rows = try db.simpleQuery("SELECT version()").wait() -print(rows) // [["version": "12.x.x"]] - -try db.simpleQuery("SELECT version()") { row in - print(row) // ["version": "12.x.x"] -}.wait() +let client = PostgresClient(configuration: config) ``` -### Parameterized Query - -Parameterized (or binary) queries allow you to execute a SQL string on the connected PostgreSQL server. These queries support passing bound parameters as a separate argument. Each parameter is represented in the SQL string using incrementing placeholders, starting at `$1`. - -These queries are most useful for selecting, inserting, and updating data. Data for these queries is transferred using the highly efficient binary format. - -Just like `simpleQuery`, `query` also offers two overloads. One that returns an array of rows, and one that accepts a closure for handling each row as it is returned. - +Once you have create your client, you must [`run()`] it: ```swift -let rows = try db.query("SELECT * FROM planets WHERE name = $1", ["Earth"]).wait() -print(rows) // [["id": 42, "name": "Earth"]] +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } -try db.query("SELECT * FROM planets WHERE name = $1", ["Earth"]) { row in - print(row) // ["id": 42, "name": "Earth"] -}.wait() + // You can use the client while the `client.run()` method is not cancelled. + + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} ``` -### Rows and Data +#### Querying -Both `simpleQuery` and `query` return the same `PostgresRow` type. Columns can be fetched from the row using the `column(_: String)` method. +Once a client is running, queries can be sent to the server. This is straightforward: ```swift -let row: PostgresRow = ... -let version = row.column("version") -print(version) // PostgresData? +let rows = try await client.query("SELECT id, username, birthday FROM users") ``` -`PostgresRow` columns are stored as `PostgresData`. This struct contains the raw bytes returned by PostgreSQL as well as some information for parsing them, such as: - -- Postgres column type -- Wire format: binary or text -- Value as array of bytes - -`PostgresData` has a variety of convenience methods for converting column data to usable Swift types. +The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. +The rows can be iterated one-by-one: ```swift -let data: PostgresData= ... +for try await row in rows { + // do something with the row +} +``` -print(data.string) // String? +#### Decoding from PostgresRow -// Postgres only supports signed Ints. -print(data.int) // Int? -print(data.int16) // Int16? -print(data.int32) // Int32? -print(data.int64) // Int64? +However, in most cases it is much easier to request a row's fields as a set of Swift types: -// 'char' can be interpreted as a UInt8. -// It will show in db as a character though. -print(data.uint8) // UInt8? +```swift +for try await (id, username, birthday) in rows.decode((Int, String, Date).self) { + // do something with the datatypes. +} +``` -print(data.bool) // Bool? +A type must implement the [`PostgresDecodable`] protocol in order to be decoded from a row. PostgresNIO provides default implementations for most of Swift's builtin types, as well as some types provided by Foundation: -print(try data.jsonb(as: Foo.self)) // Foo? +- `Bool` +- `Bytes`, `Data`, `ByteBuffer` +- `Date` +- `UInt8`, `Int16`, `Int32`, `Int64`, `Int` +- `Float`, `Double` +- `String` +- `UUID` -print(data.float) // Float? -print(data.double) // Double? +#### Querying with parameters -print(data.date) // Date? -print(data.uuid) // UUID? +Sending parameterized queries to the database is also supported (in the coolest way possible): -print(data.numeric) // PostgresNumeric? +```swift +let id = 1 +let username = "fancyuser" +let birthday = Date() +try await client.query(""" + INSERT INTO users (id, username, birthday) VALUES (\(id), \(username), \(birthday)) + """, + logger: logger +) ``` -`PostgresData` is also used for sending data _to_ the server via parameterized values. To create `PostgresData` from a Swift type, use the available intializer methods. +While this looks at first glance like a classic case of [SQL injection](https://en.wikipedia.org/wiki/SQL_injection) 😱, PostgresNIO's API ensures that this usage is safe. The first parameter of the [`query(_:logger:)`] method is not a plain `String`, but a [`PostgresQuery`], which implements Swift's `ExpressibleByStringInterpolation` protocol. PostgresNIO uses the literal parts of the provided string as the SQL query and replaces each interpolated value with a parameter binding. Only values which implement the [`PostgresEncodable`] protocol may be interpolated in this way. As with [`PostgresDecodable`], PostgresNIO provides default implementations for most common types. + +Some queries do not receive any rows from the server (most often `INSERT`, `UPDATE`, and `DELETE` queries with no `RETURNING` clause, not to mention most DDL queries). To support this, the [`query(_:logger:)`] method is marked `@discardableResult`, so that the compiler does not issue a warning if the return value is not used. + +## Security + +Please see [SECURITY.md] for details on the security process. + +[SSWG Incubation]: https://github.com/swift-server/sswg/blob/main/process/incubation.md#graduated-level +[Documentation]: https://api.vapor.codes/postgresnio/documentation/postgresnio +[Team Chat]: https://discord.gg/vapor +[MIT License]: LICENSE +[Continuous Integration]: https://github.com/vapor/postgres-nio/actions +[Swift 5.8]: https://swift.org +[Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md + +[`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection +[`PostgresClient`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient +[`run()`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient/run() +[`query(_:logger:)`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn +[`PostgresQuery`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresquery +[`PostgresRow`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrow +[`PostgresRowSequence`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrowsequence +[`PostgresDecodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresdecodable +[`PostgresEncodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresencodable +[SwiftNIO]: https://github.com/apple/swift-nio +[PostgresKit]: https://github.com/vapor/postgres-kit +[SwiftLog]: https://github.com/apple/swift-log +[ServiceLifecycle]: https://github.com/swift-server/swift-service-lifecycle +[`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html diff --git a/Snippets/Birthdays.swift b/Snippets/Birthdays.swift new file mode 100644 index 00000000..60516aa1 --- /dev/null +++ b/Snippets/Birthdays.swift @@ -0,0 +1,74 @@ +import PostgresNIO +import Foundation + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Birthday { + static func main() async throws { + // 1. Create a configuration to match server's parameters + let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "test_username", + password: "test_password", + database: "test_database", + tls: .disable + ) + + // 2. Create a client + let client = PostgresClient(configuration: config) + + // 3. Run the client + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // 4. Create a friends table to store data into + try await client.query(""" + CREATE TABLE IF NOT EXISTS "friends" ( + id SERIAL PRIMARY KEY, + given_name TEXT, + last_name TEXT, + birthday TIMESTAMP WITH TIME ZONE + ) + """ + ) + + // 5. Create a Swift friend representation + struct Friend { + var firstName: String + var lastName: String + var birthday: Date + } + + // 6. Create John Appleseed with special birthday + let dateFormatter = DateFormatter() + dateFormatter.dateFormat = "yyyy-MM-dd" + let johnsBirthday = dateFormatter.date(from: "1960-09-26")! + let friend = Friend(firstName: "Hans", lastName: "Müller", birthday: johnsBirthday) + + // 7. Store friend into the database + try await client.query(""" + INSERT INTO "friends" (given_name, last_name, birthday) + VALUES + (\(friend.firstName), \(friend.lastName), \(friend.birthday)); + """ + ) + + // 8. Query database for the friend we just inserted + let rows = try await client.query(""" + SELECT id, given_name, last_name, birthday FROM "friends" WHERE given_name = \(friend.firstName) + """ + ) + + // 9. Iterate the returned rows, decoding the rows into Swift primitives + for try await (id, firstName, lastName, birthday) in rows.decode((Int, String, String, Date).self) { + print("\(id) | \(firstName) \(lastName), \(birthday)") + } + + // 10. Shutdown the client, by cancelling its run method, through cancelling the taskGroup. + taskGroup.cancelAll() + } + } +} + diff --git a/Snippets/PostgresClient.swift b/Snippets/PostgresClient.swift new file mode 100644 index 00000000..9bfacc28 --- /dev/null +++ b/Snippets/PostgresClient.swift @@ -0,0 +1,40 @@ +import PostgresNIO +import struct Foundation.UUID + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Runner { + static func main() async throws { + +// snippet.configuration +let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable +) +// snippet.end + +// snippet.makeClient +let client = PostgresClient(configuration: config) +// snippet.end + + } + + static func runAndCancel(client: PostgresClient) async { +// snippet.run +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // You can use the client while the `client.run()` method is not cancelled. + + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} +// snippet.end + } +} + diff --git a/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift new file mode 100644 index 00000000..b428d805 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift @@ -0,0 +1,15 @@ +import Atomics + +public struct ConnectionIDGenerator: ConnectionIDGeneratorProtocol { + static let globalGenerator = ConnectionIDGenerator() + + private let atomic: ManagedAtomic + + public init() { + self.atomic = .init(0) + } + + public func next() -> Int { + return self.atomic.loadThenWrappingIncrement(ordering: .relaxed) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift new file mode 100644 index 00000000..5cdb980d --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -0,0 +1,597 @@ + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionAndMetadata: Sendable { + + public var connection: Connection + + public var maximalStreamsOnConnection: UInt16 + + public init(connection: Connection, maximalStreamsOnConnection: UInt16) { + self.connection = connection + self.maximalStreamsOnConnection = maximalStreamsOnConnection + } +} + +/// A connection that can be pooled in a ``ConnectionPool`` +public protocol PooledConnection: AnyObject, Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The connections identifier. The identifier is passed to + /// the connection factory method and must stay attached to + /// the connection at all times. It must not change during + /// the connections lifetime. + var id: ID { get } + + /// A method to register closures that are invoked when the + /// connection is closed. If the connection closed unexpectedly + /// the closure shall be called with the underlying error. + /// In most NIO clients this can be easily implemented by + /// attaching to the `channel.closeFuture`: + /// ``` + /// func onClose( + /// _ closure: @escaping @Sendable ((any Error)?) -> () + /// ) { + /// channel.closeFuture.whenComplete { _ in + /// closure(previousError) + /// } + /// } + /// ``` + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) + + /// Close the running connection. Once the close has completed + /// closures that were registered in `onClose` must be + /// invoked. + func close() +} + +/// A connection id generator. Its returned connection IDs will +/// be used when creating new ``PooledConnection``s +public protocol ConnectionIDGeneratorProtocol: Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The next connection ID that shall be used. + func next() -> ID +} + +/// A keep alive behavior for connections maintained by the pool +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public protocol ConnectionKeepAliveBehavior: Sendable { + /// the connection type + associatedtype Connection: PooledConnection + + /// The time after which a keep-alive shall + /// be triggered. + /// If nil is returned, keep-alive is deactivated + var keepAliveFrequency: Duration? { get } + + /// This method is invoked when the keep-alive shall be + /// run. + func runKeepAlive(for connection: Connection) async throws +} + +/// A request to get a connection from the `ConnectionPool` +public protocol ConnectionRequestProtocol: Sendable { + /// A connection lease request ID type. + associatedtype ID: Hashable & Sendable + /// The leased connection type + associatedtype Connection: PooledConnection + + /// A connection lease request ID. This ID must be generated + /// by users of the `ConnectionPool` outside the + /// `ConnectionPool`. It is not generated inside the pool like + /// the `ConnectionID`s. The lease request ID must be unique + /// and must not change, if your implementing type is a + /// reference type. + var id: ID { get } + + /// A function that is called with a connection or a + /// `PoolError`. + func complete(with: Result) +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionPoolConfiguration: Sendable { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes + /// idle connections, + /// the `ConnectionPool` will initiate new outbound + /// connections proactively to avoid the number of available + /// connections dropping below this number. + public var minimumConnectionCount: Int + + /// Between the `minimumConnectionCount` and + /// `maximumConnectionSoftLimit` the connection pool creates + /// _preserved_ connections. Preserved connections are closed + /// if they have been idle for ``idleTimeout``. + public var maximumConnectionSoftLimit: Int + + /// The maximum number of connections for this pool, that can + /// exist at any point in time. The pool can create _overflow_ + /// connections, if all connections are leased, and the + /// `maximumConnectionHardLimit` > `maximumConnectionSoftLimit ` + /// Overflow connections are closed immediately as soon as they + /// become idle. + public var maximumConnectionHardLimit: Int + + /// The time that a _preserved_ idle connection stays in the + /// pool before it is closed. + public var idleTimeout: Duration + + /// initializer + public init() { + self.minimumConnectionCount = 0 + self.maximumConnectionSoftLimit = 16 + self.maximumConnectionHardLimit = 16 + self.idleTimeout = .seconds(60) + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class ConnectionPool< + Connection: PooledConnection, + ConnectionID: Hashable & Sendable, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + Request: ConnectionRequestProtocol, + RequestID: Hashable & Sendable, + KeepAliveBehavior: ConnectionKeepAliveBehavior, + ObservabilityDelegate: ConnectionPoolObservabilityDelegate, + Clock: _Concurrency.Clock +>: Sendable where + Connection.ID == ConnectionID, + ConnectionIDGenerator.ID == ConnectionID, + Request.Connection == Connection, + Request.ID == RequestID, + KeepAliveBehavior.Connection == Connection, + ObservabilityDelegate.ConnectionID == ConnectionID, + Clock.Duration == Duration +{ + public typealias ConnectionFactory = @Sendable (ConnectionID, ConnectionPool) async throws -> ConnectionAndMetadata + + @usableFromInline + typealias StateMachine = PoolStateMachine> + + @usableFromInline + let factory: ConnectionFactory + + @usableFromInline + let keepAliveBehavior: KeepAliveBehavior + + @usableFromInline + let observabilityDelegate: ObservabilityDelegate + + @usableFromInline + let clock: Clock + + @usableFromInline + let configuration: ConnectionPoolConfiguration + + @usableFromInline + struct State: Sendable { + @usableFromInline + var stateMachine: StateMachine + @usableFromInline + var lastConnectError: (any Error)? + } + + @usableFromInline let stateBox: NIOLockedValueBox + + private let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + + @usableFromInline + let eventStream: AsyncStream + + @usableFromInline + let eventContinuation: AsyncStream.Continuation + + public init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator, + requestType: Request.Type, + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock, + connectionFactory: @escaping ConnectionFactory + ) { + self.clock = clock + self.factory = connectionFactory + self.keepAliveBehavior = keepAliveBehavior + self.observabilityDelegate = observabilityDelegate + self.configuration = configuration + var stateMachine = StateMachine( + configuration: .init(configuration, keepAliveBehavior: keepAliveBehavior), + generator: idGenerator, + timerCancellationTokenType: CheckedContinuation.self + ) + + let (stream, continuation) = AsyncStream.makeStream(of: NewPoolActions.self) + self.eventStream = stream + self.eventContinuation = continuation + + let connectionRequests = stateMachine.refillConnections() + + self.stateBox = NIOLockedValueBox(.init(stateMachine: stateMachine)) + + for request in connectionRequests { + self.eventContinuation.yield(.makeConnection(request)) + } + } + + @inlinable + public func releaseConnection(_ connection: Connection, streams: UInt16 = 1) { + self.modifyStateAndRunActions { state in + state.stateMachine.releaseConnection(connection, streams: streams) + } + } + + @inlinable + public func leaseConnection(_ request: Request) { + self.modifyStateAndRunActions { state in + state.stateMachine.leaseConnection(request) + } + } + + @inlinable + public func leaseConnections(_ requests: some Collection) { + let actions = self.stateBox.withLockedValue { state in + var actions = [StateMachine.Action]() + actions.reserveCapacity(requests.count) + + for request in requests { + let stateMachineAction = state.stateMachine.leaseConnection(request) + actions.append(stateMachineAction) + } + + return actions + } + + for action in actions { + self.runRequestAction(action.request) + self.runConnectionAction(action.connection) + } + } + + public func cancelLeaseConnection(_ requestID: RequestID) { + self.modifyStateAndRunActions { state in + state.stateMachine.cancelRequest(id: requestID) + } + } + + /// Mark a connection as going away. Connection implementors have to call this method if the connection + /// has received a close intent from the server. For example: an HTTP/2 GOWAY frame. + public func connectionWillClose(_ connection: Connection) { + + } + + public func connectionReceivedNewMaxStreamSetting(_ connection: Connection, newMaxStreamSetting maxStreams: UInt16) { + self.modifyStateAndRunActions { state in + state.stateMachine.connectionReceivedNewMaxStreamSetting(connection.id, newMaxStreamSetting: maxStreams) + } + } + + public func run() async { + await withTaskCancellationHandler { + #if os(Linux) || compiler(>=5.9) + if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { + return await withDiscardingTaskGroup() { taskGroup in + await self.run(in: &taskGroup) + } + } + #endif + return await withTaskGroup(of: Void.self) { taskGroup in + await self.run(in: &taskGroup) + } + } onCancel: { + let actions = self.stateBox.withLockedValue { state in + state.stateMachine.triggerForceShutdown() + } + + self.runStateMachineActions(actions) + } + } + + // MARK: - Private Methods - + + @inlinable + func connectionDidClose(_ connection: Connection, error: (any Error)?) { + self.observabilityDelegate.connectionClosed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionClosed(connection) + } + } + + // MARK: Events + + @usableFromInline + enum NewPoolActions: Sendable { + case makeConnection(StateMachine.ConnectionRequest) + case runKeepAlive(Connection) + + case scheduleTimer(StateMachine.Timer) + } + + #if os(Linux) || compiler(>=5.9) + @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) + private func run(in taskGroup: inout DiscardingTaskGroup) async { + for await event in self.eventStream { + self.runEvent(event, in: &taskGroup) + } + } + #endif + + private func run(in taskGroup: inout TaskGroup) async { + var running = 0 + for await event in self.eventStream { + running += 1 + self.runEvent(event, in: &taskGroup) + + if running == 100 { + _ = await taskGroup.next() + running -= 1 + } + } + } + + private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + switch event { + case .makeConnection(let request): + self.makeConnection(for: request, in: &taskGroup) + + case .runKeepAlive(let connection): + self.runKeepAlive(connection, in: &taskGroup) + + case .scheduleTimer(let timer): + self.runTimer(timer, in: &taskGroup) + } + } + + // MARK: Run actions + + @inlinable + /*private*/ func modifyStateAndRunActions(_ closure: (inout State) -> StateMachine.Action) { + let actions = self.stateBox.withLockedValue { state -> StateMachine.Action in + closure(&state) + } + self.runStateMachineActions(actions) + } + + @inlinable + /*private*/ func runStateMachineActions(_ actions: StateMachine.Action) { + self.runConnectionAction(actions.connection) + self.runRequestAction(actions.request) + } + + @inlinable + /*private*/ func runConnectionAction(_ action: StateMachine.ConnectionAction) { + switch action { + case .makeConnection(let request, let timers): + self.cancelTimers(timers) + self.eventContinuation.yield(.makeConnection(request)) + + case .runKeepAlive(let connection, let cancelContinuation): + cancelContinuation?.resume(returning: ()) + self.eventContinuation.yield(.runKeepAlive(connection)) + + case .scheduleTimers(let timers): + for timer in timers { + self.eventContinuation.yield(.scheduleTimer(timer)) + } + + case .cancelTimers(let timers): + self.cancelTimers(timers) + + case .closeConnection(let connection, let timers): + self.closeConnection(connection) + self.cancelTimers(timers) + + case .shutdown(let cleanup): + for connection in cleanup.connections { + self.closeConnection(connection) + } + self.cancelTimers(cleanup.timersToCancel) + + case .none: + break + } + } + + @inlinable + /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { + switch action { + case .leaseConnection(let requests, let connection): + for request in requests { + request.complete(with: .success(connection)) + } + + case .failRequest(let request, let error): + request.complete(with: .failure(error)) + + case .failRequests(let requests, let error): + for request in requests { request.complete(with: .failure(error)) } + + case .none: + break + } + } + + @inlinable + /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { + taskGroup.addTask_ { + self.observabilityDelegate.startedConnecting(id: request.connectionID) + + do { + let bundle = try await self.factory(request.connectionID, self) + self.connectionEstablished(bundle) + + // after the connection has been established, we keep the task open. This ensures + // that the pools run method can not be exited before all connections have been + // closed. + await withCheckedContinuation { (continuation: CheckedContinuation) in + bundle.connection.onClose { + self.connectionDidClose(bundle.connection, error: $0) + continuation.resume() + } + } + } catch { + self.connectionEstablishFailed(error, for: request) + } + } + } + + @inlinable + /*private*/ func connectionEstablished(_ connectionBundle: ConnectionAndMetadata) { + self.observabilityDelegate.connectSucceeded(id: connectionBundle.connection.id, streamCapacity: connectionBundle.maximalStreamsOnConnection) + + self.modifyStateAndRunActions { state in + state.lastConnectError = nil + return state.stateMachine.connectionEstablished( + connectionBundle.connection, + maxStreams: connectionBundle.maximalStreamsOnConnection + ) + } + } + + @inlinable + /*private*/ func connectionEstablishFailed(_ error: Error, for request: StateMachine.ConnectionRequest) { + self.observabilityDelegate.connectFailed(id: request.connectionID, error: error) + + self.modifyStateAndRunActions { state in + state.lastConnectError = error + return state.stateMachine.connectionEstablishFailed(error, for: request) + } + } + + @inlinable + /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { + self.observabilityDelegate.keepAliveTriggered(id: connection.id) + + taskGroup.addTask_ { + do { + try await self.keepAliveBehavior.runKeepAlive(for: connection) + + self.observabilityDelegate.keepAliveSucceeded(id: connection.id) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionKeepAliveDone(connection) + } + } catch { + self.observabilityDelegate.keepAliveFailed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionKeepAliveFailed(connection.id) + } + } + } + } + + @inlinable + /*private*/ func closeConnection(_ connection: Connection) { + self.observabilityDelegate.connectionClosing(id: connection.id) + + connection.close() + } + + @usableFromInline + enum TimerRunResult: Sendable { + case timerTriggered + case timerCancelled + case cancellationContinuationFinished + } + + @inlinable + /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { + poolGroup.addTask_ { () async -> () in + await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in + taskGroup.addTask { + do { + #if os(Linux) || compiler(>=5.9) + try await self.clock.sleep(for: timer.duration) + #else + try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) + #endif + return .timerTriggered + } catch { + return .timerCancelled + } + } + + taskGroup.addTask { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let continuation = self.stateBox.withLockedValue { state in + state.stateMachine.timerScheduled(timer, cancelContinuation: continuation) + } + + continuation?.resume(returning: ()) + } + + return .cancellationContinuationFinished + } + + switch await taskGroup.next()! { + case .cancellationContinuationFinished: + taskGroup.cancelAll() + + case .timerTriggered: + let action = self.stateBox.withLockedValue { state in + state.stateMachine.timerTriggered(timer) + } + + self.runStateMachineActions(action) + + case .timerCancelled: + // the only way to reach this, is if the state machine decided to cancel the + // timer. therefore we don't need to report it back! + break + } + + return + } + } + } + + @inlinable + /*private*/ func cancelTimers(_ cancellationTokens: some Sequence>) { + for token in cancellationTokens { + token.resume() + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolConfiguration { + init(_ configuration: ConnectionPoolConfiguration, keepAliveBehavior: KeepAliveBehavior) { + self.minimumConnectionCount = configuration.minimumConnectionCount + self.maximumConnectionSoftLimit = configuration.maximumConnectionSoftLimit + self.maximumConnectionHardLimit = configuration.maximumConnectionHardLimit + self.keepAliveDuration = keepAliveBehavior.keepAliveFrequency + self.idleTimeoutDuration = configuration.idleTimeout + } +} + +@usableFromInline +protocol TaskGroupProtocol { + // We need to call this `addTask_` because some Swift versions define this + // under exactly this name and others have different attributes. So let's pick + // a name that doesn't clash anywhere and implement it using the standard `addTask`. + mutating func addTask_(operation: @escaping @Sendable () async -> Void) +} + +#if os(Linux) || swift(>=5.9) +@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) +extension DiscardingTaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} +#endif + +extension TaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift new file mode 100644 index 00000000..1f1e1d2c --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -0,0 +1,16 @@ + +public struct ConnectionPoolError: Error, Hashable { + enum Base: Error, Hashable { + case requestCancelled + case poolShutdown + } + + private let base: Base + + init(_ base: Base) { self.base = base } + + /// The connection requests got cancelled + public static let requestCancelled = ConnectionPoolError(.requestCancelled) + /// The connection requests can't be fulfilled as the pool has already been shutdown + public static let poolShutdown = ConnectionPoolError(.poolShutdown) +} diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift new file mode 100644 index 00000000..fc1e300c --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -0,0 +1,62 @@ + +public protocol ConnectionPoolObservabilityDelegate: Sendable { + associatedtype ConnectionID: Hashable & Sendable + + /// The connection with the given ID has started trying to establish a connection. The outcome + /// of the connection will be reported as either ``connectSucceeded(id:streamCapacity:)`` or + /// ``connectFailed(id:error:)``. + func startedConnecting(id: ConnectionID) + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) + + /// A connection was established on the connection with the given ID. `streamCapacity` streams are + /// available to use on the connection. The maximum number of available streams may change over + /// time and is reported via ````. The + func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionUtilizationChanged(id:ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) + + func keepAliveTriggered(id: ConnectionID) + + func keepAliveSucceeded(id: ConnectionID) + + func keepAliveFailed(id: ConnectionID, error: Error) + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) + + func requestQueueDepthChanged(_ newDepth: Int) +} + +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { + public init(connectionIDType: ConnectionID.Type) {} + + public func startedConnecting(id: ConnectionID) {} + + public func connectFailed(id: ConnectionID, error: Error) {} + + public func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) {} + + public func connectionUtilizationChanged(id: ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) {} + + public func keepAliveTriggered(id: ConnectionID) {} + + public func keepAliveSucceeded(id: ConnectionID) {} + + public func keepAliveFailed(id: ConnectionID, error: Error) {} + + public func connectionClosing(id: ConnectionID) {} + + public func connectionClosed(id: ConnectionID, error: Error?) {} + + public func requestQueueDepthChanged(_ newDepth: Int) {} +} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift new file mode 100644 index 00000000..1d1c55da --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -0,0 +1,78 @@ + +public struct ConnectionRequest: ConnectionRequestProtocol { + public typealias ID = Int + + public var id: ID + + @usableFromInline + private(set) var continuation: CheckedContinuation + + @inlinable + init( + id: Int, + continuation: CheckedContinuation + ) { + self.id = id + self.continuation = continuation + } + + public func complete(with result: Result) { + self.continuation.resume(with: result) + } +} + +@usableFromInline +let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPool where Request == ConnectionRequest { + public convenience init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator(), + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock = ContinuousClock(), + connectionFactory: @escaping ConnectionFactory + ) { + self.init( + configuration: configuration, + idGenerator: idGenerator, + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAliveBehavior, + observabilityDelegate: observabilityDelegate, + clock: clock, + connectionFactory: connectionFactory + ) + } + + @inlinable + public func leaseConnection() async throws -> Connection { + let requestID = requestIDGenerator.next() + + let connection = try await withTaskCancellationHandler { + if Task.isCancelled { + throw CancellationError() + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = Request( + id: requestID, + continuation: continuation + ) + + self.leaseConnection(request) + } + } onCancel: { + self.cancelLeaseConnection(requestID) + } + + return connection + } + + @inlinable + public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + defer { self.releaseConnection(connection) } + return try await closure(connection) + } +} diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift new file mode 100644 index 00000000..9b7d972b --- /dev/null +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -0,0 +1,104 @@ +// A `Sequence` that can contain at most two elements. However it does not heap allocate. +@usableFromInline +struct Max2Sequence: Sequence { + @usableFromInline + private(set) var first: Element? + @usableFromInline + private(set) var second: Element? + + @inlinable + var count: Int { + if self.first == nil { return 0 } + if self.second == nil { return 1 } + return 2 + } + + @inlinable + var isEmpty: Bool { + self.first == nil + } + + @inlinable + init(_ first: Element?, _ second: Element? = nil) { + if let first = first { + self.first = first + self.second = second + } else { + self.first = second + self.second = nil + } + } + + @inlinable + init() { + self.first = nil + self.second = nil + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(first: self.first, second: self.second) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline + let first: Element? + @usableFromInline + let second: Element? + + @usableFromInline + private(set) var index: UInt8 = 0 + + @inlinable + init(first: Element?, second: Element?) { + self.first = first + self.second = second + self.index = 0 + } + + @inlinable + mutating func next() -> Element? { + switch self.index { + case 0: + self.index += 1 + return self.first + case 1: + self.index += 1 + return self.second + default: + return nil + } + } + } + + @inlinable + mutating func append(_ element: Element) { + precondition(self.second == nil) + if self.first == nil { + self.first = element + } else if self.second == nil { + self.second = element + } else { + fatalError("Max2Sequence can only hold two Elements.") + } + } + + @inlinable + func map(_ transform: (Element) throws -> (NewElement)) rethrows -> Max2Sequence { + try Max2Sequence(self.first.flatMap(transform), self.second.flatMap(transform)) + } +} + +extension Max2Sequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + precondition(elements.count <= 2) + var iterator = elements.makeIterator() + self.init(iterator.next(), iterator.next()) + } +} + +extension Max2Sequence: Equatable where Element: Equatable {} +extension Max2Sequence: Hashable where Element: Hashable {} +extension Max2Sequence: Sendable where Element: Sendable {} diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift new file mode 100644 index 00000000..b6cd7164 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -0,0 +1,279 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Darwin) +import Darwin +#elseif os(Windows) +import ucrt +import WinSDK +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#elseif canImport(Bionic) +import Bionic +#elseif canImport(WASILibc) +import WASILibc +#if canImport(wasi_pthread) +import wasi_pthread +#endif +#else +#error("The concurrency NIOLock module was unable to identify your C library.") +#endif + +#if os(Windows) +@usableFromInline +typealias LockPrimitive = SRWLOCK +#else +@usableFromInline +typealias LockPrimitive = pthread_mutex_t +#endif + +@usableFromInline +enum LockOperations {} + +extension LockOperations { + @inlinable + static func create(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + InitializeSRWLock(mutex) + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) + var attr = pthread_mutexattr_t() + pthread_mutexattr_init(&attr) + debugOnly { + pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) + } + + let err = pthread_mutex_init(mutex, &attr) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable + static func destroy(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + // SRWLOCK does not need to be free'd + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) + let err = pthread_mutex_destroy(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable + static func lock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + AcquireSRWLockExclusive(mutex) + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) + let err = pthread_mutex_lock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable + static func unlock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + ReleaseSRWLockExclusive(mutex) + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) + let err = pthread_mutex_unlock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } +} + +// Tail allocate both the mutex and a generic value using ManagedBuffer. +// Both the header pointer and the elements pointer are stable for +// the class's entire lifetime. +// +// However, for safety reasons, we elect to place the lock in the "elements" +// section of the buffer instead of the head. The reasoning here is subtle, +// so buckle in. +// +// _As a practical matter_, the implementation of ManagedBuffer ensures that +// the pointer to the header is stable across the lifetime of the class, and so +// each time you call `withUnsafeMutablePointers` or `withUnsafeMutablePointerToHeader` +// the value of the header pointer will be the same. This is because ManagedBuffer uses +// `Builtin.addressOf` to load the value of the header, and that does ~magic~ to ensure +// that it does not invoke any weird Swift accessors that might copy the value. +// +// _However_, the header is also available via the `.header` field on the ManagedBuffer. +// This presents a problem! The reason there's an issue is that `Builtin.addressOf` and friends +// do not interact with Swift's exclusivity model. That is, the various `with` functions do not +// conceptually trigger a mutating access to `.header`. For elements this isn't a concern because +// there's literally no other way to perform the access, but for `.header` it's entirely possible +// to accidentally recursively read it. +// +// Our implementation is free from these issues, so we don't _really_ need to worry about it. +// However, out of an abundance of caution, we store the Value in the header, and the LockPrimitive +// in the trailing elements. We still don't use `.header`, but it's better to be safe than sorry, +// and future maintainers will be happier that we were cautious. +// +// See also: https://github.com/apple/swift/pull/40000 +@usableFromInline +final class LockStorage: ManagedBuffer { + + @inlinable + static func create(value: Value) -> Self { + let buffer = Self.create(minimumCapacity: 1) { _ in + value + } + // Intentionally using a force cast here to avoid a miss compiliation in 5.10. + // This is as fast as an unsafeDownCast since ManagedBuffer is inlined and the optimizer + // can eliminate the upcast/downcast pair + let storage = buffer as! Self + + storage.withUnsafeMutablePointers { _, lockPtr in + LockOperations.create(lockPtr) + } + + return storage + } + + @inlinable + func lock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.lock(lockPtr) + } + } + + @inlinable + func unlock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.unlock(lockPtr) + } + } + + @inlinable + deinit { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.destroy(lockPtr) + } + } + + @inlinable + func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointerToElements { lockPtr in + try body(lockPtr) + } + } + + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointers { valuePtr, lockPtr in + LockOperations.lock(lockPtr) + defer { LockOperations.unlock(lockPtr) } + return try mutate(&valuePtr.pointee) + } + } +} + +/// A threading lock based on `libpthread` instead of `libdispatch`. +/// +/// - Note: ``NIOLock`` has reference semantics. +/// +/// This object provides a lock on top of a single `pthread_mutex_t`. This kind +/// of lock is safe to use with `libpthread`-based threading models, such as the +/// one used by NIO. On Windows, the lock is based on the substantially similar +/// `SRWLOCK` type. +struct NIOLock { + @usableFromInline + internal let _storage: LockStorage + + /// Create a new lock. + @inlinable + init() { + self._storage = .create(value: ()) + } + + /// Acquire the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `unlock`, to simplify lock handling. + @inlinable + func lock() { + self._storage.lock() + } + + /// Release the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `lock`, to simplify lock handling. + @inlinable + func unlock() { + self._storage.unlock() + } + + @inlinable + internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + try self._storage.withLockPrimitive(body) + } +} + +extension NIOLock { + /// Acquire the lock for the duration of the given block. + /// + /// This convenience method should be preferred to `lock` and `unlock` in + /// most situations, as it ensures that the lock will be released regardless + /// of how `body` exits. + /// + /// - Parameter body: The block to execute while holding the lock. + /// - Returns: The value returned by the block. + @inlinable + func withLock(_ body: () throws -> T) rethrows -> T { + self.lock() + defer { + self.unlock() + } + return try body() + } + + @inlinable + func withLockVoid(_ body: () throws -> Void) rethrows { + try self.withLock(body) + } +} + +extension NIOLock: @unchecked Sendable {} + +extension UnsafeMutablePointer { + @inlinable + func assertValidAlignment() { + assert(UInt(bitPattern: self) % UInt(MemoryLayout.alignment) == 0) + } +} + +/// A utility function that runs the body code only in debug builds, without +/// emitting compiler warnings. +/// +/// This is currently the only way to do this in Swift: see +/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. +@inlinable +internal func debugOnly(_ body: () -> Void) { + assert( + { + body() + return true + }() + ) +} diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift new file mode 100644 index 00000000..c9cd89e0 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -0,0 +1,86 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// Provides locked access to `Value`. +/// +/// - Note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// alongside a lock behind a reference. +/// +/// This is no different than creating a ``Lock`` and protecting all +/// accesses to a value using the lock. But it's easy to forget to actually +/// acquire/release the lock in the correct place. ``NIOLockedValueBox`` makes +/// that much easier. +@usableFromInline +struct NIOLockedValueBox { + + @usableFromInline + internal let _storage: LockStorage + + /// Initialize the `Value`. + @inlinable + init(_ value: Value) { + self._storage = .create(value: value) + } + + /// Access the `Value`, allowing mutation of it. + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + try self._storage.withLockedValue(mutate) + } + + /// Provides an unsafe view over the lock and its value. + /// + /// This can be beneficial when you require fine grained control over the lock in some + /// situations but don't want lose the benefits of ``withLockedValue(_:)`` in others by + /// switching to ``NIOLock``. + var unsafe: Unsafe { + Unsafe(_storage: self._storage) + } + + /// Provides an unsafe view over the lock and its value. + struct Unsafe { + @usableFromInline + let _storage: LockStorage + + /// Manually acquire the lock. + @inlinable + func lock() { + self._storage.lock() + } + + /// Manually release the lock. + @inlinable + func unlock() { + self._storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + /// + /// - Parameter mutate: A closure with scoped access to the value. + /// - Returns: The result of the `mutate` closure. + @inlinable + func withValueAssumingLockIsAcquired( + _ mutate: (_ value: inout Value) throws -> Result + ) rethrows -> Result { + try self._storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } + } +} + +extension NIOLockedValueBox: @unchecked Sendable where Value: Sendable {} diff --git a/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift new file mode 100644 index 00000000..0a7b2dee --- /dev/null +++ b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift @@ -0,0 +1,8 @@ +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct NoOpKeepAliveBehavior: ConnectionKeepAliveBehavior { + public var keepAliveFrequency: Duration? { nil } + + public func runKeepAlive(for connection: Connection) async throws {} + + public init(connectionType: Connection.Type) {} +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift new file mode 100644 index 00000000..a8e97ffd --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -0,0 +1,733 @@ +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + @usableFromInline + struct LeaseResult { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + @usableFromInline + var use: ConnectionGroup.ConnectionUse + + @inlinable + init( + connection: Connection, + timersToCancel: Max2Sequence, + wasIdle: Bool, + use: ConnectionGroup.ConnectionUse + ) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + self.use = use + } + } + + @usableFromInline + struct ConnectionGroup: Sendable { + @usableFromInline + struct Stats: Hashable, Sendable { + @usableFromInline var connecting: UInt16 = 0 + @usableFromInline var backingOff: UInt16 = 0 + @usableFromInline var idle: UInt16 = 0 + @usableFromInline var leased: UInt16 = 0 + @usableFromInline var runningKeepAlive: UInt16 = 0 + @usableFromInline var closing: UInt16 = 0 + + @usableFromInline var availableStreams: UInt16 = 0 + @usableFromInline var leasedStreams: UInt16 = 0 + + @usableFromInline var soonAvailable: UInt16 { + self.connecting + self.backingOff + self.runningKeepAlive + } + + @usableFromInline var active: UInt16 { + self.idle + self.leased + self.connecting + self.backingOff + } + } + + /// The minimum number of connections + @usableFromInline + let minimumConcurrentConnections: Int + + /// The maximum number of preserved connections + @usableFromInline + let maximumConcurrentConnectionSoftLimit: Int + + /// The absolute maximum number of connections + @usableFromInline + let maximumConcurrentConnectionHardLimit: Int + + @usableFromInline + let keepAlive: Bool + + @usableFromInline + let keepAliveReducesAvailableStreams: Bool + + /// A connectionID generator. + @usableFromInline + let generator: ConnectionIDGenerator + + /// The connections states + @usableFromInline + private(set) var connections: [ConnectionState] + + @usableFromInline + private(set) var stats = Stats() + + @inlinable + init( + generator: ConnectionIDGenerator, + minimumConcurrentConnections: Int, + maximumConcurrentConnectionSoftLimit: Int, + maximumConcurrentConnectionHardLimit: Int, + keepAlive: Bool, + keepAliveReducesAvailableStreams: Bool + ) { + self.generator = generator + self.connections = [] + self.minimumConcurrentConnections = minimumConcurrentConnections + self.maximumConcurrentConnectionSoftLimit = maximumConcurrentConnectionSoftLimit + self.maximumConcurrentConnectionHardLimit = maximumConcurrentConnectionHardLimit + self.keepAlive = keepAlive + self.keepAliveReducesAvailableStreams = keepAliveReducesAvailableStreams + } + + var isEmpty: Bool { + self.connections.isEmpty + } + + @usableFromInline + var canGrow: Bool { + self.stats.active < self.maximumConcurrentConnectionHardLimit + } + + @usableFromInline + var soonAvailableConnections: UInt16 { + self.stats.soonAvailable + } + + // MARK: - Mutations - + + /// A connection's use. Is it persisted or an overflow connection? + @usableFromInline + enum ConnectionUse: Equatable { + case persisted + case demand + case overflow + } + + /// Information around an idle connection. + @usableFromInline + struct AvailableConnectionContext { + /// The connection's use. Either general purpose or for requests with `EventLoop` + /// requirements. + @usableFromInline + var use: ConnectionUse + + @usableFromInline + var info: ConnectionAvailableInfo + + @inlinable + init(use: ConnectionUse, info: ConnectionAvailableInfo) { + self.use = use + self.info = info + } + } + + mutating func refillConnections() -> [ConnectionRequest] { + let existingConnections = self.stats.active + let missingConnection = self.minimumConcurrentConnections - Int(existingConnections) + guard missingConnection > 0 else { + return [] + } + + var requests = [ConnectionRequest]() + requests.reserveCapacity(missingConnection) + + for _ in 0.. ConnectionRequest? { + precondition(self.minimumConcurrentConnections <= self.stats.active) + guard self.maximumConcurrentConnectionSoftLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + mutating func createNewOverflowConnectionIfPossible() -> ConnectionRequest? { + precondition(self.maximumConcurrentConnectionSoftLimit <= self.stats.active) + guard self.maximumConcurrentConnectionHardLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + /*private*/ mutating func createNewConnection() -> ConnectionRequest { + precondition(self.canGrow) + self.stats.connecting += 1 + let connectionID = self.generator.next() + let connection = ConnectionState(id: connectionID) + self.connections.append(connection) + return ConnectionRequest(connectionID: connectionID) + } + + /// A new ``Connection`` was established. + /// + /// This will put the connection into the idle state. + /// + /// - Parameter connection: The new established connection. + /// - Returns: An index and an IdleConnectionContext to determine the next action for the now idle connection. + /// Call ``parkConnection(at:)``, ``leaseConnection(at:)`` or ``closeConnection(at:)`` + /// with the supplied index after this. + @inlinable + mutating func newConnectionEstablished(_ connection: Connection, maxStreams: UInt16) -> (Int, AvailableConnectionContext) { + guard let index = self.connections.firstIndex(where: { $0.id == connection.id }) else { + preconditionFailure("There is a new connection that we didn't request!") + } + self.stats.connecting -= 1 + self.stats.idle += 1 + self.stats.availableStreams += maxStreams + let connectionInfo = self.connections[index].connected(connection, maxStreams: maxStreams) + // TODO: If this is an overflow connection, but we are currently also creating a + // persisted connection, we might want to swap those. + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func backoffNextConnectionAttempt(_ connectionID: Connection.ID) -> ConnectionTimer { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.connecting -= 1 + self.stats.backingOff += 1 + + return self.connections[index].failedToConnect() + } + + @usableFromInline + enum BackoffDoneAction { + case createConnection(ConnectionRequest, TimerCancellationToken?) + case cancelTimers(Max2Sequence) + } + + @inlinable + mutating func backoffDone(_ connectionID: Connection.ID, retry: Bool) -> BackoffDoneAction { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.backingOff -= 1 + + if retry || self.stats.active < self.minimumConcurrentConnections { + self.stats.connecting += 1 + let backoffTimerCancellation = self.connections[index].retryConnect() + return .createConnection(.init(connectionID: connectionID), backoffTimerCancellation) + } + + let backoffTimerCancellation = self.connections[index].destroyBackingOffConnection() + var timerCancellations = Max2Sequence(backoffTimerCancellation) + + if let timerCancellationToken = self.swapForDeletion(index: index) { + timerCancellations.append(timerCancellationToken) + } + return .cancelTimers(timerCancellations) + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + guard let index = self.connections.firstIndex(where: { $0.id == timer.connectionID }) else { + return cancelContinuation + } + + return self.connections[index].timerScheduled(timer, cancelContinuation: cancelContinuation) + } + + // MARK: Changes at runtime + + @usableFromInline + struct NewMaxStreamInfo { + + @usableFromInline + var index: Int + + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(index: Int, info: ConnectionState.NewMaxStreamInfo) { + self.index = index + self.newMaxStreams = info.newMaxStreams + self.oldMaxStreams = info.oldMaxStreams + self.usedStreams = info.usedStreams + } + } + + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connectionID: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> NewMaxStreamInfo? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + guard let info = self.connections[index].newMaxStreamSetting(maxStreams) else { + return nil + } + + self.stats.availableStreams += maxStreams - info.oldMaxStreams + + return NewMaxStreamInfo(index: index, info: info) + } + + // MARK: Leasing and releasing + + /// Lease a connection, if an idle connection is available. + /// + /// - Returns: A connection to execute a request on. + @inlinable + mutating func leaseConnection() -> LeaseResult? { + if self.stats.availableStreams == 0 { + return nil + } + + guard let index = self.findAvailableConnection() else { + preconditionFailure("Stats and actual count are of.") + } + + return self.leaseConnection(at: index, streams: 1) + } + + @usableFromInline + enum LeasedConnectionOrStartingCount { + case leasedConnection(LeaseResult) + case startingCount(UInt16) + } + + @inlinable + mutating func leaseConnectionOrSoonAvailableConnectionCount() -> LeasedConnectionOrStartingCount { + if let result = self.leaseConnection() { + return .leasedConnection(result) + } + return .startingCount(self.stats.soonAvailable) + } + + @inlinable + mutating func leaseConnection(at index: Int, streams: UInt16) -> LeaseResult { + let leaseResult = self.connections[index].lease(streams: streams) + let use = self.getConnectionUse(index: index) + + if leaseResult.wasIdle { + self.stats.idle -= 1 + self.stats.leased += 1 + } + self.stats.leasedStreams += streams + self.stats.availableStreams -= streams + return LeaseResult( + connection: leaseResult.connection, + timersToCancel: leaseResult.timersToCancel, + wasIdle: leaseResult.wasIdle, + use: use + ) + } + + @inlinable + mutating func parkConnection(at index: Int, hasBecomeIdle newIdle: Bool) -> Max2Sequence { + let scheduleIdleTimeoutTimer: Bool + switch index { + case 0.. (Int, AvailableConnectionContext)? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + let connectionInfo = self.connections[index].release(streams: streams) + self.stats.availableStreams += streams + self.stats.leasedStreams -= streams + switch connectionInfo { + case .idle: + self.stats.idle += 1 + self.stats.leased -= 1 + case .leased: + break + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func keepAliveIfIdle(_ connectionID: Connection.ID) -> KeepAliveAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of ping pong) + // was already removed from the state machine. + return nil + } + + guard let action = self.connections[index].runKeepAliveIfIdle(reducesAvailableStreams: self.keepAliveReducesAvailableStreams) else { + return nil + } + + self.stats.runningKeepAlive += 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams -= 1 + } + + return action + } + + @inlinable + mutating func keepAliveSucceeded(_ connectionID: Connection.ID) -> (Int, AvailableConnectionContext)? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // keepAliveSucceeded can race against, closeIfIdle, shutdowns or connection errors + return nil + } + + guard let connectionInfo = self.connections[index].keepAliveSucceeded() else { + // if we don't get connection info here this means, that the connection already was + // transitioned to closing. when we did this we already decremented the + // runningKeepAlive timer. + return nil + } + + self.stats.runningKeepAlive -= 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams += 1 + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func keepAliveFailed(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // Connection has already been closed + return nil + } + + guard let closeAction = self.connections[index].keepAliveFailed() else { + return nil + } + + self.stats.idle -= 1 + self.stats.closing += 1 + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams + + // force unwrapping the connection is fine, because a close action due to failed + // keepAlive cannot happen without a connection + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + + // MARK: Connection close/removal + + @usableFromInline + struct CloseAction { + @usableFromInline + private(set) var connection: Connection + + @usableFromInline + private(set) var timersToCancel: Max2Sequence + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence) { + self.connection = connection + self.timersToCancel = timersToCancel + } + } + + /// Closes the connection at the given index. + @inlinable + mutating func closeConnectionIfIdle(at index: Int) -> CloseAction? { + guard let closeAction = self.connections[index].closeIfIdle() else { + return nil // apparently the connection isn't idle + } + + self.stats.idle -= 1 + self.stats.closing += 1 + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams + + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + + @inlinable + mutating func closeConnectionIfIdle(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of timeout) + // was already removed from the state machine. + return nil + } + + if index < self.minimumConcurrentConnections { + // because of a race a connection might receive a idle timeout after it was moved into + // the persisted connections. If a connection is now persisted, we now need to ignore + // the trigger + return nil + } + + return self.closeConnectionIfIdle(at: index) + } + + /// Information around the failed/closed connection. + @usableFromInline + struct ClosedAction { + /// Connections that are currently starting + @usableFromInline + var connectionsStarting: Int + + @usableFromInline + var timersToCancel: TinyFastSequence + + @usableFromInline + var newConnectionRequest: ConnectionRequest? + + @inlinable + init( + connectionsStarting: Int, + timersToCancel: TinyFastSequence, + newConnectionRequest: ConnectionRequest? = nil + ) { + self.connectionsStarting = connectionsStarting + self.timersToCancel = timersToCancel + self.newConnectionRequest = newConnectionRequest + } + } + + /// Connection closed. Call this method, if a connection is closed. + /// + /// This will put the position into the closed state. + /// + /// - Parameter connectionID: The failed connection's id. + /// - Returns: An optional index and an IdleConnectionContext to determine the next action for the closed connection. + /// You must call ``removeConnection(at:)`` or ``replaceConnection(at:)`` with the + /// supplied index after this. If nil is returned the connection was closed by the state machine and was + /// therefore already removed. + @inlinable + mutating func connectionClosed(_ connectionID: Connection.ID) -> ClosedAction { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("All connections that have been created should say goodbye exactly once!") + } + + let closedAction = self.connections[index].closed() + var timersToCancel = TinyFastSequence(closedAction.cancelTimers) + + if closedAction.wasRunningKeepAlive { + self.stats.runningKeepAlive -= 1 + } + self.stats.leasedStreams -= closedAction.usedStreams + self.stats.availableStreams -= closedAction.maxStreams - closedAction.usedStreams + + switch closedAction.previousConnectionState { + case .idle: + self.stats.idle -= 1 + + case .leased: + self.stats.leased -= 1 + + case .closing: + self.stats.closing -= 1 + } + + if let cancellationTimer = self.swapForDeletion(index: index) { + timersToCancel.append(cancellationTimer) + } + + let newConnectionRequest: ConnectionRequest? + if self.connections.count < self.minimumConcurrentConnections { + newConnectionRequest = self.createNewConnection() + } else { + newConnectionRequest = .none + } + + return ClosedAction( + connectionsStarting: 0, + timersToCancel: timersToCancel, + newConnectionRequest: newConnectionRequest + ) + } + + // MARK: Shutdown + + mutating func triggerForceShutdown(_ cleanup: inout ConnectionAction.Shutdown) { + for var connectionState in self.connections { + guard let closeAction = connectionState.close() else { + continue + } + + if let connection = closeAction.connection { + cleanup.connections.append(connection) + } + cleanup.timersToCancel.append(contentsOf: closeAction.cancelTimers) + } + + self.connections = [] + } + + // MARK: - Private functions - + + @inlinable + /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { + switch index { + case 0.. AvailableConnectionContext { + precondition(self.connections[index].isAvailable) + let use = self.getConnectionUse(index: index) + return AvailableConnectionContext(use: use, info: info) + } + + @inlinable + /*private*/ func findAvailableConnection() -> Int? { + return self.connections.firstIndex(where: { $0.isAvailable }) + } + + @inlinable + /*private*/ mutating func swapForDeletion(index indexToDelete: Int) -> TimerCancellationToken? { + let maybeLastConnectedIndex = self.connections.lastIndex(where: { $0.isConnected }) + + if maybeLastConnectedIndex == nil || maybeLastConnectedIndex! < indexToDelete { + self.removeO1(indexToDelete) + return nil + } + + // if maybeLastConnectedIndex == nil, we return early in the above if case. + let lastConnectedIndex = maybeLastConnectedIndex! + + switch indexToDelete { + case 0.. TimerCancellationToken? { + switch self { + case .scheduled(let timer): + self = .notScheduled + return timer.cancellationContinuation + case .running, .notScheduled: + return nil + } + } + } + + @usableFromInline + struct Timer: Sendable { + @usableFromInline + let timerID: Int + + @usableFromInline + private(set) var cancellationContinuation: TimerCancellationToken? + + @inlinable + init(id: Int) { + self.timerID = id + self.cancellationContinuation = nil + } + + @inlinable + mutating func registerCancellationContinuation(_ continuation: TimerCancellationToken) { + precondition(self.cancellationContinuation == nil) + self.cancellationContinuation = continuation + } + } + + /// The pool is creating a connection. Valid transitions are to: `.backingOff`, `.idle`, and `.closed` + case starting + /// The pool is waiting to retry establishing a connection. Valid transitions are to: `.closed`. + /// This means, the connection can be removed from the connections without cancelling external + /// state. The connection state can then be replaced by a new one. + case backingOff(Timer) + /// The connection is `idle` and ready to execute a new query. Valid transitions to: `.pingpong`, `.leased`, + /// `.closing` and `.closed` + case idle(Connection, maxStreams: UInt16, keepAlive: KeepAlive, idleTimer: Timer?) + /// The connection is leased and executing a query. Valid transitions to: `.idle` and `.closed` + case leased(Connection, usedStreams: UInt16, maxStreams: UInt16, keepAlive: KeepAlive) + /// The connection is closing. Valid transitions to: `.closed` + case closing(Connection) + /// The connection is closed. Final state. + case closed + } + + @usableFromInline + let id: Connection.ID + + @usableFromInline + private(set) var state: State = .starting + + @usableFromInline + private(set) var nextTimerID: Int = 0 + + @inlinable + init(id: Connection.ID) { + self.id = id + } + + @inlinable + var isIdle: Bool { + switch self.state { + case .idle(_, _, .notScheduled, _), .idle(_, _, .scheduled, _): + return true + case .idle(_, _, .running, _): + return false + case .backingOff, .starting, .closed, .closing, .leased: + return false + } + } + + @inlinable + var isAvailable: Bool { + switch self.state { + case .idle(_, let maxStreams, .running(true), _): + return maxStreams > 1 + case .idle(_, let maxStreams, let keepAlive, _): + return keepAlive.usedStreams < maxStreams + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + return usedStreams + keepAlive.usedStreams < maxStreams + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @inlinable + var isLeased: Bool { + switch self.state { + case .leased: + return true + case .backingOff, .starting, .closed, .closing, .idle: + return false + } + } + + @inlinable + var isConnected: Bool { + switch self.state { + case .idle, .leased: + return true + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @inlinable + mutating func connected(_ connection: Connection, maxStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .starting: + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: nil) + return .idle(availableStreams: maxStreams, newIdle: true) + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct NewMaxStreamInfo { + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(newMaxStreams: UInt16, oldMaxStreams: UInt16, usedStreams: UInt16) { + self.newMaxStreams = newMaxStreams + self.oldMaxStreams = oldMaxStreams + self.usedStreams = usedStreams + } + } + + @inlinable + mutating func newMaxStreamSetting(_ newMaxStreams: UInt16) -> NewMaxStreamInfo? { + switch self.state { + case .starting, .backingOff: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let oldMaxStreams, let keepAlive, idleTimer: let idleTimer): + self.state = .idle(connection, maxStreams: newMaxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: keepAlive.usedStreams + ) + + case .leased(let connection, let usedStreams, let oldMaxStreams, let keepAlive): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: newMaxStreams, keepAlive: keepAlive) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: usedStreams + keepAlive.usedStreams + ) + + case .closing, .closed: + return nil + } + } + + + @inlinable + mutating func parkConnection(scheduleKeepAliveTimer: Bool, scheduleIdleTimeoutTimer: Bool) -> Max2Sequence { + var keepAliveTimer: ConnectionTimer? + var keepAliveTimerState: State.Timer? + var idleTimer: ConnectionTimer? + var idleTimerState: State.Timer? + + switch self.state { + case .backingOff, .starting, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let maxStreams, .notScheduled, .none): + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self._nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + if scheduleIdleTimeoutTimer { + idleTimerState = self._nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .idleTimeout) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, .scheduled, .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + + case .idle(let connection, let maxStreams, .notScheduled, let idleTimer): + precondition(!scheduleIdleTimeoutTimer) + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self._nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return Max2Sequence(keepAliveTimer) + + case .idle(let connection, let maxStreams, .scheduled(let keepAliveTimer), .none): + precondition(!scheduleKeepAliveTimer) + + if scheduleIdleTimeoutTimer { + idleTimerState = self._nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimer), idleTimer: idleTimerState) + return Max2Sequence(idleTimer, nil) + + case .idle(let connection, let maxStreams, keepAlive: .running(let usingStream), idleTimer: .none): + if scheduleIdleTimeoutTimer { + idleTimerState = self._nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(usingStream), idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, keepAlive: .running(_), idleTimer: .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + } + } + + /// The connection failed to start + @inlinable + mutating func failedToConnect() -> ConnectionTimer { + switch self.state { + case .starting: + let backoffTimerState = self._nextTimer() + self.state = .backingOff(backoffTimerState) + return ConnectionTimer(timerID: backoffTimerState.timerID, connectionID: self.id, usecase: .backoff) + + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + /// Moves a connection, that has previously ``failedToConnect()`` back into the connecting state. + /// + /// - Returns: A ``TimerCancellationToken`` that was previously registered with the state machine + /// for the ``ConnectionTimer`` returned in ``failedToConnect()``. If no token was registered + /// nil is returned. + @inlinable + mutating func retryConnect() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .starting + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func destroyBackingOffConnection() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .closed + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct LeaseAction { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence, wasIdle: Bool) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + } + } + + @inlinable + mutating func lease(streams newLeasedStreams: UInt16 = 1) -> LeaseAction { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimer): + var cancel = Max2Sequence() + if let token = idleTimer?.cancellationContinuation { + cancel.append(token) + } + if let token = keepAlive.cancelTimerIfScheduled() { + cancel.append(token) + } + precondition(maxStreams >= newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: cancel, wasIdle: true) + + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(maxStreams >= usedStreams + newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: usedStreams + newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: .init(), wasIdle: false) + + case .backingOff, .starting, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func release(streams returnedStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(usedStreams >= returnedStreams) + let newUsedStreams = usedStreams - returnedStreams + let availableStreams = maxStreams - (newUsedStreams + keepAlive.usedStreams) + if newUsedStreams == 0 { + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return .idle(availableStreams: availableStreams, newIdle: true) + } else { + self.state = .leased(connection, usedStreams: newUsedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return .leased(availableStreams: availableStreams) + } + case .backingOff, .starting, .idle, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func runKeepAliveIfIdle(reducesAvailableStreams: Bool) -> KeepAliveAction? { + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(let timer), let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(reducesAvailableStreams), idleTimer: idleTimer) + return KeepAliveAction( + connection: connection, + keepAliveTimerCancellationContinuation: timer.cancellationContinuation + ) + + case .leased, .closed, .closing: + return nil + + case .backingOff, .starting, .idle(_, _, .running, _), .idle(_, _, .notScheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func keepAliveSucceeded() -> ConnectionAvailableInfo? { + switch self.state { + case .idle(let connection, let maxStreams, .running, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: idleTimer) + return .idle(availableStreams: maxStreams, newIdle: false) + + case .leased(let connection, let usedStreams, let maxStreams, .running): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: maxStreams, keepAlive: .notScheduled) + return .leased(availableStreams: maxStreams - usedStreams) + + case .closed, .closing: + return nil + + case .backingOff, .starting, + .leased(_, _, _, .notScheduled), + .leased(_, _, _, .scheduled), + .idle(_, _, .notScheduled, _), + .idle(_, _, .scheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func keepAliveFailed() -> CloseAction? { + return self.close() + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + switch timer.usecase { + case .backoff: + switch self.state { + case .backingOff(var timerState): + if timerState.timerID == timer.timerID { + timerState.registerCancellationContinuation(cancelContinuation) + self.state = .backingOff(timerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .idle, .leased, .closing, .closed: + return cancelContinuation + } + + case .idleTimeout: + switch self.state { + case .idle(let connection, let maxStreams, let keepAlive, let idleTimerState): + if var idleTimerState = idleTimerState, idleTimerState.timerID == timer.timerID { + idleTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed: + return cancelContinuation + } + + case .keepAlive: + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(var keepAliveTimerState), let idleTimerState): + if keepAliveTimerState.timerID == timer.timerID { + keepAliveTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimerState), idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed, + .idle(_, _, .running, _), + .idle(_, _, .notScheduled, _): + return cancelContinuation + } + } + } + + @inlinable + mutating func cancelIdleTimer() -> TimerCancellationToken? { + switch self.state { + case .starting, .backingOff, .leased, .closing, .closed: + return nil + + case .idle(let connection, let maxStreams, let keepAlive, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return idleTimer?.cancellationContinuation + } + } + + @usableFromInline + struct CloseAction { + + @usableFromInline + enum PreviousConnectionState { + case idle + case leased + case closing + case backingOff + } + + @usableFromInline + var connection: Connection? + @usableFromInline + var previousConnectionState: PreviousConnectionState + @usableFromInline + var cancelTimers: Max2Sequence + @usableFromInline + var usedStreams: UInt16 + @usableFromInline + var maxStreams: UInt16 + @usableFromInline + var runningKeepAlive: Bool + + + @inlinable + init( + connection: Connection?, + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + usedStreams: UInt16, + maxStreams: UInt16, + runningKeepAlive: Bool + ) { + self.connection = connection + self.previousConnectionState = previousConnectionState + self.cancelTimers = cancelTimers + self.usedStreams = usedStreams + self.maxStreams = maxStreams + self.runningKeepAlive = runningKeepAlive + } + } + + @inlinable + mutating func closeIfIdle() -> CloseAction? { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .idle, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + usedStreams: keepAlive.usedStreams, + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning + ) + + case .leased, .closed: + return nil + + case .backingOff, .starting, .closing: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func close() -> CloseAction? { + switch self.state { + case .starting: + // If we are currently starting, there is nothing we can do about it right now. + // Only once the connection has come up, or failed, we can actually act. + return nil + + case .closing, .closed: + // If we are already closing, we can't do anything else. + return nil + + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .idle, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + usedStreams: keepAlive.usedStreams, + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning + ) + + case .leased(let connection, usedStreams: let usedStreams, maxStreams: let maxStreams, var keepAlive): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .leased, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled() + ), + usedStreams: keepAlive.usedStreams + usedStreams, + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning + ) + + case .backingOff(let timer): + self.state = .closed + return CloseAction( + connection: nil, + previousConnectionState: .backingOff, + cancelTimers: Max2Sequence(timer.cancellationContinuation), + usedStreams: 0, + maxStreams: 0, + runningKeepAlive: false + ) + } + } + + @usableFromInline + struct ClosedAction { + + @usableFromInline + enum PreviousConnectionState { + case idle + case leased + case closing + } + + @usableFromInline + var previousConnectionState: PreviousConnectionState + @usableFromInline + var cancelTimers: Max2Sequence + @usableFromInline + var maxStreams: UInt16 + @usableFromInline + var usedStreams: UInt16 + @usableFromInline + var wasRunningKeepAlive: Bool + + @inlinable + init( + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + maxStreams: UInt16, + usedStreams: UInt16, + wasRunningKeepAlive: Bool + ) { + self.previousConnectionState = previousConnectionState + self.cancelTimers = cancelTimers + self.maxStreams = maxStreams + self.usedStreams = usedStreams + self.wasRunningKeepAlive = wasRunningKeepAlive + } + } + + @inlinable + mutating func closed() -> ClosedAction { + switch self.state { + case .starting, .backingOff, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(_, let maxStreams, var keepAlive, let idleTimer): + self.state = .closed + return ClosedAction( + previousConnectionState: .idle, + cancelTimers: .init(keepAlive.cancelTimerIfScheduled(), idleTimer?.cancellationContinuation), + maxStreams: maxStreams, + usedStreams: keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + self.state = .closed + return ClosedAction( + previousConnectionState: .leased, + cancelTimers: .init(), + maxStreams: maxStreams, + usedStreams: usedStreams + keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .closing: + self.state = .closed + return ClosedAction( + previousConnectionState: .closing, + cancelTimers: .init(), + maxStreams: 0, + usedStreams: 0, + wasRunningKeepAlive: false + ) + } + } + + // MARK: - Private Methods - + + @inlinable + mutating /*private*/ func _nextTimer() -> State.Timer { + defer { self.nextTimerID += 1 } + return State.Timer(id: self.nextTimerID) + } + } + + @usableFromInline + enum ConnectionAvailableInfo: Equatable { + case leased(availableStreams: UInt16) + case idle(availableStreams: UInt16, newIdle: Bool) + + @usableFromInline + var availableStreams: UInt16 { + switch self { + case .leased(let availableStreams): + return availableStreams + case .idle(let availableStreams, newIdle: _): + return availableStreams + } + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.KeepAliveAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.connection === rhs.connection && lhs.keepAliveTimerCancellationContinuation == rhs.keepAliveTimerCancellationContinuation + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.LeaseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.wasIdle == rhs.wasIdle && lhs.connection === rhs.connection && lhs.timersToCancel == rhs.timersToCancel + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.CloseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.cancelTimers == rhs.cancelTimers && lhs.connection === rhs.connection && lhs.maxStreams == rhs.maxStreams + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift new file mode 100644 index 00000000..99ec4896 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -0,0 +1,71 @@ +import DequeModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + /// A request queue, which can enqueue requests in O(1), dequeue requests in O(1) and even cancel requests in O(1). + /// + /// While enqueueing and dequeueing on O(1) is trivial, cancellation is hard, as it normally requires a removal within the + /// underlying Deque. However thanks to having an additional `requests` dictionary, we can remove the cancelled + /// request from the dictionary and keep it inside the queue. Whenever we pop a request from the deque, we validate + /// that it hasn't been cancelled in the meantime by checking if the popped request is still in the `requests` dictionary. + @usableFromInline + struct RequestQueue: Sendable { + @usableFromInline + private(set) var queue: Deque + + @usableFromInline + private(set) var requests: [RequestID: Request] + + @inlinable + var count: Int { + self.requests.count + } + + @inlinable + var isEmpty: Bool { + self.count == 0 + } + + @usableFromInline + init() { + self.queue = .init(minimumCapacity: 256) + self.requests = .init(minimumCapacity: 256) + } + + @inlinable + mutating func queue(_ request: Request) { + self.requests[request.id] = request + self.queue.append(request.id) + } + + @inlinable + mutating func pop(max: UInt16) -> TinyFastSequence { + var result = TinyFastSequence() + result.reserveCapacity(Int(max)) + var popped = 0 + while popped < max, let requestID = self.queue.popFirst() { + if let requestIndex = self.requests.index(forKey: requestID) { + popped += 1 + result.append(self.requests.remove(at: requestIndex).value) + } + } + + assert(result.count <= max) + return result + } + + @inlinable + mutating func remove(_ requestID: RequestID) -> Request? { + self.requests.removeValue(forKey: requestID) + } + + @inlinable + mutating func removeAll() -> TinyFastSequence { + let result = TinyFastSequence(self.requests.values) + self.requests.removeAll() + self.queue.removeAll() + return result + } + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift new file mode 100644 index 00000000..6e41f730 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -0,0 +1,635 @@ +#if canImport(Darwin) +import Darwin +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#endif + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolConfiguration: Sendable { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes idle connections, + /// the `ConnectionPool` will initiate new outbound connections proactively + /// to avoid the number of available connections dropping below this number. + @usableFromInline + var minimumConnectionCount: Int = 0 + + /// The maximum number of connections to for this pool, to be preserved. + @usableFromInline + var maximumConnectionSoftLimit: Int = 10 + + @usableFromInline + var maximumConnectionHardLimit: Int = 10 + + @usableFromInline + var keepAliveDuration: Duration? + + @usableFromInline + var idleTimeoutDuration: Duration = .seconds(30) +} + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolStateMachine< + Connection: PooledConnection, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + ConnectionID: Hashable & Sendable, + Request: ConnectionRequestProtocol, + RequestID, + TimerCancellationToken: Sendable +>: Sendable where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + + @usableFromInline + struct ConnectionRequest: Hashable, Sendable { + @usableFromInline var connectionID: ConnectionID + + @inlinable + init(connectionID: ConnectionID) { + self.connectionID = connectionID + } + } + + @usableFromInline + struct Action { + @usableFromInline let request: RequestAction + @usableFromInline let connection: ConnectionAction + + @inlinable + init(request: RequestAction, connection: ConnectionAction) { + self.request = request + self.connection = connection + } + + @inlinable + static func none() -> Action { Action(request: .none, connection: .none) } + } + + @usableFromInline + enum ConnectionAction { + @usableFromInline + struct Shutdown { + @usableFromInline + var connections: [Connection] + @usableFromInline + var timersToCancel: [TimerCancellationToken] + + @inlinable + init() { + self.connections = [] + self.timersToCancel = [] + } + } + + case scheduleTimers(Max2Sequence) + case makeConnection(ConnectionRequest, TinyFastSequence) + case runKeepAlive(Connection, TimerCancellationToken?) + case cancelTimers(TinyFastSequence) + case closeConnection(Connection, Max2Sequence) + case shutdown(Shutdown) + + case none + } + + @usableFromInline + enum RequestAction { + case leaseConnection(TinyFastSequence, Connection) + + case failRequest(Request, ConnectionPoolError) + case failRequests(TinyFastSequence, ConnectionPoolError) + + case none + } + + @usableFromInline + enum PoolState: Sendable { + case running + case shuttingDown(graceful: Bool) + case shutDown + } + + @usableFromInline + struct Timer: Hashable, Sendable { + @usableFromInline + var underlying: ConnectionTimer + + @usableFromInline + var duration: Duration + + @inlinable + var connectionID: ConnectionID { + self.underlying.connectionID + } + + @inlinable + init(_ connectionTimer: ConnectionTimer, duration: Duration) { + self.underlying = connectionTimer + self.duration = duration + } + } + + @usableFromInline let configuration: PoolConfiguration + @usableFromInline let generator: ConnectionIDGenerator + + @usableFromInline + private(set) var connections: ConnectionGroup + @usableFromInline + private(set) var requestQueue: RequestQueue + @usableFromInline + private(set) var poolState: PoolState = .running + @usableFromInline + private(set) var cacheNoMoreConnectionsAllowed: Bool = false + + @usableFromInline + private(set) var failedConsecutiveConnectionAttempts: Int = 0 + + @inlinable + init( + configuration: PoolConfiguration, + generator: ConnectionIDGenerator, + timerCancellationTokenType: TimerCancellationToken.Type + ) { + self.configuration = configuration + self.generator = generator + self.connections = ConnectionGroup( + generator: generator, + minimumConcurrentConnections: configuration.minimumConnectionCount, + maximumConcurrentConnectionSoftLimit: configuration.maximumConnectionSoftLimit, + maximumConcurrentConnectionHardLimit: configuration.maximumConnectionHardLimit, + keepAlive: configuration.keepAliveDuration != nil, + keepAliveReducesAvailableStreams: true + ) + self.requestQueue = RequestQueue() + } + + mutating func refillConnections() -> [ConnectionRequest] { + return self.connections.refillConnections() + } + + @inlinable + mutating func leaseConnection(_ request: Request) -> Action { + switch self.poolState { + case .running: + break + + case .shuttingDown, .shutDown: + return .init( + request: .failRequest(request, ConnectionPoolError.poolShutdown), + connection: .none + ) + } + + if !self.requestQueue.isEmpty && self.cacheNoMoreConnectionsAllowed { + self.requestQueue.queue(request) + return .none() + } + + var soonAvailable: UInt16 = 0 + + // check if any other EL has an idle connection + switch self.connections.leaseConnectionOrSoonAvailableConnectionCount() { + case .leasedConnection(let leaseResult): + return .init( + request: .leaseConnection(TinyFastSequence(element: request), leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + + case .startingCount(let count): + soonAvailable += count + } + + // we tried everything. there is no connection available. now we must check, if and where we + // can create further connections. but first we must enqueue the new request + + self.requestQueue.queue(request) + + let requestAction = RequestAction.none + + if soonAvailable >= self.requestQueue.count { + // if more connections will be soon available then we have waiters, we don't need to + // create further new connections. + return .init( + request: requestAction, + connection: .none + ) + } else if let request = self.connections.createNewDemandConnectionIfPossible() { + // Can we create a demand connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else if let request = self.connections.createNewOverflowConnectionIfPossible() { + // Can we create an overflow connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else { + self.cacheNoMoreConnectionsAllowed = true + + // no new connections allowed: + return .init(request: requestAction, connection: .none) + } + } + + @inlinable + mutating func releaseConnection(_ connection: Connection, streams: UInt16) -> Action { + guard let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) else { + return .none() + } + return self.handleAvailableConnection(index: index, availableContext: context) + } + + mutating func cancelRequest(id: RequestID) -> Action { + guard let request = self.requestQueue.remove(id) else { + return .none() + } + + return .init( + request: .failRequest(request, ConnectionPoolError.requestCancelled), + connection: .none + ) + } + + @inlinable + mutating func connectionEstablished(_ connection: Connection, maxStreams: UInt16) -> Action { + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) + return self.handleAvailableConnection(index: index, availableContext: context) + case .shuttingDown(graceful: false), .shutDown: + return .init(request: .none, connection: .closeConnection(connection, [])) + } + } + + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connection: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> Action { + guard let info = self.connections.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: maxStreams) else { + return .none() + } + + let waitingRequests = self.requestQueue.count + + guard waitingRequests > 0 else { + return .none() + } + + // the only thing we can do if we receive a new max stream setting is check if the new stream + // setting is higher and then dequeue some waiting requests + + guard info.newMaxStreams > info.oldMaxStreams && info.newMaxStreams > info.usedStreams else { + return .none() + } + + let leaseStreams = min(info.newMaxStreams - info.oldMaxStreams, info.newMaxStreams - info.usedStreams, UInt16(clamping: waitingRequests)) + let requests = self.requestQueue.pop(max: leaseStreams) + precondition(Int(leaseStreams) == requests.count) + let leaseResult = self.connections.leaseConnection(at: info.index, streams: leaseStreams) + + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + + @inlinable + mutating func timerScheduled(_ timer: Timer, cancelContinuation: TimerCancellationToken) -> TimerCancellationToken? { + self.connections.timerScheduled(timer.underlying, cancelContinuation: cancelContinuation) + } + + @inlinable + mutating func timerTriggered(_ timer: Timer) -> Action { + switch timer.underlying.usecase { + case .backoff: + return self.connectionCreationBackoffDone(timer.connectionID) + case .keepAlive: + return self.connectionKeepAliveTimerTriggered(timer.connectionID) + case .idleTimeout: + return self.connectionIdleTimerTriggered(timer.connectionID) + } + } + + @inlinable + mutating func connectionEstablishFailed(_ error: Error, for request: ConnectionRequest) -> Action { + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.failedConsecutiveConnectionAttempts += 1 + + let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) + let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + let timer = Timer(connectionTimer, duration: backoff) + return .init(request: .none, connection: .scheduleTimers(.init(timer))) + + case .shuttingDown(graceful: false), .shutDown: + return .none() + } + } + + @inlinable + mutating func connectionCreationBackoffDone(_ connectionID: ConnectionID) -> Action { + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let soonAvailable = self.connections.soonAvailableConnections + let retry = (soonAvailable - 1) < self.requestQueue.count + + switch self.connections.backoffDone(connectionID, retry: retry) { + case .createConnection(let request, let continuation): + let timers: TinyFastSequence + if let continuation { + timers = .init(element: continuation) + } else { + timers = .init() + } + return .init(request: .none, connection: .makeConnection(request, timers)) + + case .cancelTimers(let timers): + return .init(request: .none, connection: .cancelTimers(.init(timers))) + } + + case .shuttingDown(graceful: false), .shutDown: + return .none() + } + } + + @inlinable + mutating func connectionKeepAliveTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + precondition(self.requestQueue.isEmpty) + + guard let keepAliveAction = self.connections.keepAliveIfIdle(connectionID) else { + return .none() + } + return .init(request: .none, connection: .runKeepAlive(keepAliveAction.connection, keepAliveAction.keepAliveTimerCancellationContinuation)) + } + + @inlinable + mutating func connectionKeepAliveDone(_ connection: Connection) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + guard let (index, context) = self.connections.keepAliveSucceeded(connection.id) else { + return .none() + } + return self.handleAvailableConnection(index: index, availableContext: context) + } + + @inlinable + mutating func connectionKeepAliveFailed(_ connectionID: ConnectionID) -> Action { + guard let closeAction = self.connections.keepAliveFailed(connectionID) else { + return .none() + } + + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + + @inlinable + mutating func connectionIdleTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.requestQueue.isEmpty) + + guard let closeAction = self.connections.closeConnectionIfIdle(connectionID) else { + return .none() + } + + self.cacheNoMoreConnectionsAllowed = false + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + + @inlinable + mutating func connectionClosed(_ connection: Connection) -> Action { + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.cacheNoMoreConnectionsAllowed = false + + let closedConnectionAction = self.connections.connectionClosed(connection.id) + + let connectionAction: ConnectionAction + if let newRequest = closedConnectionAction.newConnectionRequest { + connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) + } else { + connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) + } + + return .init(request: .none, connection: connectionAction) + + case .shuttingDown(graceful: false), .shutDown: + return .none() + } + } + + struct CleanupAction { + struct ConnectionToDrop { + var connection: Connection + var keepAliveTimer: Bool + var idleTimer: Bool + } + + var connections: [ConnectionToDrop] + var requests: [Request] + } + + mutating func triggerGracefulShutdown() -> Action { + fatalError("Unimplemented") + } + + mutating func triggerForceShutdown() -> Action { + switch self.poolState { + case .running: + self.poolState = .shuttingDown(graceful: false) + var shutdown = ConnectionAction.Shutdown() + self.connections.triggerForceShutdown(&shutdown) + + if shutdown.connections.isEmpty { + self.poolState = .shutDown + } + + return .init( + request: .failRequests(self.requestQueue.removeAll(), ConnectionPoolError.poolShutdown), + connection: .shutdown(shutdown) + ) + + case .shuttingDown: + return .none() + + case .shutDown: + return .init(request: .none, connection: .none) + } + } + + @inlinable + /*private*/ mutating func handleAvailableConnection( + index: Int, + availableContext: ConnectionGroup.AvailableConnectionContext + ) -> Action { + // this connection was busy before + let requests = self.requestQueue.pop(max: availableContext.info.availableStreams) + if !requests.isEmpty { + let leaseResult = self.connections.leaseConnection(at: index, streams: UInt16(requests.count)) + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + + switch availableContext.use { + case .persisted, .demand: + switch availableContext.info { + case .leased: + return .none() + + case .idle(_, let newIdle): + let timers = self.connections.parkConnection(at: index, hasBecomeIdle: newIdle).map(self.mapTimers) + + return .init( + request: .none, + connection: .scheduleTimers(timers) + ) + } + + case .overflow: + if let closeAction = self.connections.closeConnectionIfIdle(at: index) { + return .init( + request: .none, + connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) + ) + } else { + return .none() + } + } + + } + + @inlinable + /* private */ func mapTimers(_ connectionTimer: ConnectionTimer) -> Timer { + switch connectionTimer.usecase { + case .backoff: + return Timer( + connectionTimer, + duration: Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + ) + + case .keepAlive: + return Timer(connectionTimer, duration: self.configuration.keepAliveDuration!) + + case .idleTimeout: + return Timer(connectionTimer, duration: self.configuration.idleTimeoutDuration) + + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + /// Calculates the delay for the next connection attempt after the given number of failed `attempts`. + /// + /// Our backoff formula is: 100ms * 1.25^(attempts - 1) with 3% jitter that is capped of at 1 minute. + /// This means for: + /// - 1 failed attempt : 100ms + /// - 5 failed attempts: ~300ms + /// - 10 failed attempts: ~930ms + /// - 15 failed attempts: ~2.84s + /// - 20 failed attempts: ~8.67s + /// - 25 failed attempts: ~26s + /// - 29 failed attempts: ~60s (max out) + /// + /// - Parameter attempts: number of failed attempts in a row + /// - Returns: time to wait until trying to establishing a new connection + @usableFromInline + static func calculateBackoff(failedAttempt attempts: Int) -> Duration { + // Our backoff formula is: 100ms * 1.25^(attempts - 1) that is capped of at 1minute + // This means for: + // - 1 failed attempt : 100ms + // - 5 failed attempts: ~300ms + // - 10 failed attempts: ~930ms + // - 15 failed attempts: ~2.84s + // - 20 failed attempts: ~8.67s + // - 25 failed attempts: ~26s + // - 29 failed attempts: ~60s (max out) + + let start = Double(100_000_000) + let backoffNanosecondsDouble = start * pow(1.25, Double(attempts - 1)) + + // Cap to 60s _before_ we convert to Int64, to avoid trapping in the Int64 initializer. + let backoffNanoseconds = Int64(min(backoffNanosecondsDouble, Double(60_000_000_000))) + + let backoff = Duration.nanoseconds(backoffNanoseconds) + + // Calculate a 3% jitter range + let jitterRange = (backoffNanoseconds / 100) * 3 + // Pick a random element from the range +/- jitter range. + let jitter: Duration = .nanoseconds((-jitterRange...jitterRange).randomElement()!) + let jitteredBackoff = backoff + jitter + return jitteredBackoff + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.Action: Equatable where TimerCancellationToken: Equatable, Request: Equatable {} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.scheduleTimers(let lhs), .scheduleTimers(let rhs)): + return lhs == rhs + case (.makeConnection(let lhsRequest, let lhsToken), .makeConnection(let rhsRequest, let rhsToken)): + return lhsRequest == rhsRequest && lhsToken == rhsToken + case (.runKeepAlive(let lhsConn, let lhsToken), .runKeepAlive(let rhsConn, let rhsToken)): + return lhsConn === rhsConn && lhsToken == rhsToken + case (.closeConnection(let lhsConn, let lhsTimers), .closeConnection(let rhsConn, let rhsTimers)): + return lhsConn === rhsConn && lhsTimers == rhsTimers + case (.shutdown(let lhs), .shutdown(let rhs)): + return lhs == rhs + case (.cancelTimers(let lhs), .cancelTimers(let rhs)): + return lhs == rhs + case (.none, .none), + (.cancelTimers([]), .none), (.none, .cancelTimers([])), + (.scheduleTimers([]), .none), (.none, .scheduleTimers([])): + return true + default: + return false + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction.Shutdown: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + Set(lhs.connections.lazy.map(\.id)) == Set(rhs.connections.lazy.map(\.id)) && lhs.timersToCancel == rhs.timersToCancel + } +} + + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.RequestAction: Equatable where Request: Equatable { + + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.leaseConnection(let lhsRequests, let lhsConn), .leaseConnection(let rhsRequests, let rhsConn)): + guard lhsRequests.count == rhsRequests.count else { return false } + var lhsIterator = lhsRequests.makeIterator() + var rhsIterator = rhsRequests.makeIterator() + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.id == rhsNext.id else { return false } + } + return lhsConn === rhsConn + + case (.failRequest(let lhsRequest, let lhsError), .failRequest(let rhsRequest, let rhsError)): + return lhsRequest.id == rhsRequest.id && lhsError == rhsError + + case (.failRequests(let lhsRequests, let lhsError), .failRequests(let rhsRequests, let rhsError)): + return Set(lhsRequests.lazy.map(\.id)) == Set(rhsRequests.lazy.map(\.id)) && lhsError == rhsError + + case (.none, .none): + return true + + default: + return false + } + } +} diff --git a/Sources/ConnectionPoolModule/TinyFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift new file mode 100644 index 00000000..dff8a30b --- /dev/null +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -0,0 +1,205 @@ +/// A `Sequence` that does not heap allocate, if it only carries a single element +@usableFromInline +struct TinyFastSequence: Sequence { + @usableFromInline + enum Base { + case none(reserveCapacity: Int) + case one(Element, reserveCapacity: Int) + case two(Element, Element, reserveCapacity: Int) + case n([Element]) + } + + @usableFromInline + private(set) var base: Base + + @inlinable + init() { + self.base = .none(reserveCapacity: 0) + } + + @inlinable + init(element: Element) { + self.base = .one(element, reserveCapacity: 1) + } + + @inlinable + init(_ collection: some Collection) { + switch collection.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(collection.first!, reserveCapacity: 0) + default: + if let collection = collection as? Array { + self.base = .n(collection) + } else { + self.base = .n(Array(collection)) + } + } + } + + @inlinable + init(_ max2Sequence: Max2Sequence) { + switch max2Sequence.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(max2Sequence.first!, reserveCapacity: 0) + case 2: + self.base = .n(Array(max2Sequence)) + default: + fatalError() + } + } + + @usableFromInline + var count: Int { + switch self.base { + case .none: + return 0 + case .one: + return 1 + case .two: + return 2 + case .n(let array): + return array.count + } + } + + @inlinable + var first: Element? { + switch self.base { + case .none: + return nil + case .one(let element, _): + return element + case .two(let first, _, _): + return first + case .n(let array): + return array.first + } + } + + @usableFromInline + var isEmpty: Bool { + switch self.base { + case .none: + return true + case .one, .two, .n: + return false + } + } + + @inlinable + mutating func reserveCapacity(_ minimumCapacity: Int) { + switch self.base { + case .none(let reservedCapacity): + self.base = .none(reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .one(let element, let reservedCapacity): + self.base = .one(element, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .two(let first, let second, let reservedCapacity): + self.base = .two(first, second, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .n(var array): + self.base = .none(reserveCapacity: 0) // prevent CoW + array.reserveCapacity(minimumCapacity) + self.base = .n(array) + } + } + + @inlinable + mutating func append(_ element: Element) { + switch self.base { + case .none(let reserveCapacity): + self.base = .one(element, reserveCapacity: reserveCapacity) + case .one(let first, let reserveCapacity): + self.base = .two(first, element, reserveCapacity: reserveCapacity) + + case .two(let first, let second, let reserveCapacity): + var new = [Element]() + new.reserveCapacity(Swift.max(4, reserveCapacity)) + new.append(first) + new.append(second) + new.append(element) + self.base = .n(new) + + case .n(var existing): + self.base = .none(reserveCapacity: 0) // prevent CoW + existing.append(element) + self.base = .n(existing) + } + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(self) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline private(set) var index: Int = 0 + @usableFromInline private(set) var backing: TinyFastSequence + + @inlinable + init(_ backing: TinyFastSequence) { + self.backing = backing + } + + @inlinable + mutating func next() -> Element? { + switch self.backing.base { + case .none: + return nil + case .one(let element, _): + if self.index == 0 { + self.index += 1 + return element + } + return nil + + case .two(let first, let second, _): + defer { self.index += 1 } + switch self.index { + case 0: + return first + case 1: + return second + default: + return nil + } + + case .n(let array): + if self.index < array.endIndex { + defer { self.index += 1} + return array[self.index] + } + return nil + } + } + } +} + +extension TinyFastSequence: Equatable where Element: Equatable {} +extension TinyFastSequence.Base: Equatable where Element: Equatable {} + +extension TinyFastSequence: Hashable where Element: Hashable {} +extension TinyFastSequence.Base: Hashable where Element: Hashable {} + +extension TinyFastSequence: Sendable where Element: Sendable {} +extension TinyFastSequence.Base: Sendable where Element: Sendable {} + +extension TinyFastSequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + var iterator = elements.makeIterator() + switch elements.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(iterator.next()!, reserveCapacity: 0) + case 2: + self.base = .two(iterator.next()!, iterator.next()!, reserveCapacity: 0) + default: + self.base = .n(elements) + } + } +} diff --git a/Sources/ConnectionPoolTestUtils/MockClock.swift b/Sources/ConnectionPoolTestUtils/MockClock.swift new file mode 100644 index 00000000..34bf17e3 --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockClock.swift @@ -0,0 +1,176 @@ +import _ConnectionPoolModule +import Atomics +import DequeModule +import NIOConcurrencyHelpers + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class MockClock: Clock { + public struct Instant: InstantProtocol, Comparable { + public typealias Duration = Swift.Duration + + public func advanced(by duration: Self.Duration) -> Self { + .init(self.base + duration) + } + + public func duration(to other: Self) -> Self.Duration { + self.base - other.base + } + + private var base: Swift.Duration + + public init(_ base: Duration) { + self.base = base + } + + public static func < (lhs: Self, rhs: Self) -> Bool { + lhs.base < rhs.base + } + + public static func == (lhs: Self, rhs: Self) -> Bool { + lhs.base == rhs.base + } + } + + private struct State: Sendable { + var now: Instant + + var sleepersHeap: Array + + var waiters: Deque + var nextDeadlines: Deque + + init() { + self.now = .init(.seconds(0)) + self.sleepersHeap = Array() + self.waiters = Deque() + self.nextDeadlines = Deque() + } + } + + private struct Waiter { + var continuation: CheckedContinuation + } + + private struct Sleeper { + var id: Int + + var deadline: Instant + + var continuation: CheckedContinuation + } + + public typealias Duration = Swift.Duration + + public var minimumResolution: Duration { .nanoseconds(1) } + + public var now: Instant { self.stateBox.withLockedValue { $0.now } } + + private let stateBox = NIOLockedValueBox(State()) + private let waiterIDGenerator = ManagedAtomic(0) + + public init() {} + + public func sleep(until deadline: Instant, tolerance: Duration?) async throws { + let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + enum SleepAction { + case none + case resume + case cancel + } + + let action = self.stateBox.withLockedValue { state -> (SleepAction, Waiter?) in + let waiter: Waiter? + if let next = state.waiters.popFirst() { + waiter = next + } else { + state.nextDeadlines.append(deadline) + waiter = nil + } + + if Task.isCancelled { + return (.cancel, waiter) + } + + if state.now >= deadline { + return (.resume, waiter) + } + + let newSleeper = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) + + if let index = state.sleepersHeap.lastIndex(where: { $0.deadline < deadline }) { + state.sleepersHeap.insert(newSleeper, at: index + 1) + } else if let first = state.sleepersHeap.first, first.deadline > deadline { + state.sleepersHeap.insert(newSleeper, at: 0) + } else { + state.sleepersHeap.append(newSleeper) + } + + return (.none, waiter) + } + + switch action.0 { + case .cancel: + continuation.resume(throwing: CancellationError()) + case .resume: + continuation.resume() + case .none: + break + } + + action.1?.continuation.resume(returning: deadline) + } + } onCancel: { + let continuation = self.stateBox.withLockedValue { state -> CheckedContinuation? in + if let index = state.sleepersHeap.firstIndex(where: { $0.id == waiterID }) { + return state.sleepersHeap.remove(at: index).continuation + } + return nil + } + continuation?.resume(throwing: CancellationError()) + } + } + + @discardableResult + public func nextTimerScheduled() async -> Instant { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let instant = self.stateBox.withLockedValue { state -> Instant? in + if let scheduled = state.nextDeadlines.popFirst() { + return scheduled + } else { + let waiter = Waiter(continuation: continuation) + state.waiters.append(waiter) + return nil + } + } + + if let instant { + continuation.resume(returning: instant) + } + } + } + + public func advance(to deadline: Instant) { + let waiters = self.stateBox.withLockedValue { state -> ArraySlice in + precondition(deadline > state.now, "Time can only move forward") + state.now = deadline + + if let newFirstIndex = state.sleepersHeap.firstIndex(where: { $0.deadline > deadline }) { + defer { state.sleepersHeap.removeFirst(newFirstIndex) } + return state.sleepersHeap[0..], [@Sendable ((any Error)?) -> ()]) + case closing([@Sendable ((any Error)?) -> ()]) + case closed + } + + private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) + + public init(id: Int) { + self.id = id + } + + public var signalToClose: Void { + get async throws { + try await withCheckedThrowingContinuation { continuation in + let runRightAway = self.lock.withLockedValue { state -> Bool in + switch state { + case .running(var continuations, let callbacks): + continuations.append(continuation) + state = .running(continuations, callbacks) + return false + + case .closing, .closed: + return true + } + } + + if runRightAway { + continuation.resume() + } + } + } + } + + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + let enqueued = self.lock.withLockedValue { state -> Bool in + switch state { + case .closed: + return false + + case .running(let continuations, var callbacks): + callbacks.append(closure) + state = .running(continuations, callbacks) + return true + + case .closing(var callbacks): + callbacks.append(closure) + state = .closing(callbacks) + return true + } + } + + if !enqueued { + closure(nil) + } + } + + public func close() { + let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in + switch state { + case .running(let continuations, let callbacks): + state = .closing(callbacks) + return continuations + + case .closing, .closed: + return [] + } + } + + for continuation in continuations { + continuation.resume() + } + } + + public func closeIfClosing() { + let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in + switch state { + case .running, .closed: + return [] + + case .closing(let callbacks): + state = .closed + return callbacks + } + } + + for callback in callbacks { + callback(nil) + } + } +} + +extension MockConnection: CustomStringConvertible { + public var description: String { + let state = self.lock.withLockedValue { $0 } + return "MockConnection(id: \(self.id), state: \(state))" + } +} diff --git a/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift new file mode 100644 index 00000000..936b47cc --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift @@ -0,0 +1,108 @@ +import _ConnectionPoolModule +import DequeModule +import NIOConcurrencyHelpers + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class MockConnectionFactory: Sendable where Clock.Duration == Duration { + public typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + public typealias Request = ConnectionRequest + public typealias KeepAliveBehavior = MockPingPongBehavior + public typealias MetricsDelegate = NoOpConnectionPoolMetrics + public typealias ConnectionID = Int + public typealias Connection = MockConnection + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() + + var waiter = Deque), Never>>() + + var runningConnections = [ConnectionID: Connection]() + } + + let autoMaxStreams: UInt16? + + public init(autoMaxStreams: UInt16? = nil) { + self.autoMaxStreams = autoMaxStreams + } + + public var pendingConnectionAttemptsCount: Int { + self.stateBox.withLockedValue { $0.attempts.count } + } + + public var runningConnections: [Connection] { + self.stateBox.withLockedValue { Array($0.runningConnections.values) } + } + + public func makeConnection( + id: Int, + for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> + ) async throws -> ConnectionAndMetadata { + if let autoMaxStreams = self.autoMaxStreams { + let connection = MockConnection(id: id) + Task { + try? await connection.signalToClose + connection.closeIfClosing() + } + return .init(connection: connection, maximalStreamsOnConnection: autoMaxStreams) + } + + // we currently don't support cancellation when creating a connection + let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.attempts.append((id, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (id, checkedContinuation)) + } + } + + return .init(connection: result.0, maximalStreamsOnConnection: result.1) + } + + @discardableResult + public func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in + let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in + if let attempt = state.attempts.popFirst() { + return attempt + } else { + state.waiter.append(continuation) + return nil + } + } + + if let attempt { + continuation.resume(returning: attempt) + } + } + + do { + let streamCount = try await closure(connectionID) + let connection = MockConnection(id: connectionID) + + connection.onClose { _ in + self.stateBox.withLockedValue { state in + _ = state.runningConnections.removeValue(forKey: connectionID) + } + } + + self.stateBox.withLockedValue { state in + _ = state.runningConnections[connectionID] = connection + } + + continuation.resume(returning: (connection, streamCount)) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift new file mode 100644 index 00000000..de1a7275 --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -0,0 +1,70 @@ +import _ConnectionPoolModule +import DequeModule +import NIOConcurrencyHelpers + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { + public let keepAliveFrequency: Duration? + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var runs = Deque<(Connection, CheckedContinuation)>() + + var waiter = Deque), Never>>() + } + + public init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { + self.keepAliveFrequency = keepAliveFrequency + } + + public func runKeepAlive(for connection: Connection) async throws { + precondition(self.keepAliveFrequency != nil) + + // we currently don't support cancellation when creating a connection + let success = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation) -> () in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(Connection, CheckedContinuation), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.runs.append((connection, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (connection, checkedContinuation)) + } + } + + precondition(success) + } + + @discardableResult + public func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in + let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in + if let run = state.runs.popFirst() { + return run + } else { + state.waiter.append(continuation) + return nil + } + } + + if let run { + continuation.resume(returning: run) + } + } + + do { + let success = try await closure(connection) + + continuation.resume(returning: success) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift new file mode 100644 index 00000000..5e4e2fc0 --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -0,0 +1,29 @@ +import _ConnectionPoolModule + +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + public typealias Connection = MockConnection + + public struct ID: Hashable, Sendable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + public init() {} + + public var id: ID { ID(self) } + + public static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + public func complete(with: Result) { + + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift deleted file mode 100644 index ac2c896c..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift +++ /dev/null @@ -1,109 +0,0 @@ -import Crypto -import NIO -import Logging - -extension PostgresConnection { - public func authenticate( - username: String, - database: String? = nil, - password: String? = nil, - logger: Logger = .init(label: "codes.vapor.postgres") - ) -> EventLoopFuture { - let auth = PostgresAuthenticationRequest( - username: username, - database: database, - password: password - ) - return self.send(auth, logger: self.logger) - } -} - -// MARK: Private - -private final class PostgresAuthenticationRequest: PostgresRequest { - enum State { - case ready - case done - } - - let username: String - let database: String? - let password: String? - var state: State - - init(username: String, database: String?, password: String?) { - self.state = .ready - self.username = username - self.database = database - self.password = password - } - - func log(to logger: Logger) { - logger.debug("Logging into Postgres db \(self.database ?? "nil") as \(self.username)") - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - if case .error = message.identifier { - // terminate immediately on error - return nil - } - - switch self.state { - case .ready: - switch message.identifier { - case .authentication: - let auth = try PostgresMessage.Authentication(message: message) - switch auth { - case .md5(let salt): - let pwdhash = self.md5((self.password ?? "") + self.username).hexdigest() - let hash = "md5" + self.md5(self.bytes(pwdhash) + salt).hexdigest() - return try [PostgresMessage.Password(string: hash).message()] - case .plaintext: - return try [PostgresMessage.Password(string: self.password ?? "").message()] - case .ok: - self.state = .done - return [] - } - default: throw PostgresError.protocol("Unexpected response to start message: \(message)") - } - case .done: - switch message.identifier { - case .parameterStatus: - // self.status[status.parameter] = status.value - return [] - case .backendKeyData: - // self.processID = data.processID - // self.secretKey = data.secretKey - return [] - case .readyForQuery: - return nil - default: throw PostgresError.protocol("Unexpected response to password authentication: \(message)") - } - } - - } - - func start() throws -> [PostgresMessage] { - return try [ - PostgresMessage.Startup.versionThree(parameters: [ - "user": self.username, - "database": self.database ?? username - ]).message() - ] - } - - // MARK: Private - - private func md5(_ string: String) -> [UInt8] { - return md5(self.bytes(string)) - } - - private func md5(_ message: [UInt8]) -> [UInt8] { - let digest = Insecure.MD5.hash(data: message) - return .init(digest) - } - - func bytes(_ string: String) -> [UInt8] { - return Array(string.utf8) - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift new file mode 100644 index 00000000..b260723a --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -0,0 +1,294 @@ +import NIOCore +import NIOPosix // inet_pton() et al. +import NIOSSL + +extension PostgresConnection { + /// A configuration object for a connection + public struct Configuration: Sendable { + + // MARK: - TLS + + /// The possible modes of operation for TLS encapsulation of a connection. + public struct TLS: Sendable { + // MARK: Initializers + + /// Do not try to create a TLS connection to the server. + public static var disable: Self { .init(base: .disable) } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSLContext) -> Self { + self.init(base: .prefer(sslContext)) + } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSLContext) -> Self { + self.init(base: .require(sslContext)) + } + + // MARK: Accessors + + /// Whether TLS will be attempted on the connection (`false` only when mode is ``disable``). + public var isAllowed: Bool { + if case .disable = self.base { return false } + else { return true } + } + + /// Whether TLS will be enforced on the connection (`true` only when mode is ``require(_:)``). + public var isEnforced: Bool { + if case .require(_) = self.base { return true } + else { return false } + } + + /// The `NIOSSLContext` that will be used. `nil` when TLS is disabled. + public var sslContext: NIOSSLContext? { + switch self.base { + case .prefer(let context), .require(let context): return context + case .disable: return nil + } + } + + // MARK: Implementation details + + enum Base { + case disable + case prefer(NIOSSLContext) + case require(NIOSSLContext) + } + let base: Base + private init(base: Base) { self.base = base } + } + + // MARK: - Connection options + + /// Describes options affecting how the underlying connection is made. + public struct Options: Sendable { + /// A timeout for connection attempts. Defaults to ten seconds. + /// + /// Ignored when using a preexisting communcation channel. (See + /// ``PostgresConnection/Configuration/init(establishedChannel:username:password:database:)``.) + public var connectTimeout: TimeAmount + + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. + /// Defaults to none (but see below). + /// + /// > When set to `nil`: + /// If the connection is made to a server over TCP using + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other + /// method, SNI is disabled. + public var tlsServerName: String? + + /// Whether the connection is required to provide backend key data (internal Postgres stuff). + /// + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). + public var requireBackendKeyData: Bool + + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] + + /// Create an options structure with default values. + /// + /// Most users should not need to adjust the defaults. + public init() { + self.connectTimeout = .seconds(10) + self.tlsServerName = nil + self.requireBackendKeyData = true + self.additionalStartupParameters = [] + } + } + + // MARK: - Accessors + + /// The hostname to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var host: String? { + if case let .connectTCP(host, _) = self.endpointInfo { return host } + else { return nil } + } + + /// The port to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var port: Int? { + if case let .connectTCP(_, port) = self.endpointInfo { return port } + else { return nil } + } + + /// The socket path to connect to for Unix domain socket connections. + /// + /// Always `nil` for other configurations. + public var unixSocketPath: String? { + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } + else { return nil } + } + + /// The `Channel` to use in existing-channel configurations. + /// + /// Always `nil` for other configurations. + public var establishedChannel: Channel? { + if case let .configureChannel(channel) = self.endpointInfo { return channel } + else { return nil } + } + + /// The TLS mode to use for the connection. Valid for all configurations. + /// + /// See ``TLS-swift.struct``. + public var tls: TLS + + /// Options for handling the communication channel. Most users don't need to change these. + /// + /// See ``Options-swift.struct``. + public var options: Options = .init() + + /// The username to connect with. + public var username: String + + /// The password, if any, for the user specified by ``username``. + /// + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero + /// length; these are not the same thing. + public var password: String? + + /// The name of the database to open. + /// + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. + public var database: String? + + // MARK: - Initializers + + /// Create a configuration for connecting to a server with a hostname and optional port. + /// + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost + /// definitely want this one. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + /// - tls: The TLS mode to use. + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for connecting to a server through a UNIX domain socket. + /// + /// - Parameters: + /// - path: The filesystem path of the socket to connect to. + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(unixSocketPath: String, username: String, password: String?, database: String?) { + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). + /// - tls: The TLS mode to use. + public init(establishedChannel channel: Channel, tls: PostgresConnection.Configuration.TLS, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). + public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { + self.init(establishedChannel: channel, tls: .disable, username: username, password: password, database: database) + } + + // MARK: - Implementation details + + enum EndpointInfo { + case configureChannel(Channel) + case bindUnixDomainSocket(path: String) + case connectTCP(host: String, port: Int) + } + + var endpointInfo: EndpointInfo + + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { + self.endpointInfo = endpointInfo + self.tls = tls + self.username = username + self.password = password + self.database = database + } + } +} + +// MARK: - Internal config details + +extension PostgresConnection { + /// A configuration object to bring the new ``PostgresConnection.Configuration`` together with + /// the deprecated configuration. + /// + /// TODO: Drop with next major release + struct InternalConfiguration: Sendable { + enum Connection { + case unresolvedTCP(host: String, port: Int) + case unresolvedUDS(path: String) + case resolved(address: SocketAddress) + case bootstrapped(channel: Channel) + } + + let connection: InternalConfiguration.Connection + let username: String? + let password: String? + let database: String? + var tls: Configuration.TLS + let options: Configuration.Options + } +} + +extension PostgresConnection.InternalConfiguration { + init(_ config: PostgresConnection.Configuration) { + switch config.endpointInfo { + case .connectTCP(let host, let port): self.connection = .unresolvedTCP(host: host, port: port) + case .bindUnixDomainSocket(let path): self.connection = .unresolvedUDS(path: path) + case .configureChannel(let channel): self.connection = .bootstrapped(channel: channel) + } + self.username = config.username + self.password = config.password + self.database = config.database + self.tls = config.tls + self.options = config.options + } + + var serverNameForTLS: String? { + // If a name was explicitly configured, always use it. + if let tlsServerName = self.options.tlsServerName { return tlsServerName } + + // Otherwise, if the connection is TCP and the hostname wasn't an IP (not valid in SNI), use that. + if case .unresolvedTCP(let host, _) = self.connection, !host.isIPAddress() { return host } + + // Otherwise, disable SNI + return nil + } +} + +// originally taken from NIOSSL +private extension String { + func isIPAddress() -> Bool { + // We need some scratch space to let inet_pton write into. + var ipv4Addr = in_addr(), ipv6Addr = in6_addr() // inet_pton() assumes the provided address buffer is non-NULL + + /// N.B.: ``String/withCString(_:)`` is much more efficient than directly passing `self`, especially twice. + return self.withCString { ptr in + inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 + } + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift deleted file mode 100644 index c178becb..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ /dev/null @@ -1,51 +0,0 @@ -import Logging -import NIO - -extension PostgresConnection { - public static func connect( - to socketAddress: SocketAddress, - tlsConfiguration: TLSConfiguration? = nil, - serverHostname: String? = nil, - logger: Logger = .init(label: "codes.vapor.postgres"), - on eventLoop: EventLoop - ) -> EventLoopFuture { - let bootstrap = ClientBootstrap(group: eventLoop) - .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - return bootstrap.connect(to: socketAddress).flatMap { channel in - return channel.pipeline.addHandlers([ - ByteToMessageHandler(PostgresMessageDecoder(logger: logger)), - MessageToByteHandler(PostgresMessageEncoder(logger: logger)), - PostgresRequestHandler(logger: logger), - PostgresErrorHandler(logger: logger) - ]).map { - return PostgresConnection(channel: channel, logger: logger) - } - }.flatMap { (conn: PostgresConnection) in - if let tlsConfiguration = tlsConfiguration { - return conn.requestTLS( - using: tlsConfiguration, - serverHostname: serverHostname, - logger: logger - ).map { conn } - } else { - return eventLoop.makeSucceededFuture(conn) - } - } - } -} - - -private final class PostgresErrorHandler: ChannelInboundHandler { - typealias InboundIn = Never - - let logger: Logger - init(logger: Logger) { - self.logger = logger - } - - func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.error("Uncaught error: \(error)") - context.close(promise: nil) - context.fireErrorCaught(error) - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift deleted file mode 100644 index 9c6ce553..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ /dev/null @@ -1,127 +0,0 @@ -import Logging - -extension PostgresConnection: PostgresDatabase { - public func send( - _ request: PostgresRequest, - logger: Logger - ) -> EventLoopFuture { - request.log(to: logger) - let promise = self.channel.eventLoop.makePromise(of: Void.self) - let request = PostgresRequestContext(delegate: request, promise: promise) - self.channel.write(request).cascadeFailure(to: promise) - self.channel.flush() - return promise.futureResult - } - - public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { - closure(self) - } -} - -final class PostgresRequestContext { - let delegate: PostgresRequest - let promise: EventLoopPromise - var lastError: Error? - - init(delegate: PostgresRequest, promise: EventLoopPromise) { - self.delegate = delegate - self.promise = promise - } -} - -class PostgresRequestHandler: ChannelDuplexHandler { - typealias InboundIn = PostgresMessage - typealias OutboundIn = PostgresRequestContext - typealias OutboundOut = PostgresMessage - - private var queue: [PostgresRequestContext] - let logger: Logger - - public init(logger: Logger) { - self.queue = [] - self.logger = logger - } - - private func _channelRead(context: ChannelHandlerContext, data: NIOAny) throws { - let message = self.unwrapInboundIn(data) - guard self.queue.count > 0 else { - // discard packet - return - } - let request = self.queue[0] - - switch message.identifier { - case .error: - let error = try PostgresMessage.Error(message: message) - self.logger.error("\(error)") - request.lastError = PostgresError.server(error) - case .notice: - let notice = try PostgresMessage.Error(message: message) - self.logger.notice("\(notice)") - default: break - } - - if let responses = try request.delegate.respond(to: message) { - for response in responses { - context.write(self.wrapOutboundOut(response), promise: nil) - } - context.flush() - } else { - self.queue.removeFirst() - if let error = request.lastError { - request.promise.fail(error) - } else { - request.promise.succeed(()) - } - } - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - do { - try self._channelRead(context: context, data: data) - } catch { - self.errorCaught(context: context, error: error) - } - // Regardless of error, also pass the message downstream; this is necessary for PostgresNotificationHandler (which is appended at the end) to receive notifications - context.fireChannelRead(data) - } - - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let request = self.unwrapOutboundIn(data) - self.queue.append(request) - do { - let messages = try request.delegate.start() - self.write(context: context, items: messages, promise: promise) - context.flush() - } catch { - promise?.fail(error) - self.errorCaught(context: context, error: error) - } - } - - func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { - let terminate = try! PostgresMessage.Terminate().message() - context.write(self.wrapOutboundOut(terminate), promise: nil) - context.close(mode: mode, promise: promise) - - for current in self.queue { - current.promise.fail(PostgresError.connectionClosed) - } - self.queue = [] - } -} - - -extension ChannelInboundHandler { - func write(context: ChannelHandlerContext, items: [OutboundOut], promise: EventLoopPromise?) { - var items = items - if let last = items.popLast() { - for item in items { - context.write(self.wrapOutboundOut(item), promise: nil) - } - context.write(self.wrapOutboundOut(last), promise: promise) - } else { - promise?.succeed(()) - } - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift deleted file mode 100644 index 6acb721a..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift +++ /dev/null @@ -1,62 +0,0 @@ -import NIO -import Logging - -/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. -public final class PostgresListenContext { - var stopper: (() -> Void)? - - /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. - public func stop() { - stopper?() - stopper = nil - } -} - -extension PostgresConnection { - /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. - @discardableResult - public func addListener(channel: String, handler notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) -> PostgresListenContext { - let listenContext = PostgresListenContext() - let channelHandler = PostgresNotificationHandler(logger: self.logger, channel: channel, notificationHandler: notificationHandler, listenContext: listenContext) - let pipeline = self.channel.pipeline - _ = pipeline.addHandler(channelHandler, name: nil, position: .last) - listenContext.stopper = { [pipeline, unowned channelHandler] in - _ = pipeline.removeHandler(channelHandler) - } - return listenContext - } -} - -final class PostgresNotificationHandler: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = PostgresMessage - typealias InboundOut = PostgresMessage - - let logger: Logger - let channel: String - let notificationHandler: (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void - let listenContext: PostgresListenContext - - init(logger: Logger, channel: String, notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void, listenContext: PostgresListenContext) { - self.logger = logger - self.channel = channel - self.notificationHandler = notificationHandler - self.listenContext = listenContext - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let request = self.unwrapInboundIn(data) - // Slightly complicated: We need to dispatch downstream _before_ we handle the notification ourselves, because the notification handler could try to stop the listen, which removes ourselves from the pipeline and makes fireChannelRead not work any more. - context.fireChannelRead(self.wrapInboundOut(request)) - if request.identifier == .notificationResponse { - do { - var data = request.data - let notification = try PostgresMessage.NotificationResponse.parse(from: &data) - if notification.channel == channel { - self.notificationHandler(self.listenContext, notification) - } - } catch let error { - self.logger.error("\(error)") - } - } - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+RequestTLS.swift b/Sources/PostgresNIO/Connection/PostgresConnection+RequestTLS.swift deleted file mode 100644 index f9fab9ab..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+RequestTLS.swift +++ /dev/null @@ -1,53 +0,0 @@ -import NIOSSL -import Logging - -extension PostgresConnection { - internal func requestTLS( - using tlsConfig: TLSConfiguration, - serverHostname: String?, - logger: Logger - ) -> EventLoopFuture { - let tls = RequestTLSQuery() - return self.send(tls, logger: logger).flatMapThrowing { _ in - guard tls.isSupported else { - throw PostgresError.protocol("Server does not support TLS") - } - let sslContext = try NIOSSLContext(configuration: tlsConfig) - let handler = try NIOSSLClientHandler(context: sslContext, serverHostname: serverHostname) - _ = self.channel.pipeline.addHandler(handler, position: .first) - } - } -} - -// MARK: Private - -private final class RequestTLSQuery: PostgresRequest { - var isSupported: Bool - - init() { - self.isSupported = false - } - - func log(to logger: Logger) { - logger.debug("Requesting TLS") - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - switch message.identifier { - case .sslSupported: - self.isSupported = true - return nil - case .sslUnsupported: - self.isSupported = false - return nil - default: throw PostgresError.protocol("Unexpected message during TLS request: \(message)") - } - } - - func start() throws -> [PostgresMessage] { - return try [ - PostgresMessage.SSLRequest().message() - ] - } -} - diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7cc8d728..e267d8f9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,36 +1,867 @@ -import Foundation +import Atomics +import NIOCore +import NIOPosix +#if canImport(Network) +import NIOTransportServices +#endif +import NIOSSL import Logging -public final class PostgresConnection { - let channel: Channel - +/// A Postgres connection. Use it to run queries against a Postgres server. +/// +/// Thread safety is achieved by dispatching all access to shared state onto the underlying EventLoop. +public final class PostgresConnection: @unchecked Sendable { + /// A Postgres connection ID + public typealias ID = Int + + /// The connection's underlying channel + /// + /// This should be private, but it is needed for `PostgresConnection` compatibility. + internal let channel: Channel + + /// The underlying `EventLoop` of both the connection and its channel. public var eventLoop: EventLoop { return self.channel.eventLoop } - + public var closeFuture: EventLoopFuture { - return channel.closeFuture + return self.channel.closeFuture + } + + /// A logger to use in case + public var logger: Logger { + get { + self._logger + } + set { + // ignore + } } - - public var logger: Logger + + private let internalListenID = ManagedAtomic(0) public var isClosed: Bool { return !self.channel.isActive } - - init(channel: Channel, logger: Logger) { + + public let id: ID + + private var _logger: Logger + + init(channel: Channel, connectionID: ID, logger: Logger) { self.channel = channel - self.logger = logger + self.id = connectionID + self._logger = logger + } + deinit { + assert(self.isClosed, "PostgresConnection deinitialized before being closed.") + } + + func start(configuration: InternalConfiguration) -> EventLoopFuture { + // 1. configure handlers + + let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())? + + switch configuration.tls.base { + case .prefer(let context), .require(let context): + configureSSLCallback = { channel, postgresChannelHandler in + channel.eventLoop.assertInEventLoop() + + let sslHandler = try NIOSSLClientHandler( + context: context, + serverHostname: configuration.serverNameForTLS + ) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler)) + } + case .disable: + configureSSLCallback = nil + } + + let channelHandler = PostgresChannelHandler( + configuration: configuration, + eventLoop: channel.eventLoop, + logger: logger, + configureSSLCallback: configureSSLCallback + ) + + let eventHandler = PSQLEventsHandler(logger: logger) + + // 2. add handlers + + do { + try self.channel.pipeline.syncOperations.addHandler(eventHandler) + try self.channel.pipeline.syncOperations.addHandler(channelHandler, position: .before(eventHandler)) + } catch { + return self.eventLoop.makeFailedFuture(error) + } + + let startupFuture: EventLoopFuture + if configuration.username == nil { + startupFuture = eventHandler.readyForStartupFuture + } else { + startupFuture = eventHandler.authenticateFuture + } + + // 3. wait for startup future to succeed. + + return startupFuture.flatMapError { error in + // in case of an startup error, the connection must be closed and after that + // the originating error should be surfaced + + self.channel.closeFuture.flatMapThrowing { _ in + throw error + } + } + } + + /// Create a new connection to a Postgres server + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` the request shall be created on + /// - configuration: A ``Configuration`` that shall be used for the connection + /// - connectionID: An `Int` id, used for metadata logging + /// - logger: A logger to log background events into + /// - Returns: A SwiftNIO `EventLoopFuture` that will provide a ``PostgresConnection`` + /// at a later point in time. + public static func connect( + on eventLoop: EventLoop, + configuration: PostgresConnection.Configuration, + id connectionID: ID, + logger: Logger + ) -> EventLoopFuture { + self.connect( + connectionID: connectionID, + configuration: .init(configuration), + logger: logger, + on: eventLoop + ) + } + + static func connect( + connectionID: ID, + configuration: PostgresConnection.InternalConfiguration, + logger: Logger, + on eventLoop: EventLoop + ) -> EventLoopFuture { + + var mlogger = logger + mlogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + let logger = mlogger + + // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to + // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the + // callbacks). + // + // This saves us a number of context switches between the thread the Connection is created + // on and the EventLoop. In addition, it eliminates all potential races between the creating + // thread and the EventLoop. + return eventLoop.flatSubmit { () -> EventLoopFuture in + let connectFuture: EventLoopFuture + + switch configuration.connection { + case .resolved(let address): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) + connectFuture = bootstrap.connect(to: address) + case .unresolvedTCP(let host, let port): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) + connectFuture = bootstrap.connect(host: host, port: port) + case .unresolvedUDS(let path): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) + connectFuture = bootstrap.connect(unixDomainSocketPath: path) + case .bootstrapped(let channel): + guard channel.isActive else { + return eventLoop.makeFailedFuture(PSQLError.connectionError(underlying: ChannelError.alreadyClosed)) + } + connectFuture = eventLoop.makeSucceededFuture(channel) + } + + return connectFuture.flatMap { channel -> EventLoopFuture in + let connection = PostgresConnection(channel: channel, connectionID: connectionID, logger: logger) + return connection.start(configuration: configuration).map { _ in connection } + }.flatMapErrorThrowing { error -> PostgresConnection in + switch error { + case is PSQLError: + throw error + default: + throw PSQLError.connectionError(underlying: error) + } + } + } + } + + static func makeBootstrap( + on eventLoop: EventLoop, + configuration: PostgresConnection.InternalConfiguration + ) -> NIOClientTCPBootstrapProtocol { + #if canImport(Network) + if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + return tsBootstrap.connectTimeout(configuration.options.connectTimeout) + } + #endif + + if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + return nioBootstrap.connectTimeout(configuration.options.connectTimeout) + } + + fatalError("No matching bootstrap found") + } + + // MARK: Query + + private func queryStream(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + guard query.binds.count <= Int(UInt16.max) else { + return self.channel.eventLoop.makeFailedFuture(PSQLError(code: .tooManyParameters, query: query)) + } + + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise + ) + + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + + return promise.futureResult } - + + // MARK: Prepared statements + + func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { + let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) + let context = ExtendedQueryContext( + name: name, + query: query, + bindingDataTypes: [], + logger: logger, + promise: promise + ) + + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + return promise.futureResult.map { rowDescription in + PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) + } + } + + func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { + guard executeStatement.binds.count <= Int(UInt16.max) else { + return self.channel.eventLoop.makeFailedFuture(PSQLError(code: .tooManyParameters)) + } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + executeStatement: executeStatement, + logger: logger, + promise: promise) + + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + return promise.futureResult + } + + func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture { + let promise = self.channel.eventLoop.makePromise(of: Void.self) + let context = CloseCommandContext(target: target, logger: logger, promise: promise) + + self.channel.write(HandlerTask.closeCommand(context), promise: nil) + return promise.futureResult + } + + + /// Closes the connection to the server. + /// + /// - Returns: An EventLoopFuture that is succeeded once the connection is closed. public func close() -> EventLoopFuture { guard !self.isClosed else { return self.eventLoop.makeSucceededFuture(()) } - return self.channel.close(mode: .all) + + self.channel.close(mode: .all, promise: nil) + return self.closeFuture } - - deinit { - assert(self.isClosed, "PostgresConnection deinitialized before being closed.") +} + +// MARK: Connect + +extension PostgresConnection { + static let idGenerator = ManagedAtomic(0) + + @available(*, deprecated, + message: "Use the new connect method that allows you to connect and authenticate in a single step", + renamed: "connect(on:configuration:id:logger:)" + ) + public static func connect( + to socketAddress: SocketAddress, + tlsConfiguration: TLSConfiguration? = nil, + serverHostname: String? = nil, + logger: Logger = .init(label: "codes.vapor.postgres"), + on eventLoop: EventLoop + ) -> EventLoopFuture { + var tlsFuture: EventLoopFuture + + if let tlsConfiguration = tlsConfiguration { + tlsFuture = eventLoop.makeSucceededVoidFuture().flatMapBlocking(onto: .global(qos: .default)) { + try .require(.init(configuration: tlsConfiguration)) + } + } else { + tlsFuture = eventLoop.makeSucceededFuture(.disable) + } + + return tlsFuture.flatMap { tls in + var options = PostgresConnection.Configuration.Options() + options.tlsServerName = serverHostname + let configuration = PostgresConnection.InternalConfiguration( + connection: .resolved(address: socketAddress), + username: nil, + password: nil, + database: nil, + tls: tls, + options: options + ) + + return PostgresConnection.connect( + connectionID: self.idGenerator.wrappingIncrementThenLoad(ordering: .relaxed), + configuration: configuration, + logger: logger, + on: eventLoop + ) + }.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } + + @available(*, deprecated, + message: "Use the new connect method that allows you to connect and authenticate in a single step", + renamed: "connect(on:configuration:id:logger:)" + ) + public func authenticate( + username: String, + database: String? = nil, + password: String? = nil, + logger: Logger = .init(label: "codes.vapor.postgres") + ) -> EventLoopFuture { + let authContext = AuthContext( + username: username, + password: password, + database: database) + let outgoing = PSQLOutgoingEvent.authenticate(authContext) + self.channel.triggerUserOutboundEvent(outgoing, promise: nil) + + return self.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in + handler.authenticateFuture + }.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } +} + +// MARK: Async/Await Interface + +extension PostgresConnection { + + /// Creates a new connection to a Postgres server. + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` the connection shall be created on. + /// - configuration: A ``Configuration`` that shall be used for the connection + /// - connectionID: An `Int` id, used for metadata logging + /// - logger: A logger to log background events into + /// - Returns: An established ``PostgresConnection`` asynchronously that can be used to run queries. + public static func connect( + on eventLoop: EventLoop = PostgresConnection.defaultEventLoopGroup.any(), + configuration: PostgresConnection.Configuration, + id connectionID: ID, + logger: Logger + ) async throws -> PostgresConnection { + try await self.connect( + connectionID: connectionID, + configuration: .init(configuration), + logger: logger, + on: eventLoop + ).get() + } + + /// Closes the connection to the server. + public func close() async throws { + try await self.close().get() + } + + /// Closes the connection to the server, _after all queries_ that have been created on this connection have been run. + public func closeGracefully() async throws { + try await withTaskCancellationHandler { () async throws -> () in + let promise = self.eventLoop.makePromise(of: Void.self) + self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise) + return try await promise.futureResult.get() + } onCancel: { + self.close() + } + } + + /// Run a query on the Postgres server the connection is connected to. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresRowSequence { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + + guard query.binds.count <= Int(UInt16.max) else { + throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) + } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise + ) + + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + + do { + return try await promise.futureResult.map({ $0.asyncSequence() }).get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = query + throw error // rethrow with more metadata + } + } + + /// Start listening for a channel + public func listen(_ channel: String) async throws -> PostgresNotificationSequence { + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try Task.checkCancellation() + + return try await withCheckedThrowingContinuation { continuation in + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + checkedContinuation: continuation + ) + + let task = HandlerTask.startListening(listener) + + self.channel.write(task, promise: nil) + } + } onCancel: { + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) + } + } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: Statement.name, + sql: Statement.sql, + bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } + + /// Execute a prepared statement, taking care of the preparation when necessary + @_disfavoredOverload + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> String where Statement.Row == () { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: Statement.name, + sql: Statement.sql, + bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.commandTag } + .get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } + + #if compiler(>=6.0) + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ process: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #else + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ process: (PostgresConnection) async throws -> Result + ) async throws -> Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #endif +} + +// MARK: EventLoopFuture interface + +extension PostgresConnection { + + /// Run a query on the Postgres server the connection is connected to and collect all rows. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryResult``. + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) -> EventLoopFuture { + self.queryStream(query, logger: logger).flatMap { rowStream in + rowStream.all().flatMapThrowing { rows -> PostgresQueryResult in + guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(rowStream.commandTag) + } + return PostgresQueryResult(metadata: metadata, rows: rows) + } + }.enrichPSQLError(query: query, file: file, line: line) + } + + /// Run a query on the Postgres server the connection is connected to and iterate the rows in a callback. + /// + /// - Note: This API does not support back-pressure. If you need back-pressure please use the query + /// API, that supports structured concurrency. + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - onRow: A closure that is invoked for every row. + /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryMetadata``. + @preconcurrency + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #fileID, + line: Int = #line, + _ onRow: @escaping @Sendable (PostgresRow) throws -> () + ) -> EventLoopFuture { + self.queryStream(query, logger: logger).flatMap { rowStream in + rowStream.onRow(onRow).flatMapThrowing { () -> PostgresQueryMetadata in + guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(rowStream.commandTag) + } + return metadata + } + }.enrichPSQLError(query: query, file: file, line: line) + } +} + +// MARK: PostgresDatabase conformance + +extension PostgresConnection: PostgresDatabase { + public func send( + _ request: PostgresRequest, + logger: Logger + ) -> EventLoopFuture { + guard let command = request as? PostgresCommands else { + preconditionFailure("\(#function) requires an instance of PostgresCommands. This will be a compile-time error in the future.") + } + + let resultFuture: EventLoopFuture + + switch command { + case .query(let query, let onMetadata, let onRow): + resultFuture = self.queryStream(query, logger: logger).flatMap { stream in + return stream.onRow(onRow).map { _ in + onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) + } + } + + case .queryAll(let query, let onResult): + resultFuture = self.queryStream(query, logger: logger).flatMap { rows in + return rows.all().map { allrows in + onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: allrows)) + } + } + + case .prepareQuery(let request): + resultFuture = self.prepareStatement(request.query, with: request.name, logger: logger).map { + request.prepared = PreparedQuery(underlying: $0, database: self) + } + + case .executePreparedStatement(let preparedQuery, let binds, let onRow): + var bindings = PostgresBindings(capacity: binds.count) + binds.forEach { bindings.append($0) } + + let statement = PSQLExecuteStatement( + name: preparedQuery.underlying.name, + binds: bindings, + rowDescription: preparedQuery.underlying.rowDescription + ) + + resultFuture = self.execute(statement, logger: logger).flatMap { rows in + return rows.onRow(onRow) + } + } + + return resultFuture.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } + + @preconcurrency + public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { + closure(self) + } +} + +internal enum PostgresCommands: PostgresRequest { + case query(PostgresQuery, + onMetadata: @Sendable (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable (PostgresRow) throws -> ()) + case queryAll(PostgresQuery, onResult: @Sendable (PostgresQueryResult) -> ()) + case prepareQuery(request: PrepareQueryRequest) + case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: @Sendable (PostgresRow) throws -> ()) + + func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { + fatalError("This function must not be called") + } + + func start() throws -> [PostgresMessage] { + fatalError("This function must not be called") + } + + func log(to logger: Logger) { + fatalError("This function must not be called") + } +} + +// MARK: Notifications + +/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. +public final class PostgresListenContext: Sendable { + private let promise: EventLoopPromise + + var future: EventLoopFuture { + self.promise.futureResult + } + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func cancel() { + self.promise.succeed() + } + + /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. + public func stop() { + self.promise.succeed() + } +} + +extension PostgresConnection { + /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. + @discardableResult + @preconcurrency + public func addListener( + channel: String, + handler notificationHandler: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) -> PostgresListenContext { + let listenContext = PostgresListenContext(promise: self.eventLoop.makePromise(of: Void.self)) + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + context: listenContext, + closure: notificationHandler + ) + + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) + + listenContext.future.whenComplete { _ in + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) + } + + return listenContext + } +} + +enum CloseTarget { + case preparedStatement(String) + case portal(String) +} + +extension EventLoopFuture { + func enrichPSQLError(query: PostgresQuery, file: String, line: Int) -> EventLoopFuture { + return self.flatMapErrorThrowing { error in + if var error = error as? PSQLError { + error.file = file + error.line = line + error.query = query + throw error + } else { + throw error + } + } + } +} + +extension PostgresConnection { + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { +#if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return NIOTSEventLoopGroup.singleton + } else { + return MultiThreadedEventLoopGroup.singleton + } +#else + return MultiThreadedEventLoopGroup.singleton +#endif } } diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift deleted file mode 100644 index 881f98c3..00000000 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+Close.swift +++ /dev/null @@ -1,34 +0,0 @@ -import NIO - - -/// PostgreSQL request to close a prepared statement or portal. -final class CloseRequest: PostgresRequest { - - /// Name of the prepared statement or portal to close. - let name: String - - /// Close - let target: PostgresMessage.Close.Target - - init(name: String, closeType: PostgresMessage.Close.Target) { - self.name = name - self.target = closeType - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - if message.identifier != .closeComplete { - fatalError("Unexpected PostgreSQL message \(message)") - } - return nil - } - - func start() throws -> [PostgresMessage] { - let close = try PostgresMessage.Close(target: target, name: name).message() - let sync = try PostgresMessage.Sync().message() - return [close, sync] - } - - func log(to logger: Logger) { - logger.debug("Requesting Close of \(name)") - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index cd38160b..56496172 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -1,16 +1,22 @@ -import Foundation +import NIOCore +import NIOConcurrencyHelpers +import struct Foundation.UUID extension PostgresDatabase { public func prepare(query: String) -> EventLoopFuture { let name = "nio-postgres-\(UUID().uuidString)" - let prepare = PrepareQueryRequest(query, as: name) - return self.send(prepare, logger: self.logger).map { () -> (PreparedQuery) in - let prepared = PreparedQuery(database: self, name: name, rowDescription: prepare.rowLookupTable) - return prepared + let request = PrepareQueryRequest(query, as: name) + return self.send(PostgresCommands.prepareQuery(request: request), logger: self.logger).map { _ in + // we can force unwrap the prepared here, since in a success case it must be set + // in the send method of `PostgresDatabase`. We do this dirty trick to work around + // the fact that the send method only returns an `EventLoopFuture`. + // Eventually we should move away from the `PostgresDatabase.send` API. + request.prepared! } } - public func prepare(query: String, handler: @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { + @preconcurrency + public func prepare(query: String, handler: @Sendable @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { prepare(query: query) .flatMap { preparedQuery in handler(preparedQuery) @@ -22,148 +28,52 @@ extension PostgresDatabase { } -public struct PreparedQuery { +public struct PreparedQuery: Sendable { + let underlying: PSQLPreparedStatement let database: PostgresDatabase - let name: String - let rowLookupTable: PostgresRow.LookupTable? - init(database: PostgresDatabase, name: String, rowDescription: PostgresRow.LookupTable?) { + init(underlying: PSQLPreparedStatement, database: PostgresDatabase) { + self.underlying = underlying self.database = database - self.name = name - self.rowLookupTable = rowDescription } public func execute(_ binds: [PostgresData] = []) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return self.execute(binds) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.execute(binds) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func execute(_ binds: [PostgresData] = [], _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { - let handler = ExecutePreparedQuery(query: self, binds: binds, onRow: onRow) - return database.send(handler, logger: database.logger) + @preconcurrency + public func execute(_ binds: [PostgresData] = [], _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + let command = PostgresCommands.executePreparedStatement(query: self, binds: binds, onRow: onRow) + return self.database.send(command, logger: self.database.logger) } public func deallocate() -> EventLoopFuture { - database.send(CloseRequest(name: self.name, - closeType: .preparedStatement), - logger:database.logger) - + self.underlying.connection.close(.preparedStatement(self.underlying.name), logger: self.database.logger) } } - -private final class PrepareQueryRequest: PostgresRequest { +final class PrepareQueryRequest: Sendable { let query: String let name: String - var rowLookupTable: PostgresRow.LookupTable? - var resultFormatCodes: [PostgresFormatCode] - var logger: Logger? - - init(_ query: String, as name: String) { - self.query = query - self.name = name - self.resultFormatCodes = [.binary] - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - switch message.identifier { - case .rowDescription: - let row = try PostgresMessage.RowDescription(message: message) - self.rowLookupTable = PostgresRow.LookupTable( - rowDescription: row, - resultFormat: self.resultFormatCodes - ) - return [] - case .noData: - return [] - case .parseComplete, .parameterDescription: - return [] - case .readyForQuery: - return nil - default: - fatalError("Unexpected message: \(message)") + var prepared: PreparedQuery? { + get { + self._prepared.withLockedValue { $0 } } - - } - - func start() throws -> [PostgresMessage] { - let parse = PostgresMessage.Parse( - statementName: self.name, - query: self.query, - parameterTypes: [] - ) - let describe = PostgresMessage.Describe( - command: .statement, - name: self.name - ) - return try [parse.message(), describe.message(), PostgresMessage.Sync().message()] - } - - - func log(to logger: Logger) { - self.logger = logger - logger.debug("\(self.query) prepared as \(self.name)") - } -} - - -private final class ExecutePreparedQuery: PostgresRequest { - let query: PreparedQuery - let binds: [PostgresData] - var onRow: (PostgresRow) throws -> () - var resultFormatCodes: [PostgresFormatCode] - var logger: Logger? - - init(query: PreparedQuery, binds: [PostgresData], onRow: @escaping (PostgresRow) throws -> ()) { - self.query = query - self.binds = binds - self.onRow = onRow - self.resultFormatCodes = [.binary] - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - switch message.identifier { - case .bindComplete: - return [] - case .dataRow: - let data = try PostgresMessage.DataRow(message: message) - guard let rowLookupTable = query.rowLookupTable else { - fatalError("row lookup was requested but never set") + set { + self._prepared.withLockedValue { + $0 = newValue } - let row = PostgresRow(dataRow: data, lookupTable: rowLookupTable) - try onRow(row) - return [] - case .noData: - return [] - case .commandComplete: - return [] - case .readyForQuery: - return nil - default: throw PostgresError.protocol("Unexpected message during query: \(message)") } } + let _prepared: NIOLockedValueBox = .init(nil) - func start() throws -> [PostgresMessage] { - - let bind = PostgresMessage.Bind( - portalName: "", - statementName: query.name, - parameterFormatCodes: self.binds.map { $0.formatCode }, - parameters: self.binds.map { .init(value: $0.value) }, - resultFormatCodes: self.resultFormatCodes - ) - let execute = PostgresMessage.Execute( - portalName: "", - maxRows: 0 - ) - - let sync = PostgresMessage.Sync() - return try [bind.message(), execute.message(), sync.message()] - } - - func log(to logger: Logger) { - self.logger = logger - logger.debug("Execute Prepared Query: \(query.name)") + init(_ query: String, as name: String) { + self.query = query + self.name = name } - } diff --git a/Sources/PostgresNIO/Data/PostgresData+Array.swift b/Sources/PostgresNIO/Data/PostgresData+Array.swift index 4febed36..5d648db6 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Array.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Array.swift @@ -1,17 +1,19 @@ +import NIOCore + extension PostgresData { - public init(array: [T]) - where T: PostgresDataConvertible - { + @available(*, deprecated, message: "Use ``PostgresQuery`` and ``PostgresBindings`` instead.") + public init(array: [T]) where T: PostgresDataConvertible { self.init( array: array.map { $0.postgresData }, elementType: T.postgresDataType ) } + public init(array: [PostgresData?], elementType: PostgresDataType) { var buffer = ByteBufferAllocator().buffer(capacity: 0) // 0 if empty, 1 if not buffer.writeInteger(array.isEmpty ? 0 : 1, as: UInt32.self) - // b + // b - this gets ignored by psql buffer.writeInteger(0, as: UInt32.self) // array element type buffer.writeInteger(elementType.rawValue) @@ -28,7 +30,7 @@ extension PostgresData { buffer.writeInteger(numericCast(value.readableBytes), as: UInt32.self) buffer.writeBuffer(&value) } else { - buffer.writeInteger(0, as: UInt32.self) + buffer.writeInteger(-1, as: Int32.self) } } } @@ -44,9 +46,8 @@ extension PostgresData { ) } - public func array(of type: T.Type = T.self) -> [T]? - where T: PostgresDataConvertible - { + @available(*, deprecated, message: "Use ``PostgresRow`` and ``PostgresDecodable`` instead.") + public func array(of type: T.Type = T.self) -> [T]? where T: PostgresDataConvertible { guard let array = self.array else { return nil } @@ -75,10 +76,10 @@ extension PostgresData { guard let isNotEmpty = value.readInteger(as: UInt32.self) else { return nil } - guard let b = value.readInteger(as: UInt32.self) else { + // b + guard let _ = value.readInteger(as: UInt32.self) else { return nil } - assert(b == 0, "Array b field did not equal zero") guard let type = value.readInteger(as: PostgresDataType.self) else { return nil } @@ -97,9 +98,9 @@ extension PostgresData { var array: [PostgresData] = [] while - let itemLength = value.readInteger(as: UInt32.self), - let itemValue = value.readSlice(length: numericCast(itemLength)) + let itemLength = value.readInteger(as: Int32.self) { + let itemValue = itemLength == -1 ? nil : value.readSlice(length: numericCast(itemLength)) let data = PostgresData( type: type, typeModifier: nil, @@ -112,6 +113,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Array: PostgresDataConvertible where Element: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { guard let arrayType = Element.postgresDataType.arrayType else { diff --git a/Sources/PostgresNIO/Data/PostgresData+Bool.swift b/Sources/PostgresNIO/Data/PostgresData+Bool.swift index 79c31dd8..0b9f2738 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bool.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bool.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresData { public init(bool: Bool) { var buffer = ByteBufferAllocator().buffer(capacity: 1) @@ -45,6 +47,7 @@ extension PostgresData: ExpressibleByBooleanLiteral { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Bool: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .bool diff --git a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift index 8316f61e..5ec507cd 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift @@ -1,4 +1,5 @@ import struct Foundation.Data +import NIOCore extension PostgresData { public init(bytes: Bytes) @@ -20,6 +21,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Data: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .bytea diff --git a/Sources/PostgresNIO/Data/PostgresData+Date.swift b/Sources/PostgresNIO/Data/PostgresData+Date.swift index 0afbc78f..6d730f25 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Date.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Date.swift @@ -1,4 +1,5 @@ -import Foundation +import struct Foundation.Date +import NIOCore extension PostgresData { public init(date: Date) { @@ -35,6 +36,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Date: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .timestamptz diff --git a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift index f98e06af..3af709e5 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift @@ -16,9 +16,10 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Decimal: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { - return String.postgresDataType + return .numeric } public init?(postgresData: PostgresData) { @@ -29,6 +30,6 @@ extension Decimal: PostgresDataConvertible { } public var postgresData: PostgresData? { - return .init(decimal: self) + return .init(numeric: PostgresNumeric(decimal: self)) } } diff --git a/Sources/PostgresNIO/Data/PostgresData+Double.swift b/Sources/PostgresNIO/Data/PostgresData+Double.swift index 6012cc03..2d7735ef 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Double.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Double.swift @@ -1,7 +1,9 @@ +import NIOCore + extension PostgresData { public init(double: Double) { var buffer = ByteBufferAllocator().buffer(capacity: 0) - buffer.writeDouble(double) + buffer.psqlWriteDouble(double) self.init(type: .float8, formatCode: .binary, value: buffer) } @@ -14,10 +16,10 @@ extension PostgresData { case .binary: switch self.type { case .float4: - return value.readFloat() + return value.psqlReadFloat() .flatMap { Double($0) } case .float8: - return value.readDouble() + return value.psqlReadDouble() case .numeric: return self.numeric?.double default: @@ -32,6 +34,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Double: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .float8 diff --git a/Sources/PostgresNIO/Data/PostgresData+Float.swift b/Sources/PostgresNIO/Data/PostgresData+Float.swift index e9b7b572..45430934 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Float.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Float.swift @@ -12,9 +12,9 @@ extension PostgresData { case .binary: switch self.type { case .float4: - return value.readFloat() + return value.psqlReadFloat() case .float8: - return value.readDouble() + return value.psqlReadDouble() .flatMap { Float($0) } default: return nil @@ -28,6 +28,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Float: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .float4 diff --git a/Sources/PostgresNIO/Data/PostgresData+Int.swift b/Sources/PostgresNIO/Data/PostgresData+Int.swift index ce77dd43..5a97b3fb 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Int.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Int.swift @@ -1,7 +1,6 @@ extension PostgresData { public init(int value: Int) { - assert(Int.bitWidth == 64) - self.init(type: .int8, value: .init(integer: value)) + self.init(type: .int8, value: .init(integer: Int64(value))) } public init(uint8 value: UInt8) { @@ -32,25 +31,19 @@ extension PostgresData { guard value.readableBytes == 1 else { return nil } - return value.readInteger(as: UInt8.self) - .flatMap(Int.init) + return value.readInteger(as: UInt8.self).flatMap(Int.init) case .int2: assert(value.readableBytes == 2) - return value.readInteger(as: Int16.self) - .flatMap(Int.init) + return value.readInteger(as: Int16.self).flatMap(Int.init) case .int4, .regproc: assert(value.readableBytes == 4) - return value.readInteger(as: Int32.self) - .flatMap(Int.init) + return value.readInteger(as: Int32.self).flatMap(Int.init) case .oid: assert(value.readableBytes == 4) - assert(Int.bitWidth == 64) // or else overflow is possible - return value.readInteger(as: UInt32.self) - .flatMap(Int.init) + return value.readInteger(as: UInt32.self).flatMap { Int(exactly: $0) } case .int8: assert(value.readableBytes == 8) - assert(Int.bitWidth == 64) - return value.readInteger(as: Int.self) + return value.readInteger(as: Int64.self).flatMap { Int(exactly: $0) } default: return nil } @@ -190,6 +183,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int8 } @@ -205,6 +199,7 @@ extension Int: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension UInt8: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .char } @@ -220,6 +215,7 @@ extension UInt8: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int16: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int2 } @@ -235,6 +231,7 @@ extension Int16: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int32: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int4 } @@ -250,6 +247,7 @@ extension Int32: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int64: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int8 } diff --git a/Sources/PostgresNIO/Data/PostgresData+JSON.swift b/Sources/PostgresNIO/Data/PostgresData+JSON.swift index 21a2abb4..53a2d84c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSON.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSON.swift @@ -1,4 +1,5 @@ -import Foundation +import struct Foundation.Data +import NIOCore extension PostgresData { public init(json jsonData: Data) { @@ -11,7 +12,7 @@ extension PostgresData { } public init(json value: T) throws where T: Encodable { - let jsonData = try JSONEncoder().encode(value) + let jsonData = try PostgresNIO._defaultJSONEncoder.encode(value) self.init(json: jsonData) } @@ -32,12 +33,14 @@ extension PostgresData { guard let data = self.json else { return nil } - return try JSONDecoder().decode(T.self, from: data) + return try PostgresNIO._defaultJSONDecoder.decode(T.self, from: data) } } +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable`` and conforming to ``Codable`` at the same time") public protocol PostgresJSONCodable: Codable, PostgresDataConvertible { } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresJSONCodable { public static var postgresDataType: PostgresDataType { return .json diff --git a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift index 8fade8a5..0d5befa3 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift @@ -1,4 +1,5 @@ -import Foundation +import NIOCore +import struct Foundation.Data fileprivate let jsonBVersionBytes: [UInt8] = [0x01] @@ -15,7 +16,7 @@ extension PostgresData { } public init(jsonb value: T) throws where T: Encodable { - let jsonData = try JSONEncoder().encode(value) + let jsonData = try PostgresNIO._defaultJSONEncoder.encode(value) self.init(jsonb: jsonData) } @@ -43,12 +44,14 @@ extension PostgresData { return nil } - return try JSONDecoder().decode(T.self, from: data) + return try PostgresNIO._defaultJSONDecoder.decode(T.self, from: data) } } +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable`` and conforming to ``Codable`` at the same time") public protocol PostgresJSONBCodable: Codable, PostgresDataConvertible { } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresJSONBCodable { public static var postgresDataType: PostgresDataType { return .jsonb diff --git a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift index 96bd6a77..e736a61c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift @@ -1,3 +1,4 @@ +import NIOCore import struct Foundation.Decimal public struct PostgresNumeric: CustomStringConvertible, CustomDebugStringConvertible, ExpressibleByStringLiteral { @@ -267,16 +268,10 @@ private extension Collection { // splits the collection into chunks of the supplied size // if the collection is not evenly divisible, the first chunk will be smaller func reverseChunked(by maxSize: Int) -> [SubSequence] { - var lastDistance = 0 var chunkStartIndex = self.startIndex return stride(from: 0, to: self.count, by: maxSize).reversed().map { current in - let distance = (self.count - current) - lastDistance - lastDistance = distance - let chunkEndOffset = Swift.min( - self.distance(from: chunkStartIndex, to: self.endIndex), - distance - ) - let chunkEndIndex = self.index(chunkStartIndex, offsetBy: chunkEndOffset) + let distance = self.count - current + let chunkEndIndex = self.index(self.startIndex, offsetBy: distance) defer { chunkStartIndex = chunkEndIndex } return self[chunkStartIndex.." @@ -93,12 +96,16 @@ public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertibl return "\(raw) (\(self.type))" } } +} +@available(*, deprecated, message: "Deprecating conformance to `CustomDebugStringConvertible` as a first step of deprecating `PostgresData`. Please use `PostgresBindings` or `PostgresCell` instead.") +extension PostgresData: CustomDebugStringConvertible { public var debugDescription: String { return self.description } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresData: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { fatalError("PostgresData cannot be statically represented as a single data type") diff --git a/Sources/PostgresNIO/Data/PostgresDataConvertible.swift b/Sources/PostgresNIO/Data/PostgresDataConvertible.swift index 32e7fc41..675ed6fe 100644 --- a/Sources/PostgresNIO/Data/PostgresDataConvertible.swift +++ b/Sources/PostgresNIO/Data/PostgresDataConvertible.swift @@ -1,5 +1,6 @@ import Foundation +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable``") public protocol PostgresDataConvertible { static var postgresDataType: PostgresDataType { get } init?(postgresData: PostgresData) diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index c9c96eb7..c3e4e747 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -1,11 +1,14 @@ -/// The format code being used for the field. -/// Currently will be zero (text) or one (binary). -/// In a RowDescription returned from the statement variant of Describe, -/// the format code is not yet known and will always be zero. -public enum PostgresFormatCode: Int16, Codable, CustomStringConvertible { +/// The format the postgres types are encoded in on the wire. +/// +/// Currently there a two wire formats supported: +/// - text +/// - binary +public enum PostgresFormat: Int16, Sendable { case text = 0 case binary = 1 - +} + +extension PostgresFormat: CustomStringConvertible { public var description: String { switch self { case .text: return "text" @@ -14,9 +17,20 @@ public enum PostgresFormatCode: Int16, Codable, CustomStringConvertible { } } -/// The data type's raw object ID. -/// Use `select * from pg_type where oid = ;` to lookup more information. -public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, CustomStringConvertible, RawRepresentable { +// TODO: The Codable conformance does not make any sense. Let's remove this with next major break. +extension PostgresFormat: Codable {} + +// TODO: Renamed during 1.x. Remove this with next major break. +@available(*, deprecated, renamed: "PostgresFormat") +public typealias PostgresFormatCode = PostgresFormat + +/// Data types and their raw OIDs. +/// +/// Use `select * from pg_type where oid = ` to look up more information for a given type. +/// +/// This list was generated by running `select oid, typname from pg_type where oid < 10000 order by oid` +/// and manually trimming Postgres-internal types. +public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStringConvertible { /// `0` public static let null = PostgresDataType(0) /// `16` @@ -31,6 +45,8 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public static let int8 = PostgresDataType(20) /// `21` public static let int2 = PostgresDataType(21) + /// `22` + public static let int2vector = PostgresDataType(22) /// `23` public static let int4 = PostgresDataType(23) /// `24` @@ -39,18 +55,77 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public static let text = PostgresDataType(25) /// `26` public static let oid = PostgresDataType(26) + /// `27` + public static let tid = PostgresDataType(27) + /// `28` + public static let xid = PostgresDataType(28) + /// `29` + public static let cid = PostgresDataType(29) + /// `30` + public static let oidvector = PostgresDataType(30) + /// `32` + public static let pgDDLCommand = PostgresDataType(32) /// `114` public static let json = PostgresDataType(114) + /// `142` + public static let xml = PostgresDataType(142) + /// `143` + public static let xmlArray = PostgresDataType(143) /// `194` pg_node_tree + @available(*, deprecated, message: "This is internal to Postgres and should not be used.") public static let pgNodeTree = PostgresDataType(194) + /// `199` + public static let jsonArray = PostgresDataType(199) + /// `269` + public static let tableAMHandler = PostgresDataType(269) + /// `271` + public static let xid8Array = PostgresDataType(271) + /// `325` + public static let indexAMHandler = PostgresDataType(325) /// `600` public static let point = PostgresDataType(600) + /// `601` + public static let lseg = PostgresDataType(601) + /// `602` + public static let path = PostgresDataType(602) + /// `603` + public static let box = PostgresDataType(603) + /// `604` + public static let polygon = PostgresDataType(604) + /// `628` + public static let line = PostgresDataType(628) + /// `629` + public static let lineArray = PostgresDataType(629) + /// `650` + public static let cidr = PostgresDataType(650) + /// `651` + public static let cidrArray = PostgresDataType(651) /// `700` public static let float4 = PostgresDataType(700) /// `701` public static let float8 = PostgresDataType(701) + /// `705` + public static let unknown = PostgresDataType(705) + /// `718` + public static let circle = PostgresDataType(718) + /// `719` + public static let circleArray = PostgresDataType(719) + /// `774` + public static let macaddr8 = PostgresDataType(774) + /// `775` + @available(*, deprecated, renamed: "macaddr8Array") + public static let macaddr8Aray = Self.macaddr8Array + public static let macaddr8Array = PostgresDataType(775) /// `790` public static let money = PostgresDataType(790) + /// `791` + @available(*, deprecated, renamed: "moneyArray") + public static let _money = Self.moneyArray + public static let moneyArray = PostgresDataType(791) + /// `829` + public static let macaddr = PostgresDataType(829) + /// `869` + public static let inet = PostgresDataType(869) /// `1000` _bool public static let boolArray = PostgresDataType(1000) /// `1001` _bytea @@ -61,22 +136,52 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public static let nameArray = PostgresDataType(1003) /// `1005` _int2 public static let int2Array = PostgresDataType(1005) + /// `1006` + public static let int2vectorArray = PostgresDataType(1006) /// `1007` _int4 public static let int4Array = PostgresDataType(1007) + /// `1008` + public static let regprocArray = PostgresDataType(1008) /// `1009` _text public static let textArray = PostgresDataType(1009) + /// `1010` + public static let tidArray = PostgresDataType(1010) + /// `1011` + public static let xidArray = PostgresDataType(1011) + /// `1012` + public static let cidArray = PostgresDataType(1012) + /// `1013` + public static let oidvectorArray = PostgresDataType(1013) + /// `1014` + public static let bpcharArray = PostgresDataType(1014) /// `1015` _varchar public static let varcharArray = PostgresDataType(1015) /// `1016` _int8 public static let int8Array = PostgresDataType(1016) /// `1017` _point public static let pointArray = PostgresDataType(1017) + /// `1018` + public static let lsegArray = PostgresDataType(1018) + /// `1019` + public static let pathArray = PostgresDataType(1019) + /// `1020` + public static let boxArray = PostgresDataType(1020) /// `1021` _float4 public static let float4Array = PostgresDataType(1021) /// `1022` _float8 public static let float8Array = PostgresDataType(1022) + /// `1027` + public static let polygonArray = PostgresDataType(1027) + /// `1028` + public static let oidArray = PostgresDataType(1018) + /// `1033` + public static let aclitem = PostgresDataType(1033) /// `1034` _aclitem public static let aclitemArray = PostgresDataType(1034) + /// `1040` + public static let macaddrArray = PostgresDataType(1040) + /// `1041` + public static let inetArray = PostgresDataType(1041) /// `1042` public static let bpchar = PostgresDataType(1042) /// `1043` @@ -89,22 +194,202 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public static let timestamp = PostgresDataType(1114) /// `1115` _timestamp public static let timestampArray = PostgresDataType(1115) + /// `1182` + public static let dateArray = PostgresDataType(1182) + /// `1183` + public static let timeArray = PostgresDataType(1183) /// `1184` public static let timestamptz = PostgresDataType(1184) + /// `1185` + public static let timestamptzArray = PostgresDataType(1185) + /// `1186` + public static let interval = PostgresDataType(1186) + /// `1187` + public static let intervalArray = PostgresDataType(1187) + /// `1231` + public static let numericArray = PostgresDataType(1231) + /// `1263` + public static let cstringArray = PostgresDataType(1263) /// `1266` public static let timetz = PostgresDataType(1266) + /// `1270` + public static let timetzArray = PostgresDataType(1270) + /// `1560` + public static let bit = PostgresDataType(1560) + /// `1561` + public static let bitArray = PostgresDataType(1561) + /// `1562` + public static let varbit = PostgresDataType(1562) + /// `1563` + public static let varbitArray = PostgresDataType(1563) /// `1700` public static let numeric = PostgresDataType(1700) + /// `1790` + public static let refcursor = PostgresDataType(1790) + /// `2201` + public static let refcursorArray = PostgresDataType(2201) + /// `2202` + public static let regprocedure = PostgresDataType(2202) + /// `2203` + public static let regoper = PostgresDataType(2203) + /// `2204` + public static let regoperator = PostgresDataType(2204) + /// `2205` + public static let regclass = PostgresDataType(2205) + /// `2206` + public static let regtype = PostgresDataType(2206) + /// `2207` + public static let regprocedureArray = PostgresDataType(2207) + /// `2208` + public static let regoperArray = PostgresDataType(2208) + /// `2209` + public static let regoperatorArray = PostgresDataType(2209) + /// `2210` + public static let regclassArray = PostgresDataType(2210) + /// `2211` + public static let regtypeArray = PostgresDataType(2211) + /// `2249` + public static let record = PostgresDataType(2249) + /// `2275` + public static let cstring = PostgresDataType(2275) + /// `2276` + public static let any = PostgresDataType(2276) + /// `2277` + public static let anyarray = PostgresDataType(2277) /// `2278` public static let void = PostgresDataType(2278) + /// `2279` + public static let trigger = PostgresDataType(2279) + /// `2280` + public static let languageHandler = PostgresDataType(2280) + /// `2281` + public static let `internal` = PostgresDataType(2281) + /// `2283` + public static let anyelement = PostgresDataType(2283) + /// `2287` + public static let recordArray = PostgresDataType(2287) + /// `2776` + public static let anynonarray = PostgresDataType(2776) /// `2950` public static let uuid = PostgresDataType(2950) /// `2951` _uuid public static let uuidArray = PostgresDataType(2951) + /// `3115` + public static let fdwHandler = PostgresDataType(3115) + /// `3220` + public static let pgLSN = PostgresDataType(3220) + /// `3221` + public static let pgLSNArray = PostgresDataType(3221) + /// `3310` + public static let tsmHandler = PostgresDataType(3310) + /// `3500` + public static let anyenum = PostgresDataType(3500) + /// `3614` + public static let tsvector = PostgresDataType(3614) + /// `3615` + public static let tsquery = PostgresDataType(3615) + /// `3642` + public static let gtsvector = PostgresDataType(3642) + /// `3643` + public static let tsvectorArray = PostgresDataType(3643) + /// `3644` + public static let gtsvectorArray = PostgresDataType(3644) + /// `3645` + public static let tsqueryArray = PostgresDataType(3645) + /// `3734` + public static let regconfig = PostgresDataType(3734) + /// `3735` + public static let regconfigArray = PostgresDataType(3735) + /// `3769` + public static let regdictionary = PostgresDataType(3769) + /// `3770` + public static let regdictionaryArray = PostgresDataType(3770) /// `3802` public static let jsonb = PostgresDataType(3802) /// `3807` _jsonb public static let jsonbArray = PostgresDataType(3807) + /// `3831` + public static let anyrange = PostgresDataType(3831) + /// `3838` + public static let eventTrigger = PostgresDataType(3838) + /// `3904` + public static let int4Range = PostgresDataType(3904) + /// `3905` _int4range + public static let int4RangeArray = PostgresDataType(3905) + /// `3906` + public static let numrange = PostgresDataType(3906) + /// `3907` + public static let numrangeArray = PostgresDataType(3907) + /// `3908` + public static let tsrange = PostgresDataType(3908) + /// `3909` + public static let tsrangeArray = PostgresDataType(3909) + /// `3910` + public static let tstzrange = PostgresDataType(3910) + /// `3911` + public static let tstzrangeArray = PostgresDataType(3911) + /// `3912` + public static let daterange = PostgresDataType(3912) + /// `3913` + public static let daterangeArray = PostgresDataType(3913) + /// `3926` + public static let int8Range = PostgresDataType(3926) + /// `3927` _int8range + public static let int8RangeArray = PostgresDataType(3927) + /// `4072` + public static let jsonpath = PostgresDataType(4072) + /// `4073` + public static let jsonpathArray = PostgresDataType(4073) + /// `4089` + public static let regnamespace = PostgresDataType(4089) + /// `4090` + public static let regnamespaceArray = PostgresDataType(4090) + /// `4096` + public static let regrole = PostgresDataType(4096) + /// `4097` + public static let regroleArray = PostgresDataType(4097) + /// `4191` + public static let regcollation = PostgresDataType(4191) + /// `4192` + public static let regcollationArray = PostgresDataType(4192) + /// `4451` + public static let int4multirange = PostgresDataType(4451) + /// `4532` + public static let nummultirange = PostgresDataType(4532) + /// `4533` + public static let tsmultirange = PostgresDataType(4533) + /// `4534` + public static let tstzmultirange = PostgresDataType(4534) + /// `4535` + public static let datemultirange = PostgresDataType(4535) + /// `4536` + public static let int8multirange = PostgresDataType(4536) + /// `4537` + public static let anymultirange = PostgresDataType(4537) + /// `4538` + public static let anycompatiblemultirange = PostgresDataType(4538) + /// `5069` + public static let xid8 = PostgresDataType(5069) + /// `5077` + public static let anycompatible = PostgresDataType(5077) + /// `5078` + public static let anycompatiblearray = PostgresDataType(5078) + /// `5079` + public static let anycompatiblenonarray = PostgresDataType(5079) + /// `5080` + public static let anycompatiblerange = PostgresDataType(5080) + /// `6150` + public static let int4multirangeArray = PostgresDataType(6150) + /// `6151` + public static let nummultirangeArray = PostgresDataType(6151) + /// `6152` + public static let tsmultirangeArray = PostgresDataType(6152) + /// `6153` + public static let tstzmultirangeArray = PostgresDataType(6153) + /// `6155` + public static let datemultirangeArray = PostgresDataType(6155) + /// `6157` + public static let int8multirangeArray = PostgresDataType(6157) /// The raw data type code recognized by PostgreSQL. public var rawValue: UInt32 @@ -115,12 +400,7 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public var isUserDefined: Bool { self.rawValue >= 1 << 14 } - - /// See `ExpressibleByIntegerLiteral.init(integerLiteral:)` - public init(integerLiteral value: UInt32) { - self.init(value) - } - + public init(_ rawValue: UInt32) { self.rawValue = rawValue } @@ -128,60 +408,255 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public init?(rawValue: UInt32) { self.init(rawValue) } - + /// Returns the known SQL name, if one exists. /// Note: This only supports a limited subset of all PSQL types and is meant for convenience only. + /// This list was manually generated. public var knownSQLName: String? { switch self { + case .null: return "NULL" case .bool: return "BOOLEAN" case .bytea: return "BYTEA" case .char: return "CHAR" case .name: return "NAME" case .int8: return "BIGINT" case .int2: return "SMALLINT" + case .int2vector: return "INT2VECTOR" case .int4: return "INTEGER" case .regproc: return "REGPROC" case .text: return "TEXT" case .oid: return "OID" + case .tid: return "TID" + case .xid: return "XID" + case .cid: return "CID" + case .oidvector: return "OIDVECTOR" + case .pgDDLCommand: return "PG_DDL_COMMAND" case .json: return "JSON" - case .pgNodeTree: return "PGNODETREE" + case .xml: return "XML" + case .xmlArray: return "XML[]" + case .jsonArray: return "JSON[]" + case .tableAMHandler: return "TABLE_AM_HANDLER" + case .xid8Array: return "XID8[]" + case .indexAMHandler: return "INDEX_AM_HANDLER" case .point: return "POINT" + case .lseg: return "LSEG" + case .path: return "PATH" + case .box: return "BOX" + case .polygon: return "POLYGON" + case .line: return "LINE" + case .lineArray: return "LINE[]" + case .cidr: return "CIDR" + case .cidrArray: return "CIDR[]" case .float4: return "REAL" case .float8: return "DOUBLE PRECISION" + case .circle: return "CIRCLE" + case .circleArray: return "CIRCLE[]" + case .macaddr8: return "MACADDR8" + case .macaddr8Array: return "MACADDR8[]" case .money: return "MONEY" + case .moneyArray: return "MONEY[]" + case .macaddr: return "MACADDR" + case .inet: return "INET" case .boolArray: return "BOOLEAN[]" case .byteaArray: return "BYTEA[]" case .charArray: return "CHAR[]" case .nameArray: return "NAME[]" case .int2Array: return "SMALLINT[]" + case .int2vectorArray: return "INT2VECTOR[]" case .int4Array: return "INTEGER[]" + case .regprocArray: return "REGPROC[]" case .textArray: return "TEXT[]" + case .tidArray: return "TID[]" + case .xidArray: return "XID[]" + case .cidArray: return "CID[]" + case .oidvectorArray: return "OIDVECTOR[]" + case .bpcharArray: return "CHARACTER[]" case .varcharArray: return "VARCHAR[]" case .int8Array: return "BIGINT[]" case .pointArray: return "POINT[]" + case .lsegArray: return "LSEG[]" + case .pathArray: return "PATH[]" + case .boxArray: return "BOX[]" case .float4Array: return "REAL[]" case .float8Array: return "DOUBLE PRECISION[]" + case .polygonArray: return "POLYGON[]" + case .oidArray: return "OID[]" + case .aclitem: return "ACLITEM" case .aclitemArray: return "ACLITEM[]" - case .bpchar: return "BPCHAR" + case .macaddrArray: return "MACADDR[]" + case .inetArray: return "INET[]" + case .bpchar: return "CHARACTER" case .varchar: return "VARCHAR" case .date: return "DATE" case .time: return "TIME" case .timestamp: return "TIMESTAMP" - case .timestamptz: return "TIMESTAMPTZ" case .timestampArray: return "TIMESTAMP[]" + case .dateArray: return "DATE[]" + case .timeArray: return "TIME[]" + case .timestamptz: return "TIMESTAMPTZ" + case .timestamptzArray: return "TIMESTAMPTZ[]" + case .interval: return "INTERVAL" + case .intervalArray: return "INTERVAL[]" + case .numericArray: return "NUMERIC[]" + case .cstringArray: return "CSTRING[]" + case .timetz: return "TIMETZ" + case .timetzArray: return "TIMETZ[]" + case .bit: return "BIT" + case .bitArray: return "BIT[]" + case .varbit: return "VARBIT" + case .varbitArray: return "VARBIT[]" case .numeric: return "NUMERIC" + case .refcursor: return "REFCURSOR" + case .refcursorArray: return "REFCURSOR[]" + case .regprocedure: return "REGPROCEDURE" + case .regoper: return "REGOPER" + case .regoperator: return "REGOPERATOR" + case .regclass: return "REGCLASS" + case .regtype: return "REGTYPE" + case .regprocedureArray: return "REGPROCEDURE[]" + case .regoperArray: return "REGOPER[]" + case .regoperatorArray: return "REGOPERATOR[]" + case .regclassArray: return "REGCLASS[]" + case .regtypeArray: return "REGTYPE[]" + case .record: return "RECORD" + case .cstring: return "CSTRING" + case .any: return "ANY" + case .anyarray: return "ANYARRAY" case .void: return "VOID" + case .trigger: return "TRIGGER" + case .languageHandler: return "LANGUAGE_HANDLER" + case .`internal`: return "INTERNAL" + case .anyelement: return "ANYELEMENT" + case .recordArray: return "RECORD[]" + case .anynonarray: return "ANYNONARRAY" case .uuid: return "UUID" case .uuidArray: return "UUID[]" + case .fdwHandler: return "FDW_HANDLER" + case .pgLSN: return "PG_LSN" + case .pgLSNArray: return "PG_LSN[]" + case .tsmHandler: return "TSM_HANDLER" + case .anyenum: return "ANYENUM" + case .tsvector: return "TSVECTOR" + case .tsquery: return "TSQUERY" + case .gtsvector: return "GTSVECTOR" + case .tsvectorArray: return "TSVECTOR[]" + case .gtsvectorArray: return "GTSVECTOR[]" + case .tsqueryArray: return "TSQUERY[]" + case .regconfig: return "REGCONFIG" + case .regconfigArray: return "REGCONFIG[]" + case .regdictionary: return "REGDICTIONARY" + case .regdictionaryArray: return "REGDICTIONARY[]" case .jsonb: return "JSONB" case .jsonbArray: return "JSONB[]" + case .anyrange: return "ANYRANGE" + case .eventTrigger: return "EVENT_TRIGGER" + case .int4Range: return "INT4RANGE" + case .int4RangeArray: return "INT4RANGE[]" + case .numrange: return "NUMRANGE" + case .numrangeArray: return "NUMRANGE[]" + case .tsrange: return "TSRANGE" + case .tsrangeArray: return "TSRANGE[]" + case .tstzrange: return "TSTZRANGE" + case .tstzrangeArray: return "TSTZRANGE[]" + case .daterange: return "DATERANGE" + case .daterangeArray: return "DATERANGE[]" + case .int8Range: return "INT8RANGE" + case .int8RangeArray: return "INT8RANGE[]" + case .jsonpath: return "JSONPATH" + case .jsonpathArray: return "JSONPATH[]" + case .regnamespace: return "REGNAMESPACE" + case .regnamespaceArray: return "REGNAMESPACE[]" + case .regrole: return "REGROLE" + case .regroleArray: return "REGROLE[]" + case .regcollation: return "REGCOLLATION" + case .regcollationArray: return "REGCOLLATION[]" + case .int4multirange: return "INT4MULTIRANGE" + case .nummultirange: return "NUMMULTIRANGE" + case .tsmultirange: return "TSMULTIRANGE" + case .tstzmultirange: return "TSTZMULTIRANGE" + case .datemultirange: return "DATEMULTIRANGE" + case .int8multirange: return "INT8MULTIRANGE" + case .anymultirange: return "ANYMULTIRANGE" + case .anycompatiblemultirange: return "ANYCOMPATIBLEMULTIRANGE" + case .xid8: return "XID8" + case .anycompatible: return "ANYCOMPATIBLE" + case .anycompatiblearray: return "ANYCOMPATIBLEARRAY" + case .anycompatiblenonarray: return "ANYCOMPATIBLENONARRAY" + case .anycompatiblerange: return "ANYCOMPATIBLERANG" + case .int4multirangeArray: return "INT4MULTIRANGE[]" + case .nummultirangeArray: return "NUMMULTIRANGE[]" + case .tsmultirangeArray: return "TSMULTIRANGE[]" + case .tstzmultirangeArray: return "TSTZMULTIRANGE[]" + case .datemultirangeArray: return "DATEMULTIRANGE[]" + case .int8multirangeArray: return "INT8MULTIRANGE[]" default: return nil } } /// Returns the array type for this type if one is known. + /// + /// This list was manually generated. internal var arrayType: PostgresDataType? { switch self { + case .xml: return .xmlArray + case .json: return .jsonArray + case .xid8: return .xid8Array + case .line: return .lineArray + case .cidr: return .cidrArray + case .circle: return .circleArray + case .macaddr8: return .macaddr8Array + case .money: return .moneyArray + case .int2vector: return .int2vectorArray + case .regproc: return .regprocArray + case .tid: return .tidArray + case .xid: return .xidArray + case .cid: return .cidArray + case .oidvector: return .oidvectorArray + case .bpchar: return .bpcharArray + case .lseg: return .lsegArray + case .path: return .pathArray + case .box: return .boxArray + case .polygon: return .polygonArray + case .oid: return .oidArray + case .aclitem: return .aclitemArray + case .macaddr: return .macaddrArray + case .inet: return .inetArray + case .timestamp: return .timestampArray + case .date: return .dateArray + case .time: return .timeArray + case .timestamptz: return .timestamptzArray + case .interval: return .intervalArray + case .numeric: return .numericArray + case .cstring: return .cstringArray + case .timetz: return .timetzArray + case .bit: return .bitArray + case .varbit: return .varbitArray + case .refcursor: return .refcursorArray + case .regprocedure: return .regprocedureArray + case .regoper: return .regoperArray + case .regoperator: return .regoperatorArray + case .regclass: return .regclassArray + case .regtype: return .regtypeArray + case .record: return .recordArray + case .pgLSN: return .pgLSNArray + case .tsvector: return .tsvectorArray + case .gtsvector: return .gtsvectorArray + case .tsquery: return .tsqueryArray + case .regconfig: return .regconfigArray + case .regdictionary: return .regdictionaryArray + case .numrange: return .numrangeArray + case .tsrange: return .tsrangeArray + case .tstzrange: return .tstzrangeArray + case .daterange: return .daterangeArray + case .jsonpath: return .jsonpathArray + case .regnamespace: return .regnamespaceArray + case .regrole: return .regroleArray + case .regcollation: return .regcollationArray + case .int4multirange: return .int4multirangeArray + case .tsmultirange: return .tsmultirangeArray + case .tstzmultirange: return .tstzmultirangeArray + case .datemultirange: return .datemultirangeArray + case .int8multirange: return .int8multirangeArray case .bool: return .boolArray case .bytea: return .byteaArray case .char: return .charArray @@ -196,14 +671,77 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, case .jsonb: return .jsonbArray case .text: return .textArray case .varchar: return .varcharArray + case .int4Range: return .int4RangeArray + case .int8Range: return .int8RangeArray default: return nil } } /// Returns the element type for this type if one is known. /// Returns nil if this is not an array type. + /// + /// This list was manually generated. internal var elementType: PostgresDataType? { switch self { + case .xmlArray: return .xml + case .jsonArray: return .json + case .xid8Array: return .xid8 + case .lineArray: return .line + case .cidrArray: return .cidr + case .circleArray: return .circle + case .macaddr8Array: return .macaddr8 + case .moneyArray: return .money + case .int2vectorArray: return .int2vector + case .regprocArray: return .regproc + case .tidArray: return .tid + case .xidArray: return .xid + case .cidArray: return .cid + case .oidvectorArray: return .oidvector + case .bpcharArray: return .bpchar + case .lsegArray: return .lseg + case .pathArray: return .path + case .boxArray: return .box + case .polygonArray: return .polygon + case .oidArray: return .oid + case .aclitemArray: return .aclitem + case .macaddrArray: return .macaddr + case .inetArray: return .inet + case .timestampArray: return .timestamp + case .dateArray: return .date + case .timeArray: return .time + case .timestamptzArray: return .timestamptz + case .intervalArray: return .interval + case .numericArray: return .numeric + case .cstringArray: return .cstring + case .timetzArray: return .timetz + case .bitArray: return .bit + case .varbitArray: return .varbit + case .refcursorArray: return .refcursor + case .regprocedureArray: return .regprocedure + case .regoperArray: return .regoper + case .regoperatorArray: return .regoperator + case .regclassArray: return .regclass + case .regtypeArray: return .regtype + case .recordArray: return .record + case .pgLSNArray: return .pgLSN + case .tsvectorArray: return .tsvector + case .gtsvectorArray: return .gtsvector + case .tsqueryArray: return .tsquery + case .regconfigArray: return .regconfig + case .regdictionaryArray: return .regdictionary + case .numrangeArray: return .numrange + case .tsrangeArray: return .tsrange + case .tstzrangeArray: return .tstzrange + case .daterangeArray: return .daterange + case .jsonpathArray: return .jsonpath + case .regnamespaceArray: return .regnamespace + case .regroleArray: return .regrole + case .regcollationArray: return .regcollation + case .int4multirangeArray: return .int4multirange + case .tsmultirangeArray: return .tsmultirange + case .tstzmultirangeArray: return .tstzmultirange + case .datemultirangeArray: return .datemultirange + case .int8multirangeArray: return .int8multirange case .boolArray: return .bool case .byteaArray: return .bytea case .charArray: return .char @@ -218,12 +756,41 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, case .jsonbArray: return .jsonb case .textArray: return .text case .varcharArray: return .varchar + case .int4RangeArray: return .int4Range + case .int8RangeArray: return .int8Range + default: return nil + } + } + + /// Returns the bound type for this type if one is known. + /// Returns nil if this is not a range type. + /// + /// This list was manually generated. + @usableFromInline + internal var boundType: PostgresDataType? { + switch self { + case .int4Range: return .int4 + case .int8Range: return .int8 + case .numrange: return .numeric + case .tsrange: return .timestamp + case .tstzrange: return .timestamptz + case .daterange: return .date default: return nil } } - /// See `CustomStringConvertible`. + // See `CustomStringConvertible.description`. public var description: String { return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } } + +// TODO: The Codable conformance does not make any sense. Let's remove this with next major break. +extension PostgresDataType: Codable {} + +// TODO: The ExpressibleByIntegerLiteral conformance does not make any sense and is not used anywhere. Remove with next major break. +extension PostgresDataType: ExpressibleByIntegerLiteral { + public init(integerLiteral value: UInt32) { + self.init(value) + } +} diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 7c80fe91..e3aea692 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -1,76 +1,323 @@ -public struct PostgresRow: CustomStringConvertible { - final class LookupTable { - let rowDescription: PostgresMessage.RowDescription - let resultFormat: [PostgresFormatCode] - - struct Value { - let index: Int - let field: PostgresMessage.RowDescription.Field +import NIOCore +import class Foundation.JSONDecoder + +/// `PostgresRow` represents a single table row that is received from the server for a query or a prepared statement. +/// Its element type is ``PostgresCell``. +/// +/// - Warning: Please note that random access to cells in a ``PostgresRow`` have O(n) time complexity. If you require +/// random access to cells in O(1) create a new ``PostgresRandomAccessRow`` with the given row and +/// access it instead. +public struct PostgresRow: Sendable { + @usableFromInline + let lookupTable: [String: Int] + @usableFromInline + let data: DataRow + @usableFromInline + let columns: [RowDescription.Column] + + init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.data = data + self.lookupTable = lookupTable + self.columns = columns + } +} + +extension PostgresRow: Equatable { + public static func ==(lhs: Self, rhs: Self) -> Bool { + // we don't need to compare the lookup table here, as the looup table is only derived + // from the column description. + lhs.data == rhs.data && lhs.columns == rhs.columns + } +} + +extension PostgresRow: Sequence { + public typealias Element = PostgresCell + + public struct Iterator: IteratorProtocol { + public typealias Element = PostgresCell + + private(set) var columnIndex: Array.Index + private(set) var columnIterator: Array.Iterator + private(set) var dataIterator: DataRow.Iterator + + init(_ row: PostgresRow) { + self.columnIndex = 0 + self.columnIterator = row.columns.makeIterator() + self.dataIterator = row.data.makeIterator() } - - private var _storage: [String: Value]? - var storage: [String: Value] { - if let existing = self._storage { - return existing - } else { - let all = self.rowDescription.fields.enumerated().map { (index, field) in - return (field.name, Value(index: index, field: field)) - } - let storage = [String: Value](all) { a, b in - // take the first value - return a - } - self._storage = storage - return storage + + public mutating func next() -> PostgresCell? { + guard let bytes = self.dataIterator.next() else { + return nil } + + let column = self.columnIterator.next()! + + defer { self.columnIndex += 1 } + + return PostgresCell( + bytes: bytes, + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: columnIndex + ) } + } - init( - rowDescription: PostgresMessage.RowDescription, - resultFormat: [PostgresFormatCode] - ) { - self.rowDescription = rowDescription - self.resultFormat = resultFormat + public func makeIterator() -> Iterator { + Iterator(self) + } +} + +extension PostgresRow: Collection { + public struct Index: Comparable { + var cellIndex: DataRow.Index + var columnIndex: Array.Index + + // Only needed implementation for comparable. The compiler synthesizes the rest from this. + public static func < (lhs: Self, rhs: Self) -> Bool { + lhs.columnIndex < rhs.columnIndex } + } - func lookup(column: String) -> Value? { - if let value = self.storage[column] { - return value - } else { - return nil - } + public subscript(position: Index) -> PostgresCell { + let column = self.columns[position.columnIndex] + return PostgresCell( + bytes: self.data[position.cellIndex], + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: position.columnIndex + ) + } + + public var startIndex: Index { + Index( + cellIndex: self.data.startIndex, + columnIndex: 0 + ) + } + + public var endIndex: Index { + Index( + cellIndex: self.data.endIndex, + columnIndex: self.columns.count + ) + } + + public func index(after i: Index) -> Index { + Index( + cellIndex: self.data.index(after: i.cellIndex), + columnIndex: self.columns.index(after: i.columnIndex) + ) + } + + public var count: Int { + self.data.count + } +} + +extension PostgresRow { + public func makeRandomAccess() -> PostgresRandomAccessRow { + PostgresRandomAccessRow(self) + } +} + +/// A random access row of ``PostgresCell``s. Its initialization is O(n) where n is the number of columns +/// in the row. All subsequent cell access are O(1). +public struct PostgresRandomAccessRow { + let columns: [RowDescription.Column] + let cells: [ByteBuffer?] + let lookupTable: [String: Int] + + public init(_ row: PostgresRow) { + self.cells = [ByteBuffer?](row.data) + self.columns = row.columns + self.lookupTable = row.lookupTable + } +} + +extension PostgresRandomAccessRow: Sendable, RandomAccessCollection { + public typealias Element = PostgresCell + public typealias Index = Int + + public var startIndex: Int { + 0 + } + + public var endIndex: Int { + self.columns.count + } + + public var count: Int { + self.columns.count + } + + public subscript(index: Int) -> PostgresCell { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let column = self.columns[index] + return PostgresCell( + bytes: self.cells[index], + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: index + ) + } + + public subscript(name: String) -> PostgresCell { + guard let index = self.lookupTable[name] else { + fatalError(#"A column "\#(name)" does not exist."#) + } + return self[index] + } + + /// Checks if the row contains a cell for the given column name. + /// - Parameter column: The column name to check against + /// - Returns: `true` if the row contains this column, `false` if it does not. + public func contains(_ column: String) -> Bool { + self.lookupTable[column] != nil + } +} + +extension PostgresRandomAccessRow { + public subscript(data index: Int) -> PostgresData { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let column = self.columns[index] + return PostgresData( + type: column.dataType, + typeModifier: column.dataTypeModifier, + formatCode: .binary, + value: self.cells[index] + ) + } + + public subscript(data column: String) -> PostgresData { + guard let index = self.lookupTable[column] else { + fatalError(#"A column "\#(column)" does not exist."#) } + return self[data: index] } +} - public let dataRow: PostgresMessage.DataRow +extension PostgresRandomAccessRow { + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column name to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode( + column: String, + as type: T.Type, + context: PostgresDecodingContext, + file: String = #fileID, line: Int = #line + ) throws -> T { + guard let index = self.lookupTable[column] else { + fatalError(#"A column "\#(column)" does not exist."#) + } + return try self.decode(column: index, as: type, context: context, file: file, line: line) + } + + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column index to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode( + column index: Int, + as type: T.Type, + context: PostgresDecodingContext, + file: String = #fileID, line: Int = #line + ) throws -> T { + precondition(index < self.columns.count) + + let column = self.columns[index] + + var cellSlice = self.cells[index] + do { + return try T._decodeRaw(from: &cellSlice, type: column.dataType, format: column.format, context: context) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: self.columns[index].name, + columnIndex: index, + targetType: T.self, + postgresType: self.columns[index].dataType, + postgresFormat: self.columns[index].format, + postgresData: cellSlice, + file: file, + line: line + ) + } + } +} + +// MARK: Deprecated API + +extension PostgresRow { + @available(*, deprecated, message: "Will be removed from public API.") public var rowDescription: PostgresMessage.RowDescription { - self.lookupTable.rowDescription + let fields = self.columns.map { column in + PostgresMessage.RowDescription.Field( + name: column.name, + tableOID: UInt32(column.tableOID), + columnAttributeNumber: column.columnAttributeNumber, + dataType: PostgresDataType(UInt32(column.dataType.rawValue)), + dataTypeSize: column.dataTypeSize, + dataTypeModifier: column.dataTypeModifier, + formatCode: .init(psqlFormatCode: column.format) + ) + } + return PostgresMessage.RowDescription(fields: fields) } - let lookupTable: LookupTable + @available(*, deprecated, message: "Iterate the cells on `PostgresRow` instead.") + public var dataRow: PostgresMessage.DataRow { + let columns = self.data.map { + PostgresMessage.DataRow.Column(value: $0) + } + return PostgresMessage.DataRow(columns: columns) + } + @available(*, deprecated, message: """ + This call is O(n) where n is the number of cells in the row. For random access to cells + in a row create a PostgresRandomAccessRow from the row first and use its subscript + methods. (see `makeRandomAccess()`) + """) public func column(_ column: String) -> PostgresData? { - guard let entry = self.lookupTable.lookup(column: column) else { + guard let index = self.lookupTable[column] else { return nil } - let formatCode: PostgresFormatCode - switch self.lookupTable.resultFormat.count { - case 1: formatCode = self.lookupTable.resultFormat[0] - default: formatCode = entry.field.formatCode - } + return PostgresData( - type: entry.field.dataType, - typeModifier: entry.field.dataTypeModifier, - formatCode: formatCode, - value: self.dataRow.columns[entry.index].value + type: self.columns[index].dataType, + typeModifier: self.columns[index].dataTypeModifier, + formatCode: .binary, + value: self.data[column: index] ) } +} +extension PostgresRow: CustomStringConvertible { public var description: String { var row: [String: PostgresData] = [:] - for field in self.lookupTable.rowDescription.fields { - row[field.name] = self.column(field.name) + for cell in self { + row[cell.columnName] = PostgresData( + type: cell.dataType, + typeModifier: 0, + formatCode: cell.format, + value: cell.bytes + ) } return row.description } diff --git a/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift b/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift new file mode 100644 index 00000000..9619c182 --- /dev/null +++ b/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift @@ -0,0 +1,95 @@ +import NIOCore + +extension PostgresConnection.Configuration { + /// Legacy connection parameters structure. Replaced by ``PostgresConnection/Configuration/host`` etc. + @available(*, deprecated, message: "Use `Configuration.host` etc. instead.") + public struct Connection { + /// See ``PostgresConnection/Configuration/host``. + public var host: String + + /// See ``PostgresConnection/Configuration/port``. + public var port: Int + + /// See ``PostgresConnection/Configuration/Options-swift.struct/requireBackendKeyData``. + public var requireBackendKeyData: Bool = true + + /// See ``PostgresConnection/Configuration/Options-swift.struct/connectTimeout``. + public var connectTimeout: TimeAmount = .seconds(10) + + /// Create a configuration for connecting to a server. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + public init(host: String, port: Int = 5432) { + self.host = host + self.port = port + } + } + + /// Legacy authentication parameters structure. Replaced by ``PostgresConnection/Configuration/username`` etc. + @available(*, deprecated, message: "Use `Configuration.username` etc. instead.") + public struct Authentication { + /// See ``PostgresConnection/Configuration/username``. + public var username: String + + /// See ``PostgresConnection/Configuration/password``. + public var password: String? + + /// See ``PostgresConnection/Configuration/database``. + public var database: String? + + public init(username: String, database: String?, password: String?) { + self.username = username + self.database = database + self.password = password + } + } + + /// Accessor for legacy connection parameters. Replaced by ``PostgresConnection/Configuration/host`` etc. + @available(*, deprecated, message: "Use `Configuration.host` etc. instead.") + public var connection: Connection { + get { + var conn: Connection + switch self.endpointInfo { + case .connectTCP(let host, let port): + conn = .init(host: host, port: port) + case .bindUnixDomainSocket(_), .configureChannel(_): + conn = .init(host: "!invalid!", port: 0) // best we can do, really + } + conn.requireBackendKeyData = self.options.requireBackendKeyData + conn.connectTimeout = self.options.connectTimeout + return conn + } + set { + self.endpointInfo = .connectTCP(host: newValue.host, port: newValue.port) + self.options.connectTimeout = newValue.connectTimeout + self.options.requireBackendKeyData = newValue.requireBackendKeyData + } + } + + @available(*, deprecated, message: "Use `Configuration.username` etc. instead.") + public var authentication: Authentication { + get { + .init(username: self.username, database: self.database, password: self.password) + } + set { + self.username = newValue.username + self.password = newValue.password + self.database = newValue.database + } + } + + /// Legacy initializer. + /// Replaced by ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)`` etc. + @available(*, deprecated, message: "Use `init(host:port:username:password:database:tls:)` instead.") + public init(connection: Connection, authentication: Authentication, tls: TLS) { + self.init( + host: connection.host, port: connection.port, + username: authentication.username, password: authentication.password, database: authentication.database, + tls: tls + ) + self.options.connectTimeout = connection.connectTimeout + self.options.requireBackendKeyData = connection.requireBackendKeyData + } +} diff --git a/Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift new file mode 100644 index 00000000..da7c25d5 --- /dev/null +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift @@ -0,0 +1,119 @@ +import NIOCore + +extension PostgresMessage { + /// Authentication request returned by the server. + @available(*, deprecated, message: "Will be removed from public API") + public enum Authentication: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .authentication + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Authentication { + guard let type = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not read authentication message type") + } + switch type { + case 0: return .ok + case 3: return .plaintext + case 5: + guard let salt = buffer.readBytes(length: 4) else { + throw PostgresError.protocol("Could not parse MD5 salt from authentication message") + } + return .md5(salt) + case 10: + var mechanisms: [String] = [] + while buffer.readableBytes > 0 { + guard let nextString = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Could not parse SASL mechanisms from authentication message") + } + if nextString.isEmpty { + break + } + mechanisms.append(nextString) + } + guard buffer.readableBytes == 0 else { + throw PostgresError.protocol("Trailing data at end of SASL mechanisms authentication message") + } + return .saslMechanisms(mechanisms) + case 11: + guard let challengeData = buffer.readBytes(length: buffer.readableBytes) else { + throw PostgresError.protocol("Could not parse SASL challenge from authentication message") + } + return .saslContinue(challengeData) + case 12: + guard let finalData = buffer.readBytes(length: buffer.readableBytes) else { + throw PostgresError.protocol("Could not parse SASL final data from authentication message") + } + return .saslFinal(finalData) + + case 2, 7...9: + throw PostgresError.protocol("Support for KRBv5, GSSAPI, and SSPI authentication are not implemented") + case 6: + throw PostgresError.protocol("Support for SCM credential authentication is obsolete") + + default: + throw PostgresError.protocol("Unknown authentication request type: \(type)") + } + } + + public func serialize(into buffer: inout ByteBuffer) throws { + switch self { + case .ok: + buffer.writeInteger(0, as: Int32.self) + case .plaintext: + buffer.writeInteger(3, as: Int32.self) + case .md5(let salt): + buffer.writeInteger(5, as: Int32.self) + buffer.writeBytes(salt) + case .saslMechanisms(let mechanisms): + buffer.writeInteger(10, as: Int32.self) + mechanisms.forEach { + buffer.writeNullTerminatedString($0) + } + case .saslContinue(let challenge): + buffer.writeInteger(11, as: Int32.self) + buffer.writeBytes(challenge) + case .saslFinal(let data): + buffer.writeInteger(12, as: Int32.self) + buffer.writeBytes(data) + } + } + + /// AuthenticationOk + /// Specifies that the authentication was successful. + case ok + + /// AuthenticationCleartextPassword + /// Specifies that a clear-text password is required. + case plaintext + + /// AuthenticationMD5Password + /// Specifies that an MD5-encrypted password is required. + case md5([UInt8]) + + /// AuthenticationSASL + /// Specifies the start of SASL mechanism negotiation. + case saslMechanisms([String]) + + /// AuthenticationSASLContinue + /// Specifies SASL mechanism-specific challenge data. + case saslContinue([UInt8]) + + /// AuthenticationSASLFinal + /// Specifies mechanism-specific post-authentication client data. + case saslFinal([UInt8]) + + /// See `CustomStringConvertible`. + public var description: String { + switch self { + case .ok: return "Ok" + case .plaintext: return "CleartextPassword" + case .md5(let salt): return "MD5Password(salt: 0x\(salt.hexdigest()))" + case .saslMechanisms(let mech): return "SASLMechanisms(\(mech))" + case .saslContinue(let data): return "SASLChallenge(\(data))" + case .saslFinal(let data): return "SASLFinal(\(data))" + } + } + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift similarity index 89% rename from Sources/PostgresNIO/Message/PostgresMessage+Bind.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift index 3b7d250c..5ff4bbf0 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. + @available(*, deprecated, message: "Will be removed from public API") public struct Bind: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .bind @@ -26,7 +27,7 @@ extension PostgresMessage { /// This can be zero to indicate that there are no parameters or that the parameters all use the default format (text); /// or one, in which case the specified format code is applied to all parameters; or it can equal the actual number of parameters. /// The parameter format codes. Each must presently be zero (text) or one (binary). - public var parameterFormatCodes: [PostgresFormatCode] + public var parameterFormatCodes: [PostgresFormat] /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. public var parameters: [Parameter] @@ -35,12 +36,12 @@ extension PostgresMessage { /// This can be zero to indicate that there are no result columns or that the result columns should all use the default format (text); /// or one, in which case the specified format code is applied to all result columns (if any); /// or it can equal the actual number of result columns of the query. - public var resultFormatCodes: [PostgresFormatCode] + public var resultFormatCodes: [PostgresFormat] /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.write(nullTerminated: self.portalName) - buffer.write(nullTerminated: self.statementName) + buffer.writeNullTerminatedString(self.portalName) + buffer.writeNullTerminatedString(self.statementName) buffer.write(array: self.parameterFormatCodes) buffer.write(array: self.parameters) { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift similarity index 90% rename from Sources/PostgresNIO/Message/PostgresMessage+Close.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift index 69871df6..9bcc8aa1 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Close Command + @available(*, deprecated, message: "Will be removed from public API") public struct Close: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .close @@ -33,7 +34,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) throws { buffer.writeInteger(target.rawValue) - buffer.write(nullTerminated: name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift similarity index 88% rename from Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift index 8ac6e706..c9370402 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Close command. + @available(*, deprecated, message: "Will be removed from public API") public struct CommandComplete: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> CommandComplete { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift similarity index 89% rename from Sources/PostgresNIO/Message/PostgresMessage+Describe.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift index 6bfe20d1..787355db 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Describe command. + @available(*, deprecated, message: "Will be removed from public API") public struct Describe: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .describe @@ -31,7 +32,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { buffer.writeInteger(command.rawValue) - buffer.write(nullTerminated: name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift similarity index 85% rename from Sources/PostgresNIO/Message/PostgresMessage+Execute.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift index 8566355d..39b447a4 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as an Execute command. + @available(*, deprecated, message: "Will be removed from public API") public struct Execute: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .execute @@ -20,7 +21,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.write(nullTerminated: portalName) + buffer.writeNullTerminatedString(portalName) buffer.writeInteger(self.maxRows) } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift similarity index 91% rename from Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift index a3806d0f..89e67682 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a parameter description. + @available(*, deprecated, message: "Will be removed from public API") public struct ParameterDescription: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterDescription { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift similarity index 92% rename from Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift index 09939bef..5ad6f95e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift @@ -1,6 +1,7 @@ -import NIO +import NIOCore extension PostgresMessage { + @available(*, deprecated, message: "Will be removed from public API") public struct ParameterStatus: PostgresMessageType, CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterStatus { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift similarity index 93% rename from Sources/PostgresNIO/Message/PostgresMessage+Parse.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift index 749a6949..8fb5a1ff 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Parse command. + @available(*, deprecated, message: "Will be removed from public API") public struct Parse: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .parse diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift similarity index 91% rename from Sources/PostgresNIO/Message/PostgresMessage+Password.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift index f28463e7..cafe9cda 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift @@ -1,9 +1,10 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a password response. Note that this is also used for /// GSSAPI and SSPI response messages (which is really a design error, since the contained /// data is not a null-terminated string in that case, but can be arbitrary binary data). + @available(*, deprecated, message: "Will be removed from public API") public struct Password: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .passwordMessage diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift similarity index 93% rename from Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift index b05e833b..5afc0910 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. + @available(*, deprecated, message: "Will be removed from public API") public struct ReadyForQuery: CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ReadyForQuery { diff --git a/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift new file mode 100644 index 00000000..dc3b1772 --- /dev/null +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift @@ -0,0 +1,78 @@ +import NIOCore + +extension PostgresMessage { + /// SASL ongoing challenge response message sent by the client. + @available(*, deprecated, message: "Will be removed from public API") + public struct SASLResponse: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .saslResponse + } + + public let responseData: [UInt8] + + public static func parse(from buffer: inout ByteBuffer) throws -> SASLResponse { + guard let data = buffer.readBytes(length: buffer.readableBytes) else { + throw PostgresError.protocol("Could not parse SASL response from response message") + } + + return SASLResponse(responseData: data) + } + + public func serialize(into buffer: inout ByteBuffer) throws { + buffer.writeBytes(responseData) + } + + public var description: String { + return "SASLResponse(\(responseData))" + } + } +} + +extension PostgresMessage { + /// SASL initial challenge response message sent by the client. + @available(*, deprecated, message: "Will be removed from public API") + public struct SASLInitialResponse { + public let mechanism: String + public let initialData: [UInt8] + + public func serialize(into buffer: inout ByteBuffer) throws { + buffer.writeNullTerminatedString(self.mechanism) + if initialData.count > 0 { + buffer.writeInteger(Int32(initialData.count), as: Int32.self) // write(array:) writes Int16, which is incorrect here + buffer.writeBytes(initialData) + } else { + buffer.writeInteger(-1, as: Int32.self) + } + } + + public var description: String { + return "SASLInitialResponse(\(mechanism), data: \(initialData))" + } + } +} + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.SASLInitialResponse: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .saslInitialResponse + } + + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let mechanism = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Could not parse SASL mechanism from initial response message") + } + guard let dataLength = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse SASL initial data length from initial response message") + } + + var actualData: [UInt8] = [] + + if dataLength != -1 { + guard let data = buffer.readBytes(length: Int(dataLength)) else { + throw PostgresError.protocol("Could not parse SASL initial data from initial response message") + } + actualData = data + } + return .init(mechanism: mechanism, initialData: actualData) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift similarity index 90% rename from Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift index 9133d26a..ee504932 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift @@ -1,8 +1,9 @@ -import NIO +import NIOCore extension PostgresMessage { /// A message asking the PostgreSQL server if SSL is supported /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 + @available(*, deprecated, message: "Will be removed from public API") public struct SSLRequest: PostgresMessageType { /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, /// and 5679 in the least significant 16 bits. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift similarity index 85% rename from Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift index 80e106b5..a0a6cfcf 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a simple query. + @available(*, deprecated, message: "Will be removed from public API") public struct SimpleQuery: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .query diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift similarity index 95% rename from Sources/PostgresNIO/Message/PostgresMessage+Startup.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift index 25f68772..e9762439 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. + @available(*, deprecated, message: "Will be removed from public API") public struct Startup: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .none diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift similarity index 82% rename from Sources/PostgresNIO/Message/PostgresMessage+Sync.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift index 6e47cefb..0560ef7a 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift @@ -1,7 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. + @available(*, deprecated, message: "Will be removed from public API") public struct Sync: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .sync diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift similarity index 73% rename from Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift index 61227fdf..afeae5bf 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift @@ -1,4 +1,7 @@ +import NIOCore + extension PostgresMessage { + @available(*, deprecated, message: "Will be removed from public API") public struct Terminate: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { .terminate diff --git a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift b/Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift similarity index 96% rename from Sources/PostgresNIO/Message/PostgresMessageDecoder.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift index 9a64e827..e092c234 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift @@ -1,5 +1,7 @@ -import NIO +import NIOCore +import Logging +@available(*, deprecated, message: "Will be removed from public API") public final class PostgresMessageDecoder: ByteToMessageDecoder { /// See `ByteToMessageDecoder`. public typealias InboundOut = PostgresMessage diff --git a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift b/Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift similarity index 92% rename from Sources/PostgresNIO/Message/PostgresMessageEncoder.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift index 4eca9bc5..8dd4c38d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift @@ -1,5 +1,7 @@ -import NIO +import NIOCore +import Logging +@available(*, deprecated, message: "Will be removed from public API") public final class PostgresMessageEncoder: MessageToByteEncoder { /// See `MessageToByteEncoder`. public typealias OutboundIn = PostgresMessage diff --git a/Sources/PostgresNIO/Docs.docc/coding.md b/Sources/PostgresNIO/Docs.docc/coding.md new file mode 100644 index 00000000..3bcc4a7e --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/coding.md @@ -0,0 +1,39 @@ +# PostgreSQL data types + +Translate Swift data types to Postgres data types and vica versa. Learn how to write translations +for your own custom Swift types. + +## Topics + +### Essentials + +- ``PostgresCodable`` +- ``PostgresDataType`` +- ``PostgresFormat`` +- ``PostgresNumeric`` + +### Encoding + +- ``PostgresEncodable`` +- ``PostgresNonThrowingEncodable`` +- ``PostgresDynamicTypeEncodable`` +- ``PostgresThrowingDynamicTypeEncodable`` +- ``PostgresArrayEncodable`` +- ``PostgresRangeEncodable`` +- ``PostgresRangeArrayEncodable`` +- ``PostgresEncodingContext`` + +### Decoding + +- ``PostgresDecodable`` +- ``PostgresArrayDecodable`` +- ``PostgresRangeDecodable`` +- ``PostgresRangeArrayDecodable`` +- ``PostgresDecodingContext`` + +### JSON + +- ``PostgresJSONEncoder`` +- ``PostgresJSONDecoder`` + + diff --git a/Sources/PostgresNIO/Docs.docc/deprecated.md b/Sources/PostgresNIO/Docs.docc/deprecated.md new file mode 100644 index 00000000..a29465f6 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/deprecated.md @@ -0,0 +1,43 @@ +# Deprecations + +`PostgresNIO` follows SemVer 2.0.0. Learn which APIs are considered deprecated and how to migrate to +their replacements. + +``PostgresNIO`` reached 1.0 in April 2020. Since then the maintainers have been hard at work to +guarantee API stability. However as the Swift and Swift on server ecosystem have matured approaches +have changed. The introduction of structured concurrency changed what developers expect from a +modern Swift library. Because of this ``PostgresNIO`` added various APIs that embrace the new Swift +patterns. This means however, that PostgresNIO still offers APIs that have fallen out of favor. +Those are documented here. All those APIs will be removed once the maintainers release the next +major version. The maintainers recommend all adopters to move of those APIs sooner rather than +later. + +## Topics + +### Migrate of deprecated APIs + +- + +### Deprecated APIs + +These types are already deprecated or will be deprecated in the near future. All of them will be +removed from the public API with the next major release. + +- ``PostgresDatabase`` +- ``PostgresData`` +- ``PostgresDataConvertible`` +- ``PostgresQueryResult`` +- ``PostgresJSONCodable`` +- ``PostgresJSONBCodable`` +- ``PostgresMessageEncoder`` +- ``PostgresMessageDecoder`` +- ``PostgresRequest`` +- ``PostgresMessage`` +- ``PostgresMessageType`` +- ``PostgresFormatCode`` +- ``PostgresListenContext`` +- ``PreparedQuery`` +- ``SASLAuthenticationManager`` +- ``SASLAuthenticationMechanism`` +- ``SASLAuthenticationError`` +- ``SASLAuthenticationStepResult`` diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg new file mode 100644 index 00000000..a831189d --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md new file mode 100644 index 00000000..6355a7a4 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -0,0 +1,58 @@ +# ``PostgresNIO`` + +@Metadata { + @TitleHeading(Package) +} + +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on SwiftNIO. + +## Overview + +``PostgresNIO`` allows you to connect to, authorize with, query, and retrieve results from a +PostgreSQL server. PostgreSQL is an open source relational database. + +Use a ``PostgresConnection`` to create a connection to the PostgreSQL server. You can then use it to +run queries and prepared statements against the server. ``PostgresConnection`` also supports +PostgreSQL's Listen & Notify API. + +Developers, who don't want to manage connections themselves, can use the ``PostgresClient``, which +offers the same functionality as ``PostgresConnection``. ``PostgresClient`` +pools connections for rapid connection reuse and hides the complexities of connection +management from the user, allowing developers to focus on their SQL queries. ``PostgresClient`` +implements the `Service` protocol from Service Lifecycle allowing an easy adoption in Swift server +applications. + +``PostgresNIO`` embraces Swift structured concurrency, offering async/await APIs which handle +task cancellation. The query interface makes use of backpressure to ensure that memory can not grow +unbounded for queries that return thousands of rows. + +``PostgresNIO`` runs efficiently on Linux and Apple platforms. On Apple platforms developers can +configure ``PostgresConnection`` to use `Network.framework` as the underlying transport framework. + +## Topics + +### Essentials + +- ``PostgresClient`` +- ``PostgresClient/Configuration`` +- ``PostgresConnection`` +- + +### Advanced + +- +- +- + +### Errors + +- ``PostgresError`` +- ``PostgresDecodingError`` +- ``PSQLError`` + +### Deprecations + +- + +[SwiftNIO]: https://github.com/apple/swift-nio +[SwiftLog]: https://github.com/apple/swift-log diff --git a/Sources/PostgresNIO/Docs.docc/listen.md b/Sources/PostgresNIO/Docs.docc/listen.md new file mode 100644 index 00000000..10c5d8bf --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/listen.md @@ -0,0 +1,9 @@ +# Listen & Notify + +``PostgresNIO`` supports PostgreSQL's listen and notify API. Learn how to listen for changes and +notify other listeners. + +## Topics + +- ``PostgresNotification`` +- ``PostgresNotificationSequence`` diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md new file mode 100644 index 00000000..3a7c634a --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -0,0 +1,90 @@ +# Adopting the new PostgresRow cell API + +This article describes how to adopt the new ``PostgresRow`` cell APIs in existing Postgres code +which use the ``PostgresRow/column(_:)`` API today. + +## TLDR + +1. Map your sequence of ``PostgresRow``s to ``PostgresRandomAccessRow``s. +2. Use the ``PostgresRandomAccessRow/subscript(_:)-3facl`` API to receive a ``PostgresCell`` +3. Decode the ``PostgresCell`` into a Swift type using the ``PostgresCell/decode(_:file:line:)`` method. + +```swift +let rows: [PostgresRow] // your existing return value +for row in rows.map({ PostgresRandomAccessRow($0) }) { + let id = try row["id"].decode(UUID.self) + let name = try row["name"].decode(String.self) + let email = try row["email"].decode(String.self) + let age = try row["age"].decode(Int.self) +} +``` + +## Overview + +When Postgres [`1.9.0`] was released we changed the default behaviour of ``PostgresRow``s. +Previously for each row we created an internal lookup table, that allowed you to access the rows' +cells by name: + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows { + let id = row.column("id").uuid + let name = row.column("name").string + let email = row.column("email").string + let age = row.column("age").int + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +During the last year we introduced a new API that let's you consume ``PostgresRow`` by iterating +its cells. This approach has the performance benefit of not needing to create an internal cell +lookup table for each row: + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows { + let (id, name, email, age) = try row.decode((UUID, String, String, Int).self) + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +However, since we still supported the ``PostgresRow/column(_:)`` API, which requires a precomputed +lookup table within the row, users were not seeing any performance benefits. To allow users to +benefit of the new fastpath, we changed ``PostgresRow``'s behavior: + +By default the ``PostgresRow`` does not create an internal lookup table for its cells on creation +anymore. Because of this, when using the ``PostgresRow/column(_:)`` API, a throwaway lookup table +needs to be produced on every call. Since this is wasteful we have deprecated this API. Instead we +allow users now to explicitly opt-in into the cell lookup API by using the new +``PostgresRandomAccessRow``. + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows.map { PostgresRandomAccessRow($0) } { + let id = try row["id"].decode(UUID.self) + let name = try row["name"].decode(String.self) + let email = try row["email"].decode(String.self) + let age = try row["age"].decode(Int.self) + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +[`1.9.0`]: https://github.com/vapor/postgres-nio/releases/tag/1.9.0 diff --git a/Sources/PostgresNIO/Docs.docc/prepared-statement.md b/Sources/PostgresNIO/Docs.docc/prepared-statement.md new file mode 100644 index 00000000..ff4b1c62 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/prepared-statement.md @@ -0,0 +1,7 @@ +# Boosting Performance with Prepared Statements + +Improve performance by leveraging PostgreSQL's prepared statements. + +## Topics + +- ``PostgresPreparedStatement`` diff --git a/Sources/PostgresNIO/Docs.docc/running-queries.md b/Sources/PostgresNIO/Docs.docc/running-queries.md new file mode 100644 index 00000000..b2c4586f --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/running-queries.md @@ -0,0 +1,27 @@ +# Running Queries + +Interact with the PostgreSQL database by running Queries. + +## Overview + + + +You interact with the Postgres database by running SQL [Queries]. + + + +``PostgresQuery`` conforms to + + +## Topics + +- ``PostgresQuery`` +- ``PostgresBindings`` +- ``PostgresRow`` +- ``PostgresRowSequence`` +- ``PostgresRandomAccessRow`` +- ``PostgresCell`` +- ``PostgresQueryMetadata`` + +[Queries]: doc:PostgresQuery +[`ExpressibleByStringInterpolation`]: https://developer.apple.com/documentation/swift/expressiblebystringinterpolation diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json new file mode 100644 index 00000000..38914a04 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -0,0 +1,24 @@ +{ + "theme": { + "aside": { "border-radius": "16px", "border-style": "double", "border-width": "3px" }, + "border-radius": "0", + "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "color": { + "fill": { "dark": "#000", "light": "#fff" }, + "psqlnio": "#336791", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", + "documentation-intro-accent": "var(--color-psqlnio)", + "documentation-intro-eyebrow": "white", + "documentation-intro-figure": "white", + "documentation-intro-title": "white", + "logo-base": { "dark": "#fff", "light": "#000" }, + "logo-shape": { "dark": "#000", "light": "#fff" } + }, + "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } + }, + "features": { + "quickNavigation": { "enable": true }, + "i18n": { "enable": true } + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+0.swift b/Sources/PostgresNIO/Message/PostgresMessage+0.swift index 96fe0b37..386ffd34 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+0.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+0.swift @@ -1,8 +1,12 @@ +import NIOCore + /// A frontend or backend Postgres message. public struct PostgresMessage: Equatable { - public var identifier: Identifier + @available(*, deprecated, message: "Will be removed from public API.") + public var identifier: Identifier public var data: ByteBuffer + @available(*, deprecated, message: "Will be removed from public API.") public init(identifier: Identifier, bytes: Data) where Data: Sequence, Data.Element == UInt8 { @@ -11,6 +15,7 @@ public struct PostgresMessage: Equatable { self.init(identifier: identifier, data: buffer) } + @available(*, deprecated, message: "Will be removed from public API.") public init(identifier: Identifier, data: ByteBuffer) { self.identifier = identifier self.data = data diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift deleted file mode 100644 index 20a64ba7..00000000 --- a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift +++ /dev/null @@ -1,61 +0,0 @@ -import NIO - -extension PostgresMessage { - /// Authentication request returned by the server. - public enum Authentication: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - return .authentication - } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> Authentication { - guard let type = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not read authentication message type") - } - switch type { - case 0: return .ok - case 3: return .plaintext - case 5: - guard let salt = buffer.readBytes(length: 4) else { - throw PostgresError.protocol("Could not parse MD5 salt from authentication message") - } - return .md5(salt) - default: - throw PostgresError.protocol("Unkonwn authentication request type: \(type)") - } - } - - public func serialize(into buffer: inout ByteBuffer) throws { - switch self { - case .ok: - buffer.writeInteger(0, as: Int32.self) - case .plaintext: - buffer.writeInteger(3, as: Int32.self) - case .md5(let salt): - buffer.writeInteger(5, as: Int32.self) - buffer.writeBytes(salt) - } - } - - /// AuthenticationOk - /// Specifies that the authentication was successful. - case ok - - /// AuthenticationCleartextPassword - /// Specifies that a clear-text password is required. - case plaintext - - /// AuthenticationMD5Password - /// Specifies that an MD5-encrypted password is required. - case md5([UInt8]) - - /// See `CustomStringConvertible`. - public var description: String { - switch self { - case .ok: return "Ok" - case .plaintext: return "CleartextPassword" - case .md5(let salt): return "MD5Password(salt: 0x\(salt.hexdigest()))" - } - } - } -} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift index f25994d8..63a6af7d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift @@ -1,24 +1,9 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as cancellation key data. /// The frontend must save these values if it wishes to be able to issue CancelRequest messages later. - public struct BackendKeyData: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - .backendKeyData - } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> BackendKeyData { - guard let processID = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not parse process id from backend key data") - } - guard let secretKey = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not parse secret key from backend key data") - } - return .init(processID: processID, secretKey: secretKey) - } - + public struct BackendKeyData { /// The process ID of this backend. public var processID: Int32 @@ -26,3 +11,21 @@ extension PostgresMessage { public var secretKey: Int32 } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.BackendKeyData: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + .backendKeyData + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let processID = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse process id from backend key data") + } + guard let secretKey = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse secret key from backend key data") + } + return .init(processID: processID, secretKey: secretKey) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift index 2ead1b3a..655bfb1e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift @@ -1,12 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a data row. - public struct DataRow: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - return .dataRow - } - + public struct DataRow { public struct Column: CustomStringConvertible { /// The length of the column value, in bytes (this count does not include itself). /// Can be zero. As a special case, -1 indicates a NULL column value. No value bytes follow in the NULL case. @@ -23,23 +19,7 @@ extension PostgresMessage { } } } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> DataRow { - guard let columns = buffer.read(array: Column.self, { buffer in - if var slice = buffer.readNullableBytes() { - var copy = ByteBufferAllocator().buffer(capacity: slice.readableBytes) - copy.writeBuffer(&slice) - return .init(value: copy) - } else { - return .init(value: nil) - } - }) else { - throw PostgresError.protocol("Could not parse data row columns") - } - return .init(columns: columns) - } - + /// The data row's columns public var columns: [Column] @@ -49,3 +29,26 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.DataRow: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .dataRow + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let columns = buffer.read(array: Column.self, { buffer in + if var slice = buffer.readNullableBytes() { + var copy = ByteBufferAllocator().buffer(capacity: slice.readableBytes) + copy.writeBuffer(&slice) + return .init(value: copy) + } else { + return .init(value: nil) + } + }) else { + throw PostgresError.protocol("Could not parse data row columns") + } + return .init(columns: columns) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 9b0a18cd..45cda21f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -1,25 +1,9 @@ -import NIO +import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. - public struct Error: PostgresMessageType, CustomStringConvertible { - public static var identifier: PostgresMessage.Identifier { - return .error - } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> Error { - var fields: [Field: String] = [:] - while let field = buffer.readInteger(as: Field.self) { - guard let string = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Could not read error response string.") - } - fields[field] = string - } - return .init(fields: fields) - } - - public enum Field: UInt8, Hashable { + public struct Error: CustomStringConvertible, Sendable { + public enum Field: UInt8, Hashable, Sendable { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a //// localized translation of one of these. Always present. @@ -108,3 +92,22 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.Error: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .error + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + var fields: [Field: String] = [:] + while let field = buffer.readInteger(as: Field.self) { + guard let string = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Could not read error response string.") + } + fields[field] = string + } + return .init(fields: fields) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 92b8d2e4..5d111e3b 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -1,9 +1,10 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. /// Values are not unique across all identifiers, meaning some messages will require keeping state to identify. - public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { + @available(*, deprecated, message: "Will be removed from public API.") + public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { // special public static let none: Identifier = 0x00 // special @@ -132,7 +133,7 @@ extension PostgresMessage { /// See `CustomStringConvertible`. public var description: String { - return String(Character(Unicode.Scalar(value))) + return String(Unicode.Scalar(self.value)) } /// See `ExpressibleByIntegerLiteral`. @@ -143,6 +144,7 @@ extension PostgresMessage { } extension ByteBuffer { + @available(*, deprecated, message: "Will be removed from public API") mutating func write(identifier: PostgresMessage.Identifier) { self.writeInteger(identifier.value) } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift index b381bfaf..1a3b596d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift @@ -1,26 +1,29 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a notification response. - public struct NotificationResponse: PostgresMessageType { - public static let identifier = Identifier.notificationResponse - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> Self { - guard let backendPID: Int32 = buffer.readInteger() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") - } - guard let channel = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") - } - guard let payload = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") - } - return .init(backendPID: backendPID, channel: channel, payload: payload) - } - + public struct NotificationResponse { public var backendPID: Int32 public var channel: String public var payload: String } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.NotificationResponse: PostgresMessageType { + public static let identifier = PostgresMessage.Identifier.notificationResponse + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let backendPID: Int32 = buffer.readInteger() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") + } + guard let channel = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") + } + guard let payload = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") + } + return .init(backendPID: backendPID, channel: channel, payload: payload) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index 61aec62b..5713cc99 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -1,13 +1,8 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a row description. - public struct RowDescription: PostgresMessageType { - /// See `PostgresMessageType`. - public static var identifier: PostgresMessage.Identifier { - return .rowDescription - } - + public struct RowDescription { /// Describes a single field returns in a `RowDescription` message. public struct Field: CustomStringConvertible { static func parse(from buffer: inout ByteBuffer) throws -> Field { @@ -29,7 +24,7 @@ extension PostgresMessage { guard let dataTypeModifier = buffer.readInteger(as: Int32.self) else { throw PostgresError.protocol("Could not read row description field data type modifier") } - guard let formatCode = buffer.readInteger(as: PostgresFormatCode.self) else { + guard let formatCode = buffer.readInteger(as: PostgresFormat.self) else { throw PostgresError.protocol("Could not read row description field format code") } return .init( @@ -65,7 +60,7 @@ extension PostgresMessage { /// Currently will be zero (text) or one (binary). /// In a RowDescription returned from the statement variant of Describe, /// the format code is not yet known and will always be zero. - public var formatCode: PostgresFormatCode + public var formatCode: PostgresFormat /// See `CustomStringConvertible`. public var description: String { @@ -73,15 +68,7 @@ extension PostgresMessage { } } - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> RowDescription { - guard let fields = try buffer.read(array: Field.self, { buffer in - return try.parse(from: &buffer) - }) else { - throw PostgresError.protocol("Could not read row description fields") - } - return .init(fields: fields) - } + /// The fields supplied in the row description. public var fields: [Field] @@ -92,3 +79,21 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.RowDescription: PostgresMessageType { + /// See `PostgresMessageType`. + public static var identifier: PostgresMessage.Identifier { + return .rowDescription + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let fields = try buffer.read(array: Field.self, { buffer in + return try.parse(from: &buffer) + }) else { + throw PostgresError.protocol("Could not read row description fields") + } + return .init(fields: fields) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessageType.swift b/Sources/PostgresNIO/Message/PostgresMessageType.swift index 9a69fa30..170c4aec 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageType.swift +++ b/Sources/PostgresNIO/Message/PostgresMessageType.swift @@ -1,10 +1,15 @@ +import NIOCore + +@available(*, deprecated, message: "Will be removed from public API. Internally we now use `PostgresBackendMessage` and `PostgresFrontendMessage`") public protocol PostgresMessageType { static var identifier: PostgresMessage.Identifier { get } static func parse(from buffer: inout ByteBuffer) throws -> Self func serialize(into buffer: inout ByteBuffer) throws } +@available(*, deprecated, message: "`PostgresMessageType` protocol is deprecated.") extension PostgresMessageType { + @available(*, deprecated, message: "Will be removed from public API.") func message() throws -> PostgresMessage { var buffer = ByteBufferAllocator().buffer(capacity: 0) try self.serialize(into: &buffer) @@ -15,7 +20,8 @@ extension PostgresMessageType { var message = message self = try Self.parse(from: &message.data) } - + + @available(*, deprecated, message: "Will be removed from public API.") public static var identifier: PostgresMessage.Identifier { return .none } diff --git a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift new file mode 100644 index 00000000..245e5efd --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift @@ -0,0 +1,226 @@ +import NIOCore + +struct AuthenticationStateMachine { + + enum State { + case initialized + case startupMessageSent + case passwordAuthenticationSent + + case saslInitialResponseSent(SASLAuthenticationManager) + case saslChallengeResponseSent(SASLAuthenticationManager) + case saslFinalReceived + + case error(PSQLError) + case authenticated + } + + enum Action { + case sendStartupMessage(AuthContext) + case sendPassword(PasswordAuthencationMode, AuthContext) + case sendSaslInitialResponse(name: String, initialResponse: [UInt8]) + case sendSaslResponse([UInt8]) + case wait + case authenticated + + case reportAuthenticationError(PSQLError) + } + + let authContext: AuthContext + var state: State + + init(authContext: AuthContext) { + self.authContext = authContext + self.state = .initialized + } + + mutating func start() -> Action { + guard case .initialized = self.state else { + preconditionFailure("Unexpected state") + } + self.state = .startupMessageSent + return .sendStartupMessage(self.authContext) + } + + mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> Action { + switch self.state { + case .startupMessageSent: + switch message { + case .ok: + self.state = .authenticated + return .authenticated + case .md5(let salt): + guard self.authContext.password != nil else { + return self.setAndFireError(PSQLError(code: .authMechanismRequiresPassword)) + } + self.state = .passwordAuthenticationSent + return .sendPassword(.md5(salt: salt), self.authContext) + case .plaintext: + self.state = .passwordAuthenticationSent + return .sendPassword(.cleartext, authContext) + case .kerberosV5: + return self.setAndFireError(.unsupportedAuthMechanism(.kerberosV5)) + case .scmCredential: + return self.setAndFireError(.unsupportedAuthMechanism(.scmCredential)) + case .gss: + return self.setAndFireError(.unsupportedAuthMechanism(.gss)) + case .sspi: + return self.setAndFireError(.unsupportedAuthMechanism(.sspi)) + case .sasl(let mechanisms): + guard mechanisms.contains(SASLMechanism.SCRAM.SHA256.name) else { + return self.setAndFireError(.unsupportedAuthMechanism(.sasl(mechanisms: mechanisms))) + } + + guard let password = self.authContext.password else { + return self.setAndFireError(.authMechanismRequiresPassword) + } + + let saslManager = SASLAuthenticationManager(asClientSpeaking: + SASLMechanism.SCRAM.SHA256(username: self.authContext.username, password: { password })) + + do { + var bytes: [UInt8]? + let done = try saslManager.handle(message: nil, sender: { bytes = $0 }) + // TODO: Gwynne reminds herself to refactor `SASLAuthenticationManager` to + // be async instead of very badly done synchronous. + + guard let output = bytes, done == false else { + preconditionFailure("TODO: SASL auth is always a three step process in Postgres.") + } + + self.state = .saslInitialResponseSent(saslManager) + return .sendSaslInitialResponse(name: SASLMechanism.SCRAM.SHA256.name, initialResponse: output) + } catch { + return self.setAndFireError(.sasl(underlying: error)) + } + case .gssContinue, + .saslContinue, + .saslFinal: + return self.setAndFireError(.unexpectedBackendMessage(.authentication(message))) + } + case .passwordAuthenticationSent, .saslFinalReceived: + guard case .ok = message else { + return self.setAndFireError(.unexpectedBackendMessage(.authentication(message))) + } + + self.state = .authenticated + return .authenticated + + case .saslInitialResponseSent(let saslManager): + guard case .saslContinue(data: var data) = message else { + return self.setAndFireError(.unexpectedBackendMessage(.authentication(message))) + } + + let input = data.readBytes(length: data.readableBytes) + + do { + var bytes: [UInt8]? + let done = try saslManager.handle(message: input, sender: { bytes = $0 }) + + guard let output = bytes, done == false else { + preconditionFailure("TODO: SASL auth is always a three step process in Postgres.") + } + + self.state = .saslChallengeResponseSent(saslManager) + return .sendSaslResponse(output) + } catch { + return self.setAndFireError(.sasl(underlying: error)) + } + + case .saslChallengeResponseSent(let saslManager): + guard case .saslFinal(data: var data) = message else { + return self.setAndFireError(.unexpectedBackendMessage(.authentication(message))) + } + + let input = data.readBytes(length: data.readableBytes) + + do { + var bytes: [UInt8]? + let done = try saslManager.handle(message: input, sender: { bytes = $0 }) + + guard bytes == nil, done == true else { + preconditionFailure("TODO: SASL auth is always a three step process in Postgres.") + } + + self.state = .saslFinalReceived + return .wait + } catch { + return self.setAndFireError(.sasl(underlying: error)) + } + + case .initialized: + preconditionFailure("Invalid state") + + case .authenticated, .error: + preconditionFailure("This state machine must not receive messages, after authenticate or error") + } + } + + mutating func errorReceived(_ message: PostgresBackendMessage.ErrorResponse) -> Action { + return self.setAndFireError(.server(message)) + } + + mutating func errorHappened(_ error: PSQLError) -> Action { + return self.setAndFireError(error) + } + + private mutating func setAndFireError(_ error: PSQLError) -> Action { + switch self.state { + case .initialized: + preconditionFailure(""" + The `AuthenticationStateMachine` must be immidiatly started after creation. + """) + case .startupMessageSent, + .passwordAuthenticationSent, + .saslInitialResponseSent, + .saslChallengeResponseSent, + .saslFinalReceived: + self.state = .error(error) + return .reportAuthenticationError(error) + case .authenticated, .error: + preconditionFailure(""" + This state must not be reached. If the auth state `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + } + } + + var isComplete: Bool { + switch self.state { + case .authenticated, .error: + return true + case .initialized, + .startupMessageSent, + .passwordAuthenticationSent, + .saslInitialResponseSent, + .saslChallengeResponseSent, + .saslFinalReceived: + return false + } + } +} + +extension AuthenticationStateMachine.State: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .initialized: + return ".initialized" + case .startupMessageSent: + return ".startupMessageSent" + case .passwordAuthenticationSent: + return ".passwordAuthenticationSent" + + case .saslInitialResponseSent(let saslManager): + return ".saslInitialResponseSent(\(String(reflecting: saslManager)))" + case .saslChallengeResponseSent(let saslManager): + return ".saslChallengeResponseSent(\(String(reflecting: saslManager)))" + case .saslFinalReceived: + return ".saslFinalReceived" + + case .error(let error): + return ".error(\(String(reflecting: error)))" + case .authenticated: + return ".authenticated" + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift new file mode 100644 index 00000000..791cebdd --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift @@ -0,0 +1,99 @@ + +struct CloseStateMachine { + + enum State { + case initialized(CloseCommandContext) + case closeSyncSent(CloseCommandContext) + case closeCompleteReceived + + case error(PSQLError) + } + + enum Action { + case sendCloseSync(CloseTarget) + case succeedClose(CloseCommandContext) + case failClose(CloseCommandContext, with: PSQLError) + + case read + case wait + } + + var state: State + + init(closeContext: CloseCommandContext) { + self.state = .initialized(closeContext) + } + + mutating func start() -> Action { + guard case .initialized(let closeContext) = self.state else { + preconditionFailure("Start should only be called, if the query has been initialized") + } + + self.state = .closeSyncSent(closeContext) + + return .sendCloseSync(closeContext.target) + } + + mutating func closeCompletedReceived() -> Action { + guard case .closeSyncSent(let closeContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.closeComplete)) + } + + self.state = .closeCompleteReceived + return .succeedClose(closeContext) + } + + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { + let error = PSQLError.server(errorMessage) + switch self.state { + case .initialized: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + + case .closeSyncSent: + return self.setAndFireError(error) + + case .closeCompleteReceived: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + + case .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + } + } + + mutating func errorHappened(_ error: PSQLError) -> Action { + return self.setAndFireError(error) + } + + // MARK: Channel actions + + mutating func readEventCaught() -> Action { + return .read + } + + var isComplete: Bool { + switch self.state { + case .closeCompleteReceived, .error: + return true + case .initialized, .closeSyncSent: + return false + } + } + + // MARK: Private Methods + + private mutating func setAndFireError(_ error: PSQLError) -> Action { + switch self.state { + case .closeSyncSent(let closeContext): + self.state = .error(error) + return .failClose(closeContext, with: error) + case .initialized, .closeCompleteReceived, .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift new file mode 100644 index 00000000..9d264bcc --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -0,0 +1,1213 @@ +import NIOCore + +struct ConnectionStateMachine { + + typealias TransactionState = PostgresBackendMessage.TransactionState + + struct ConnectionContext { + let backendKeyData: Optional + var parameters: [String: String] + var transactionState: TransactionState + } + + struct BackendKeyData { + let processID: Int32 + let secretKey: Int32 + } + + enum State { + enum TLSConfiguration { + case prefer + case require + } + + case initialized + case sslRequestSent(TLSConfiguration) + case sslNegotiated + case sslHandlerAdded + case waitingToStartAuthentication + case authenticating(AuthenticationStateMachine) + case authenticated(BackendKeyData?, [String: String]) + + case readyForQuery(ConnectionContext) + case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) + case closeCommand(CloseStateMachine, ConnectionContext) + + case closing(PSQLError?) + case closed(clientInitiated: Bool, error: PSQLError?) + + case modifying + } + + enum QuiescingState { + case notQuiescing + case quiescing(closePromise: EventLoopPromise?) + } + + enum ConnectionAction { + + struct CleanUpContext { + enum Action { + case close + case fireChannelInactive + } + + let action: Action + + /// Tasks to fail with the error + let tasks: [PSQLTask] + + let error: PSQLError + + let closePromise: EventLoopPromise? + } + + case read + case wait + case sendSSLRequest + case establishSSLConnection + case provideAuthenticationContext + case forwardNotificationToListeners(PostgresBackendMessage.NotificationResponse) + case fireEventReadyForQuery + case fireChannelInactive + /// Close the connection by sending a `Terminate` message and then closing the connection. This is for clean shutdowns. + case closeConnection(EventLoopPromise?) + + /// Close connection because of an error state. Fail all tasks with the provided error. + case closeConnectionAndCleanup(CleanUpContext) + + // Auth Actions + case sendStartupMessage(AuthContext) + case sendPasswordMessage(PasswordAuthencationMode, AuthContext) + case sendSaslInitialResponse(name: String, initialResponse: [UInt8]) + case sendSaslResponse([UInt8]) + + // Connection Actions + + // --- general actions + case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendBindExecuteSync(PSQLExecuteStatement) + case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + case succeedQuery(EventLoopPromise, with: QueryResult) + + // --- streaming actions + // actions if query has requested next row but we are waiting for backend + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) + case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) + + // Prepare statement actions + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + + // Close actions + case sendCloseSync(CloseTarget) + case succeedClose(CloseCommandContext) + case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?) + } + + private var state: State + private let requireBackendKeyData: Bool + private var taskQueue = CircularBuffer() + private var quiescingState: QuiescingState = .notQuiescing + + init(requireBackendKeyData: Bool) { + self.state = .initialized + self.requireBackendKeyData = requireBackendKeyData + } + + #if DEBUG + /// for testing purposes only + init(_ state: State, requireBackendKeyData: Bool = true) { + self.state = state + self.requireBackendKeyData = requireBackendKeyData + } + #endif + + enum TLSConfiguration { + case disable + case prefer + case require + } + + mutating func connected(tls: TLSConfiguration) -> ConnectionAction { + switch self.state { + case .initialized: + switch tls { + case .disable: + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext + + case .prefer: + self.state = .sslRequestSent(.prefer) + return .sendSSLRequest + + case .require: + self.state = .sslRequestSent(.require) + return .sendSSLRequest + } + + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .closeCommand, + .closing, + .closed, + .modifying: + return .wait + } + } + + mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction { + self.startAuthentication(authContext) + } + + mutating func gracefulClose(_ promise: EventLoopPromise?) -> ConnectionAction { + switch self.state { + case .closing, .closed: + // we are already closed, but sometimes an upstream handler might want to close the + // connection, though it has already been closed by the remote. Typical race condition. + return .closeConnection(promise) + case .readyForQuery: + precondition(self.taskQueue.isEmpty, """ + The state should only be .readyForQuery if there are no more tasks in the queue + """) + self.state = .closing(nil) + return .closeConnection(promise) + default: + switch self.quiescingState { + case .notQuiescing: + self.quiescingState = .quiescing(closePromise: promise) + case .quiescing(.some(let closePromise)): + closePromise.futureResult.cascade(to: promise) + case .quiescing(.none): + self.quiescingState = .quiescing(closePromise: promise) + } + return .wait + } + } + + mutating func close(promise: EventLoopPromise?) -> ConnectionAction { + return self.closeConnectionAndCleanup(.clientClosedConnection(underlying: nil), closePromise: promise) + } + + mutating func closed() -> ConnectionAction { + switch self.state { + case .initialized: + preconditionFailure("How can a connection be closed, if it was never connected.") + + case .closed: + return .wait + + case .authenticated, + .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .readyForQuery, + .extendedQuery, + .closeCommand: + return self.errorHappened(.serverClosedConnection(underlying: nil)) + + case .closing(let error): + self.state = .closed(clientInitiated: true, error: error) + self.quiescingState = .notQuiescing + return .fireChannelInactive + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func sslSupportedReceived(unprocessedBytes: Int) -> ConnectionAction { + switch self.state { + case .sslRequestSent: + if unprocessedBytes > 0 { + return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest) + } + self.state = .sslNegotiated + return .establishSSLConnection + + case .initialized, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .closeCommand, + .closing, + .closed: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func sslUnsupportedReceived() -> ConnectionAction { + switch self.state { + case .sslRequestSent(.require): + return self.closeConnectionAndCleanup(.sslUnsupported) + + case .sslRequestSent(.prefer): + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext + + case .initialized, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .closeCommand, + .closing, + .closed: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func sslHandlerAdded() -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .closeCommand, + .closing, + .closed: + preconditionFailure("Can only add a ssl handler after negotiation: \(self.state)") + + case .sslNegotiated: + self.state = .sslHandlerAdded + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func sslEstablished() -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .closeCommand, + .closing, + .closed: + preconditionFailure("Can only establish a ssl connection after adding a ssl handler: \(self.state)") + + case .sslHandlerAdded: + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> ConnectionAction { + guard case .authenticating(var authState) = self.state else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.authentication(message))) + } + + self.state = .modifying // avoid CoW + let action = authState.authenticationMessageReceived(message) + self.state = .authenticating(authState) + return self.modify(with: action) + } + + mutating func backendKeyDataReceived(_ keyData: PostgresBackendMessage.BackendKeyData) -> ConnectionAction { + guard case .authenticated(_, let parameters) = self.state else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.backendKeyData(keyData))) + } + + let keyData = BackendKeyData( + processID: keyData.processID, + secretKey: keyData.secretKey) + + self.state = .authenticated(keyData, parameters) + return .wait + } + + mutating func parameterStatusReceived(_ status: PostgresBackendMessage.ParameterStatus) -> ConnectionAction { + switch self.state { + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .closing: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterStatus(status))) + case .authenticated(let keyData, var parameters): + self.state = .modifying // avoid CoW + parameters[status.parameter] = status.value + self.state = .authenticated(keyData, parameters) + return .wait + + case .readyForQuery(var connectionContext): + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .readyForQuery(connectionContext) + return .wait + + case .extendedQuery(let query, var connectionContext): + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .extendedQuery(query, connectionContext) + return .wait + + case .closeCommand(let closeState, var connectionContext): + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .closeCommand(closeState, connectionContext) + return .wait + + case .initialized, + .closed: + preconditionFailure("We shouldn't receive messages if we are not connected") + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> ConnectionAction { + switch self.state { + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticated, + .readyForQuery: + return self.closeConnectionAndCleanup(.server(errorMessage)) + case .authenticating(var authState): + if authState.isComplete { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) + } + self.state = .modifying // avoid CoW + let action = authState.errorReceived(errorMessage) + self.state = .authenticating(authState) + return self.modify(with: action) + + case .closeCommand(var closeStateMachine, let connectionContext): + if closeStateMachine.isComplete { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) + } + self.state = .modifying // avoid CoW + let action = closeStateMachine.errorReceived(errorMessage) + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) + + case .extendedQuery(var extendedQueryState, let connectionContext): + if extendedQueryState.isComplete { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) + } + self.state = .modifying // avoid CoW + let action = extendedQueryState.errorReceived(errorMessage) + self.state = .extendedQuery(extendedQueryState, connectionContext) + return self.modify(with: action) + + case .closing: + // If the state machine is in state `.closing`, the connection shutdown was initiated + // by the client. This means a `TERMINATE` message has already been sent and the + // connection close was passed on to the channel. Therefore we await a channelInactive + // as the next event. + // Since a connection close was already issued, we should keep cool and just wait. + return .wait + case .initialized, .closed: + preconditionFailure("We should not receive server errors if we are not connected") + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func errorHappened(_ error: PSQLError) -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticated, + .readyForQuery: + return self.closeConnectionAndCleanup(error) + case .authenticating(var authState): + let action = authState.errorHappened(error) + return self.modify(with: action) + case .extendedQuery(var queryState, _): + if queryState.isComplete { + return self.closeConnectionAndCleanup(error) + } else { + let action = queryState.errorHappened(error) + return self.modify(with: action) + } + case .closeCommand(var closeState, _): + if closeState.isComplete { + return self.closeConnectionAndCleanup(error) + } else { + let action = closeState.errorHappened(error) + return self.modify(with: action) + } + case .closing: + // If the state machine is in state `.closing`, the connection shutdown was initiated + // by the client. This means a `TERMINATE` message has already been sent and the + // connection close was passed on to the channel. Therefore we await a channelInactive + // as the next event. + // For some reason Azure Postgres does not end ssl cleanly when terminating the + // connection. More documentation can be found in the issue: + // https://github.com/vapor/postgres-nio/issues/150 + // Since a connection close was already issued, we should keep cool and just wait. + return .wait + case .closed: + return self.closeConnectionAndCleanup(error) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> ConnectionAction { + switch self.state { + case .extendedQuery(var extendedQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = extendedQuery.noticeReceived(notice) + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + + default: + return .wait + } + } + + mutating func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) -> ConnectionAction { + return .forwardNotificationToListeners(notification) + } + + mutating func readyForQueryReceived(_ transactionState: PostgresBackendMessage.TransactionState) -> ConnectionAction { + switch self.state { + case .authenticated(let backendKeyData, let parameters): + if self.requireBackendKeyData && backendKeyData == nil { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + + let connectionContext = ConnectionContext( + backendKeyData: backendKeyData, + parameters: parameters, + transactionState: transactionState) + + self.state = .readyForQuery(connectionContext) + return self.executeNextQueryFromQueue() + case .extendedQuery(let extendedQuery, var connectionContext): + guard extendedQuery.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + + connectionContext.transactionState = transactionState + + self.state = .readyForQuery(connectionContext) + return self.executeNextQueryFromQueue() + case .closeCommand(let closeStateMachine, var connectionContext): + guard closeStateMachine.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + + connectionContext.transactionState = transactionState + + self.state = .readyForQuery(connectionContext) + return self.executeNextQueryFromQueue() + + default: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + } + + mutating func enqueue(task: PSQLTask) -> ConnectionAction { + let psqlErrror: PSQLError + + // check if we are quiescing. if so fail task immidiatly + switch self.quiescingState { + case .quiescing: + psqlErrror = PSQLError.clientClosedConnection(underlying: nil) + + case .notQuiescing: + switch self.state { + case .initialized, + .authenticated, + .authenticating, + .closeCommand, + .extendedQuery, + .sslNegotiated, + .sslHandlerAdded, + .sslRequestSent, + .waitingToStartAuthentication: + self.taskQueue.append(task) + return .wait + + case .readyForQuery: + return self.executeTask(task) + + case .closing(let error): + psqlErrror = PSQLError.clientClosedConnection(underlying: error) + + case .closed(clientInitiated: true, error: let error): + psqlErrror = PSQLError.clientClosedConnection(underlying: error) + + case .closed(clientInitiated: false, error: let error): + psqlErrror = PSQLError.serverClosedConnection(underlying: error) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + switch task { + case .extendedQuery(let queryContext): + switch queryContext.query { + case .executeStatement(_, let promise), .unnamed(_, let promise): + return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .prepareStatement(_, _, _, let promise): + return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) + } + case .closeCommand(let closeContext): + return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) + } + } + + mutating func channelReadComplete() -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .closeCommand, + .closing, + .closed: + return .wait + + case .extendedQuery(var extendedQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = extendedQuery.channelReadComplete() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func readEventCaught() -> ConnectionAction { + switch self.state { + case .initialized: + preconditionFailure("Invalid state: \(self.state). Read event before connection established?") + + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .closing: + // all states in which we definitely want to make further forward progress... + return .read + + case .extendedQuery(var extendedQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = extendedQuery.readEventCaught() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + + case .closeCommand(var closeState, let connectionContext): + self.state = .modifying // avoid CoW + let action = closeState.readEventCaught() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) + + case .closed: + // Generally we shouldn't see this event (read after connection closed?!). + // But truth is, adopters run into this, again and again. So preconditioning here leads + // to unnecessary crashes. So let's be resilient and just make more forward progress. + // If we really care, we probably need to dive deep into PostgresNIO and SwiftNIO. + return .read + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + // MARK: - Running Queries - + + mutating func parseCompleteReceived() -> ConnectionAction { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.parseCompletedReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + default: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parseComplete)) + } + } + + mutating func bindCompleteReceived() -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.bindComplete)) + } + + self.state = .modifying // avoid CoW + let action = queryState.bindCompleteReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + mutating func parameterDescriptionReceived(_ description: PostgresBackendMessage.ParameterDescription) -> ConnectionAction { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.parameterDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + default: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterDescription(description))) + } + } + + mutating func rowDescriptionReceived(_ description: RowDescription) -> ConnectionAction { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.rowDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + default: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) + } + } + + mutating func noDataReceived() -> ConnectionAction { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.noDataReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + default: + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.noData)) + } + } + + mutating func portalSuspendedReceived() -> ConnectionAction { + self.closeConnectionAndCleanup(.unexpectedBackendMessage(.portalSuspended)) + } + + mutating func closeCompletedReceived() -> ConnectionAction { + guard case .closeCommand(var closeState, let connectionContext) = self.state, !closeState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.closeComplete)) + } + + self.state = .modifying // avoid CoW + let action = closeState.closeCompletedReceived() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) + } + + mutating func commandCompletedReceived(_ commandTag: String) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag))) + } + + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + mutating func emptyQueryResponseReceived() -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + // MARK: Consumer + + mutating func cancelQueryStream() -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Tried to cancel stream without active query") + } + + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + mutating func requestQueryRows() -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + preconditionFailure("Tried to consume next row, without active query") + } + + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + // MARK: - Private Methods - + + private mutating func startAuthentication(_ authContext: AuthContext) -> ConnectionAction { + guard case .waitingToStartAuthentication = self.state else { + preconditionFailure("Can only start authentication after connect or ssl establish") + } + + self.state = .modifying // avoid CoW + var authState = AuthenticationStateMachine(authContext: authContext) + let action = authState.start() + self.state = .authenticating(authState) + return self.modify(with: action) + } + + private mutating func closeConnectionAndCleanup(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticated, + .readyForQuery: + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + return .closeConnectionAndCleanup(cleanupContext) + + case .authenticating(var authState): + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + + if authState.isComplete { + // in case the auth state machine is complete all necessary actions have already + // been forwarded to the consumer. We can close and cleanup without caring about the + // substate machine. + return .closeConnectionAndCleanup(cleanupContext) + } + + let action = authState.errorHappened(error) + guard case .reportAuthenticationError = action else { + preconditionFailure("Expect to fail auth") + } + return .closeConnectionAndCleanup(cleanupContext) + + case .extendedQuery(var queryStateMachine, _): + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + + if queryStateMachine.isComplete { + // in case the query state machine is complete all necessary actions have already + // been forwarded to the consumer. We can close and cleanup without caring about the + // substate machine. + return .closeConnectionAndCleanup(cleanupContext) + } + + let action = queryStateMachine.errorHappened(error) + switch action { + case .sendParseDescribeBindExecuteSync, + .sendParseDescribeSync, + .sendBindExecuteSync, + .succeedQuery, + .succeedPreparedStatementCreation, + .forwardRows, + .forwardStreamComplete, + .wait, + .read: + preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") + + case .evaluateErrorAtConnectionLevel: + return .closeConnectionAndCleanup(cleanupContext) + + case .failQuery(let queryContext, with: let error): + return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + + case .forwardStreamError(let error, let read): + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + + case .failPreparedStatementCreation(let promise, with: let error): + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) + } + + case .closeCommand(var closeStateMachine, _): + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + + if closeStateMachine.isComplete { + // in case the close state machine is complete all necessary actions have already + // been forwarded to the consumer. We can close and cleanup without caring about the + // substate machine. + return .closeConnectionAndCleanup(cleanupContext) + } + + let action = closeStateMachine.errorHappened(error) + switch action { + case .sendCloseSync, + .succeedClose, + .read, + .wait: + preconditionFailure("Invalid close state machine action in state: \(self.state), action: \(action)") + case .failClose(let closeCommandContext, with: let error): + return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) + } + + case .closing, .closed: + // We might run into this case because of reentrancy. For example: After we received an + // backend unexpected message, that we read of the wire, we bring this connection into + // the error state and will try to close the connection. However the server might have + // send further follow up messages. In those cases we will run into this method again + // and again. We should just ignore those events. + return .closeConnection(closePromise) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + private mutating func executeNextQueryFromQueue() -> ConnectionAction { + guard case .readyForQuery = self.state else { + preconditionFailure("Only expected to be invoked, if we are readyToQuery") + } + + if let task = self.taskQueue.popFirst() { + return self.executeTask(task) + } + + // if we don't have anything left to do and we are quiescing, next we should close + if case .quiescing(let promise) = self.quiescingState { + self.state = .closing(nil) + return .closeConnection(promise) + } + + return .fireEventReadyForQuery + } + + private mutating func executeTask(_ task: PSQLTask) -> ConnectionAction { + guard case .readyForQuery(let connectionContext) = self.state else { + preconditionFailure("Only expected to be invoked, if we are readyToQuery") + } + + switch task { + case .extendedQuery(let queryContext): + self.state = .modifying // avoid CoW + var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) + let action = extendedQuery.start() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + + case .closeCommand(let closeContext): + self.state = .modifying // avoid CoW + var closeStateMachine = CloseStateMachine(closeContext: closeContext) + let action = closeStateMachine.start() + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) + } + } + + struct Configuration { + let requireTLS: Bool + } +} + +extension ConnectionStateMachine { + func shouldCloseConnection(reason error: PSQLError) -> Bool { + switch error.code.base { + case .failedToAddSSLHandler, + .receivedUnencryptedDataAfterSSLRequest, + .sslUnsupported, + .messageDecodingFailure, + .unexpectedBackendMessage, + .unsupportedAuthMechanism, + .authMechanismRequiresPassword, + .saslError, + .tooManyParameters, + .invalidCommandTag, + .connectionError, + .uncleanShutdown, + .unlistenFailed: + return true + case .queryCancelled: + return false + case .server, .listenFailed: + guard let sqlState = error.serverInfo?[.sqlState] else { + // any error message that doesn't have a sql state field, is unexpected by default. + return true + } + + if sqlState.starts(with: "28") { + // these are authentication errors + return true + } + + return false + case .clientClosedConnection, .poolClosed: + preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen") + case .serverClosedConnection: + return true + } + } + + mutating func setErrorAndCreateCleanupContextIfNeeded(_ error: PSQLError) -> ConnectionAction.CleanUpContext? { + if self.shouldCloseConnection(reason: error) { + return self.setErrorAndCreateCleanupContext(error) + } + + return nil + } + + mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { + let tasks = Array(self.taskQueue) + self.taskQueue.removeAll() + + var forwardedPromise: EventLoopPromise? = nil + if case .quiescing(.some(let quiescePromise)) = self.quiescingState, let closePromise = closePromise { + quiescePromise.futureResult.cascade(to: closePromise) + forwardedPromise = quiescePromise + } else if case .quiescing(.some(let quiescePromise)) = self.quiescingState { + forwardedPromise = quiescePromise + } else { + forwardedPromise = closePromise + } + + let action: ConnectionAction.CleanUpContext.Action + if case .serverClosedConnection = error.code.base { + self.state = .closed(clientInitiated: false, error: error) + action = .fireChannelInactive + } else { + self.state = .closing(error) + action = .close + } + + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) + } +} + +extension ConnectionStateMachine { + mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { + switch action { + case .sendParseDescribeBindExecuteSync(let query): + return .sendParseDescribeBindExecuteSync(query) + case .sendBindExecuteSync(let executeStatement): + return .sendBindExecuteSync(executeStatement) + case .failQuery(let requestContext, with: let error): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .succeedQuery(let requestContext, with: let result): + return .succeedQuery(requestContext, with: result) + case .forwardRows(let buffer): + return .forwardRows(buffer) + case .forwardStreamComplete(let buffer, let commandTag): + return .forwardStreamComplete(buffer, commandTag: commandTag) + case .forwardStreamError(let error, let read): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + + case .evaluateErrorAtConnectionLevel(let error): + if let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) { + return .closeConnectionAndCleanup(cleanupContext) + } + return .wait + case .read: + return .read + case .wait: + return .wait + case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes): + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + return .succeedPreparedStatementCreation(promise, with: rowDescription) + case .failPreparedStatementCreation(let promise, with: let error): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) + } + } +} + +extension ConnectionStateMachine { + mutating func modify(with action: AuthenticationStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { + switch action { + case .sendStartupMessage(let authContext): + return .sendStartupMessage(authContext) + case .sendPassword(let mode, let authContext): + return .sendPasswordMessage(mode, authContext) + case .sendSaslInitialResponse(let name, let initialResponse): + return .sendSaslInitialResponse(name: name, initialResponse: initialResponse) + case .sendSaslResponse(let bytes): + return .sendSaslResponse(bytes) + case .authenticated: + self.state = .authenticated(nil, [:]) + return .wait + case .wait: + return .wait + case .reportAuthenticationError(let error): + let cleanupContext = self.setErrorAndCreateCleanupContext(error) + return .closeConnectionAndCleanup(cleanupContext) + } + } +} + +extension ConnectionStateMachine { + mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { + switch action { + case .sendCloseSync(let sendClose): + return .sendCloseSync(sendClose) + case .succeedClose(let closeContext): + return .succeedClose(closeContext) + case .failClose(let closeContext, with: let error): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failClose(closeContext, with: error, cleanupContext: cleanupContext) + case .read: + return .read + case .wait: + return .wait + } + } +} + +struct SendPrepareStatement { + let name: String + let query: String +} + +struct AuthContext: CustomDebugStringConvertible { + var username: String + var password: String? + var database: String? + var additionalParameters: [(String, String)] + + init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) { + self.username = username + self.password = password + self.database = database + self.additionalParameters = additionalParameters + } + + var debugDescription: String { + """ + AuthContext(username: \(String(reflecting: self.username)), \ + password: \(self.password != nil ? "********" : "nil"), \ + database: \(self.database != nil ? String(reflecting: self.database!) : "nil")) + """ + } +} + +extension AuthContext: Equatable { + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.username == rhs.username + && lhs.password == rhs.password + && lhs.database == rhs.database + && lhs.additionalParameters.count == rhs.additionalParameters.count + else { + return false + } + + return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in + lhs.0 == rhs.0 && lhs.1 == rhs.1 + } + } +} + +enum PasswordAuthencationMode: Equatable { + case cleartext + case md5(salt: UInt32) +} + +extension ConnectionStateMachine.State: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .initialized: + return ".initialized" + case .sslRequestSent: + return ".sslRequestSent" + case .sslNegotiated: + return ".sslNegotiated" + case .sslHandlerAdded: + return ".sslHandlerAdded" + case .waitingToStartAuthentication: + return ".waitingToStartAuthentication" + case .authenticating(let authStateMachine): + return ".authenticating(\(String(reflecting: authStateMachine)))" + case .authenticated(let backendKeyData, let parameters): + return ".authenticated(\(String(reflecting: backendKeyData)), \(String(reflecting: parameters)))" + case .readyForQuery(let connectionContext): + return ".readyForQuery(connectionContext: \(String(reflecting: connectionContext)))" + case .extendedQuery(let subStateMachine, let connectionContext): + return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" + case .closeCommand(let subStateMachine, let connectionContext): + return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" + case .closing: + return ".closing" + case .closed: + return ".closed" + case .modifying: + return ".modifying" + } + } +} + +extension ConnectionStateMachine.ConnectionContext: CustomDebugStringConvertible { + var debugDescription: String { + """ + (processID: \(self.backendKeyData?.processID != nil ? String(self.backendKeyData!.processID) : "nil")), \ + secretKey: \(self.backendKeyData?.secretKey != nil ? String(self.backendKeyData!.secretKey) : "nil")), \ + parameters: \(String(reflecting: self.parameters))) + """ + } +} + +extension ConnectionStateMachine.QuiescingState: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .notQuiescing: + return ".notQuiescing" + case .quiescing(let closePromise): + return ".quiescing(\(closePromise != nil ? "\(closePromise!)" : "nil"))" + } + } +} + diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift new file mode 100644 index 00000000..087a6c24 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -0,0 +1,585 @@ +import NIOCore + +struct ExtendedQueryStateMachine { + + private enum State { + case initialized(ExtendedQueryContext) + case messagesSent(ExtendedQueryContext) + + case parseCompleteReceived(ExtendedQueryContext) + case parameterDescriptionReceived(ExtendedQueryContext) + case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) + case noDataMessageReceived(ExtendedQueryContext) + case emptyQueryResponseReceived + + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is + /// used after receiving a `bindComplete` message + case bindCompleteReceived(ExtendedQueryContext) + case streaming([RowDescription.Column], RowStreamStateMachine) + /// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP + case drain([RowDescription.Column]) + + case commandComplete(commandTag: String) + case error(PSQLError) + + case modifying + } + + enum Action { + case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) + case sendBindExecuteSync(PSQLExecuteStatement) + + // --- general actions + case failQuery(EventLoopPromise, with: PSQLError) + case succeedQuery(EventLoopPromise, with: QueryResult) + + case evaluateErrorAtConnectionLevel(PSQLError) + + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + + // --- streaming actions + // actions if query has requested next row but we are waiting for backend + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) + case forwardStreamError(PSQLError, read: Bool) + + case read + case wait + } + + private var state: State + private var isCancelled: Bool + + init(queryContext: ExtendedQueryContext) { + self.isCancelled = false + self.state = .initialized(queryContext) + } + + mutating func start() -> Action { + guard case .initialized(let queryContext) = self.state else { + preconditionFailure("Start should only be called, if the query has been initialized") + } + + switch queryContext.query { + case .unnamed(let query, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendParseDescribeBindExecuteSync(query) + } + + case .executeStatement(let prepared, _): + return self.avoidingStateMachineCoW { state -> Action in + switch prepared.rowDescription { + case .some(let rowDescription): + state = .rowDescriptionReceived(queryContext, rowDescription.columns) + case .none: + state = .noDataMessageReceived(queryContext) + } + return .sendBindExecuteSync(prepared) + } + + case .prepareStatement(let name, let query, let bindingDataTypes, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) + } + } + } + + mutating func cancel() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Start must be called immediatly after the query was created") + + case .messagesSent(let queryContext), + .parseCompleteReceived(let queryContext), + .parameterDescriptionReceived(let queryContext), + .rowDescriptionReceived(let queryContext, _), + .noDataMessageReceived(let queryContext), + .bindCompleteReceived(let queryContext): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .queryCancelled) + + case .prepareStatement(_, _, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) + } + + case .streaming(let columns, var streamStateMachine): + precondition(!self.isCancelled) + self.isCancelled = true + self.state = .drain(columns) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(.queryCancelled, read: false) + case .read: + return .forwardStreamError(.queryCancelled, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: + // the stream has already finished. + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func parseCompletedReceived() -> Action { + guard case .messagesSent(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.parseComplete)) + } + + return self.avoidingStateMachineCoW { state -> Action in + state = .parseCompleteReceived(queryContext) + return .wait + } + } + + mutating func parameterDescriptionReceived(_ parameterDescription: PostgresBackendMessage.ParameterDescription) -> Action { + guard case .parseCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.parameterDescription(parameterDescription))) + } + + return self.avoidingStateMachineCoW { state -> Action in + state = .parameterDescriptionReceived(queryContext) + return .wait + } + } + + mutating func noDataReceived() -> Action { + guard case .parameterDescriptionReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.noData)) + } + + switch queryContext.query { + case .unnamed, .executeStatement: + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .wait + } + + case .prepareStatement(_, _, _, let promise): + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .succeedPreparedStatementCreation(promise, with: nil) + } + } + } + + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { + guard case .parameterDescriptionReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) + } + + // In Postgres extended queries we receive the `rowDescription` before we send the + // `Bind` message. Well actually it's vice versa, but this is only true since we do + // pipelining during a query. + // + // In the actual protocol description we receive a rowDescription before the Bind + + // In Postgres extended queries we always request the response rows to be returned in + // `.binary` format. + let columns = rowDescription.columns.map { column -> RowDescription.Column in + var column = column + column.format = .binary + return column + } + + self.avoidingStateMachineCoW { state in + state = .rowDescriptionReceived(queryContext, columns) + } + + switch queryContext.query { + case .unnamed, .executeStatement: + return .wait + + case .prepareStatement(_, _, _, let eventLoopPromise): + return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) + } + } + + mutating func bindCompleteReceived() -> Action { + switch self.state { + case .rowDescriptionReceived(let queryContext, let columns): + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .streaming(columns, .init()) + let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) + } + + case .noDataMessageReceived(let queryContext): + return self.avoidingStateMachineCoW { state -> Action in + state = .bindCompleteReceived(queryContext) + return .wait + } + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .emptyQueryResponseReceived, + .bindCompleteReceived, + .streaming, + .drain, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func dataRowReceived(_ dataRow: DataRow) -> Action { + switch self.state { + case .streaming(let columns, var demandStateMachine): + // When receiving a data row, we must ensure that the data row column count + // matches the previously received row description column count. + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + return self.avoidingStateMachineCoW { state -> Action in + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) + return .wait + } + + case .drain(let columns): + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + // we ignore all rows and wait for readyForQuery + return .wait + + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .emptyQueryResponseReceived, + .rowDescriptionReceived, + .bindCompleteReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func commandCompletedReceived(_ commandTag: String) -> Action { + switch self.state { + case .bindCompleteReceived(let context): + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + preconditionFailure("Invalid state: \(self.state)") + } + + case .streaming(_, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag) + } + + case .drain: + precondition(self.isCancelled) + self.state = .commandComplete(commandTag: commandTag) + return .wait + + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .emptyQueryResponseReceived, + .rowDescriptionReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func emptyQueryResponseReceived() -> Action { + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + } + + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { + let error = PSQLError.server(errorMessage) + switch self.state { + case .initialized: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .bindCompleteReceived: + return self.setAndFireError(error) + case .rowDescriptionReceived, .noDataMessageReceived: + return self.setAndFireError(error) + case .streaming, .drain: + return self.setAndFireError(error) + case .commandComplete, .emptyQueryResponseReceived: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> Action { + //self.queryObject.noticeReceived(notice) + return .wait + } + + mutating func errorHappened(_ error: PSQLError) -> Action { + return self.setAndFireError(error) + } + + // MARK: Customer Actions + + mutating func requestQueryRows() -> Action { + switch self.state { + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.demandMoreResponseBodyParts() + state = .streaming(columns, demandStateMachine) + switch action { + case .read: + return .read + case .wait: + return .wait + } + } + + case .drain: + return .wait + + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .emptyQueryResponseReceived, + .rowDescriptionReceived, + .bindCompleteReceived: + preconditionFailure("Requested to consume next row without anything going on.") + + case .commandComplete, .error: + preconditionFailure("The stream is already closed or in a failure state; rows can not be consumed at this time.") + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Channel actions + + mutating func channelReadComplete() -> Action { + switch self.state { + case .initialized, + .commandComplete, + .drain, + .error, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .emptyQueryResponseReceived, + .rowDescriptionReceived, + .bindCompleteReceived: + return .wait + + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let rows = demandStateMachine.channelReadComplete() + state = .streaming(columns, demandStateMachine) + switch rows { + case .some(let rows): + return .forwardRows(rows) + case .none: + return .wait + } + } + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func readEventCaught() -> Action { + switch self.state { + case .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .rowDescriptionReceived, + .bindCompleteReceived: + return .read + case .streaming(let columns, var demandStateMachine): + precondition(!self.isCancelled) + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.read() + state = .streaming(columns, demandStateMachine) + switch action { + case .wait: + return .wait + case .read: + return .read + } + } + case .initialized, + .commandComplete, + .emptyQueryResponseReceived, + .drain, + .error: + // we already have the complete stream received, now we are waiting for a + // `readyForQuery` package. To receive this we need to read! + return .read + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Private Methods + + private mutating func setAndFireError(_ error: PSQLError) -> Action { + switch self.state { + case .initialized(let context), + .messagesSent(let context), + .parseCompleteReceived(let context), + .parameterDescriptionReceived(let context), + .rowDescriptionReceived(let context, _), + .noDataMessageReceived(let context), + .bindCompleteReceived(let context): + self.state = .error(error) + if self.isCancelled { + return .evaluateErrorAtConnectionLevel(error) + } else { + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: error) + case .prepareStatement(_, _, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: error) + } + } + + case .drain: + self.state = .error(error) + return .evaluateErrorAtConnectionLevel(error) + + case .streaming(_, var streamStateMachine): + self.state = .error(error) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(error, read: false) + case .read: + return .forwardStreamError(error, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + case .modifying: + preconditionFailure("Invalid state") + } + } + + var isComplete: Bool { + switch self.state { + case .commandComplete, .emptyQueryResponseReceived, .error: + return true + + case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): + switch context.query { + case .prepareStatement: + return true + case .unnamed, .executeStatement: + return false + } + + case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: + return false + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension ExtendedQueryStateMachine { + /// So, uh...this function needs some explaining. + /// + /// While the state machine logic above is great, there is a downside to having all of the state machine data in + /// associated data on enumerations: any modification of that data will trigger copy on write for heap-allocated + /// data. That means that for _every operation on the state machine_ we will CoW our underlying state, which is + /// not good. + /// + /// The way we can avoid this is by using this helper function. It will temporarily set state to a value with no + /// associated data, before attempting the body of the function. It will also verify that the state machine never + /// remains in this bad state. + /// + /// A key note here is that all callers must ensure that they return to a good state before they exit. + /// + /// Sadly, because it's generic and has a closure, we need to force it to be inlined at all call sites, which is + /// not ideal. + @inline(__always) + private mutating func avoidingStateMachineCoW(_ body: (inout State) -> ReturnType) -> ReturnType { + self.state = .modifying + defer { + assert(!self.isModifying) + } + + return body(&self.state) + } + + private var isModifying: Bool { + if case .modifying = self.state { + return true + } else { + return false + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift new file mode 100644 index 00000000..89f40469 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -0,0 +1,254 @@ +import NIOCore + +struct ListenStateMachine { + var channels: [String: ChannelState] + + init() { + self.channels = [:] + } + + enum StartListeningAction { + case none + case startListening(String) + case succeedListenStart(NotificationListener) + } + + mutating func startListening(_ new: NotificationListener) -> StartListeningAction { + return self.channels[new.channel, default: .init()].start(new) + } + + enum StartListeningSuccessAction { + case stopListening + case activateListeners(Dictionary.Values) + } + + mutating func startListeningSucceeded(channel: String) -> StartListeningSuccessAction { + return self.channels[channel]!.startListeningSucceeded() + } + + mutating func startListeningFailed(channel: String, error: Error) -> Dictionary.Values { + return self.channels[channel]!.startListeningFailed(error) + } + + enum StopListeningSuccessAction { + case startListening + case none + } + + mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction { + switch self.channels[channel]!.stopListeningSucceeded() { + case .none: + self.channels.removeValue(forKey: channel) + return .none + + case .startListening: + return .startListening + } + } + + enum CancelAction { + case stopListening(String, cancelListener: NotificationListener) + case cancelListener(NotificationListener) + case none + } + + mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction { + return self.channels[channel]?.cancelListening(id: id) ?? .none + } + + mutating func fail(_ error: Error) -> [NotificationListener] { + var result = [NotificationListener]() + while var (_, channel) = self.channels.popFirst() { + switch channel.fail(error) { + case .none: + continue + + case .failListeners(let listeners): + result.append(contentsOf: listeners) + } + } + return result + } + + enum ReceivedAction { + case none + case notify(Dictionary.Values) + } + + func notificationReceived(channel: String) -> ReceivedAction { + // TODO: Do we want to close the connection, if we receive a notification on a channel that we don't listen to? + // We can only change this with the next major release, as it would break current functionality. + return self.channels[channel]?.notificationReceived() ?? .none + } +} + +extension ListenStateMachine { + struct ChannelState { + enum State { + case initialized + case starting([Int: NotificationListener]) + case listening([Int: NotificationListener]) + case stopping([Int: NotificationListener]) + case failed(Error) + } + + private var state: State + + init() { + self.state = .initialized + } + + mutating func start(_ new: NotificationListener) -> StartListeningAction { + switch self.state { + case .initialized: + self.state = .starting([new.id: new]) + return .startListening(new.channel) + + case .starting(var listeners): + listeners[new.id] = new + self.state = .starting(listeners) + return .none + + case .listening(var listeners): + listeners[new.id] = new + self.state = .listening(listeners) + return .succeedListenStart(new) + + case .stopping(var listeners): + listeners[new.id] = new + self.state = .stopping(listeners) + return .none + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningSucceeded() -> StartListeningSuccessAction { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + if listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening + } else { + self.state = .listening(listeners) + return .activateListeners(listeners.values) + } + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningFailed(_ error: Error) -> Dictionary.Values { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + self.state = .initialized + return listeners.values + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func stopListeningSucceeded() -> StopListeningSuccessAction { + switch self.state { + case .initialized, .listening, .starting: + fatalError("Invalid state: \(self.state)") + + case .stopping(let listeners): + if listeners.isEmpty { + self.state = .initialized + return .none + } else { + self.state = .starting(listeners) + return .startListening + } + + case .failed: + return .none + } + } + + mutating func cancelListening(id: Int) -> CancelAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .starting(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .listening(var listeners): + precondition(!listeners.isEmpty) + let maybeLast = listeners.removeValue(forKey: id) + if let last = maybeLast, listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening(last.channel, cancelListener: last) + } else { + self.state = .listening(listeners) + if let notLast = maybeLast { + return .cancelListener(notLast) + } + return .none + } + + case .stopping(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .stopping(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .failed: + return .none + } + } + + enum FailAction { + case failListeners(Dictionary.Values) + case none + } + + mutating func fail(_ error: Error) -> FailAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners), .listening(let listeners), .stopping(let listeners): + self.state = .failed(error) + return .failListeners(listeners.values) + + case .failed: + return .none + } + } + + func notificationReceived() -> ReceivedAction { + switch self.state { + case .initialized, .starting: + fatalError("Invalid state: \(self.state)") + + case .listening(let listeners): + return .notify(listeners.values) + + case .stopping: + return .none + + default: + preconditionFailure("TODO: Implemented") + } + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift new file mode 100644 index 00000000..5afa4d0b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -0,0 +1,93 @@ +import NIOCore + +struct PreparedStatementStateMachine { + enum State { + case preparing([PreparedStatementContext]) + case prepared(RowDescription?) + case error(PSQLError) + } + + var preparedStatements: [String: State] = [:] + + enum LookupAction { + case prepareStatement + case waitForAlreadyInFlightPreparation + case executeStatement(RowDescription?) + case returnError(PSQLError) + } + + mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction { + if let state = self.preparedStatements[preparedStatement.name] { + switch state { + case .preparing(var statements): + statements.append(preparedStatement) + self.preparedStatements[preparedStatement.name] = .preparing(statements) + return .waitForAlreadyInFlightPreparation + case .prepared(let rowDescription): + return .executeStatement(rowDescription) + case .error(let error): + return .returnError(error) + } + } else { + self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement]) + return .prepareStatement + } + } + + struct PreparationCompleteAction { + var statements: [PreparedStatementContext] + var rowDescription: RowDescription? + } + + mutating func preparationComplete( + name: String, + rowDescription: RowDescription? + ) -> PreparationCompleteAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + // When sending the bindings we are going to ask for binary data. + if var rowDescription = rowDescription { + for i in 0.. ErrorHappenedAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + self.preparedStatements[name] = .error(error) + return ErrorHappenedAction( + statements: statements, + error: error + ) + case .prepared, .error: + preconditionFailure("Error happened in an unexpected state \(state)") + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift new file mode 100644 index 00000000..4bfd5e9b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift @@ -0,0 +1,214 @@ +import NIOCore + +/// A sub state for receiving data rows. Stores whether the consumer has either signaled demand and whether the +/// channel has issued `read` events. +/// +/// This should be used as a SubStateMachine in QuerySubStateMachines. +struct RowStreamStateMachine { + + enum Action { + case read + case wait + } + + private enum State { + /// The state machines expects further writes to `channelRead`. The writes are appended to the buffer. + case waitingForRows([DataRow]) + /// The state machines expects a call to `demandMoreResponseBodyParts` or `read`. The buffer is + /// empty. It is preserved for performance reasons. + case waitingForReadOrDemand([DataRow]) + /// The state machines expects a call to `read`. The buffer is empty. It is preserved for performance reasons. + case waitingForRead([DataRow]) + /// The state machines expects a call to `demandMoreResponseBodyParts`. The buffer is empty. It is + /// preserved for performance reasons. + case waitingForDemand([DataRow]) + + case failed + + case modifying + } + + private var state: State + + init() { + var buffer = [DataRow]() + buffer.reserveCapacity(32) + self.state = .waitingForRows(buffer) + } + + mutating func receivedRow(_ newRow: DataRow) { + switch self.state { + case .waitingForRows(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForRows(buffer) + + // For all the following cases, please note: + // Normally these code paths should never be hit. However there is one way to trigger + // this: + // + // If the server decides to close a connection, NIO will forward all outstanding + // `channelRead`s without waiting for a next `context.read` call. For this reason we might + // receive new rows, when we don't expect them here. + case .waitingForRead(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForRead(buffer) + + case .waitingForDemand(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForDemand(buffer) + + case .waitingForReadOrDemand(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForReadOrDemand(buffer) + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func channelReadComplete() -> [DataRow]? { + switch self.state { + case .waitingForRows(let buffer): + if buffer.isEmpty { + self.state = .waitingForRead(buffer) + return nil + } else { + var newBuffer = buffer + newBuffer.removeAll(keepingCapacity: true) + self.state = .waitingForReadOrDemand(newBuffer) + return buffer + } + + case .waitingForRead, + .waitingForDemand, + .waitingForReadOrDemand: + preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)") + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func demandMoreResponseBodyParts() -> Action { + switch self.state { + case .waitingForDemand(let buffer): + self.state = .waitingForRows(buffer) + return .read + + case .waitingForReadOrDemand(let buffer): + self.state = .waitingForRead(buffer) + return .wait + + case .waitingForRead: + // If we are `.waitingForRead`, no action needs to be taken. Demand has already been + // signaled. Once we receive the next `read`, we will forward it, right away + return .wait + + case .waitingForRows: + // If we are `.waitingForRows`, no action needs to be taken. As soon as we receive + // the next `channelReadComplete` we will forward all buffered data + return .wait + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func read() -> Action { + switch self.state { + case .waitingForRows: + // This should never happen. But we don't want to precondition this behavior. Let's just + // pass the read event on + return .read + + case .waitingForReadOrDemand(let buffer): + self.state = .waitingForDemand(buffer) + return .wait + + case .waitingForRead(let buffer): + self.state = .waitingForRows(buffer) + return .read + + case .waitingForDemand: + // we have already received a read event. We will issue it as soon as we received demand + // from the consumer + return .wait + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func end() -> [DataRow] { + switch self.state { + case .waitingForRows(let buffer): + return buffer + + case .waitingForReadOrDemand(let buffer), + .waitingForRead(let buffer), + .waitingForDemand(let buffer): + + // Normally this code path should never be hit. However there is one way to trigger + // this: + // + // If the server decides to close a connection, NIO will forward all outstanding + // `channelRead`s without waiting for a next `context.read` call. For this reason we might + // receive a call to `end()`, when we don't expect it here. + return buffer + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func fail() -> Action { + switch self.state { + case .waitingForRows, + .waitingForReadOrDemand, + .waitingForRead: + self.state = .failed + return .wait + + case .waitingForDemand: + self.state = .failed + return .read + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift new file mode 100644 index 00000000..ddab0fff --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -0,0 +1,242 @@ +import NIOCore +import struct Foundation.Date +import struct Foundation.UUID + +// MARK: Protocols + +/// A type, of which arrays can be encoded into and decoded from a postgres binary format +public protocol PostgresArrayEncodable: PostgresEncodable { + static var psqlArrayType: PostgresDataType { get } +} + +/// A type that can be decoded into a Swift Array of its own type from a Postgres array. +public protocol PostgresArrayDecodable: PostgresDecodable {} + +// MARK: Element conformances + +extension Bool: PostgresArrayDecodable {} + +extension Bool: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .boolArray } +} + +extension ByteBuffer: PostgresArrayDecodable {} + +extension ByteBuffer: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .byteaArray } +} + +extension UInt8: PostgresArrayDecodable {} + +extension UInt8: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .charArray } +} + + +extension Int16: PostgresArrayDecodable {} + +extension Int16: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int2Array } +} + +extension Int32: PostgresArrayDecodable {} + +extension Int32: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int4Array } +} + +extension Int64: PostgresArrayDecodable {} + +extension Int64: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int8Array } +} + +extension Int: PostgresArrayDecodable {} + +extension Int: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { + if MemoryLayout.size == 8 { + return .int8Array + } + return .int4Array + } +} + +extension Float: PostgresArrayDecodable {} + +extension Float: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .float4Array } +} + +extension Double: PostgresArrayDecodable {} + +extension Double: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .float8Array } +} + +extension String: PostgresArrayDecodable {} + +extension String: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .textArray } +} + +extension UUID: PostgresArrayDecodable {} + +extension UUID: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .uuidArray } +} + +extension Date: PostgresArrayDecodable {} + +extension Date: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .timestamptzArray } +} + +extension Range: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} + +extension Range: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { + public static var psqlArrayType: PostgresDataType { Bound.psqlRangeArrayType } +} + +extension ClosedRange: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} + +extension ClosedRange: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { + public static var psqlArrayType: PostgresDataType { Bound.psqlRangeArrayType } +} + +// MARK: Array conformances + +extension Array: PostgresEncodable where Element: PostgresArrayEncodable { + public static var psqlType: PostgresDataType { + Element.psqlArrayType + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + // 0 if empty, 1 if not + buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) + // b + buffer.writeInteger(0, as: Int32.self) + // array element type + buffer.writeInteger(Element.psqlType.rawValue) + + // continue if the array is not empty + guard !self.isEmpty else { + return + } + + // length of array + buffer.writeInteger(numericCast(self.count), as: Int32.self) + // dimensions + buffer.writeInteger(1, as: Int32.self) + + try self.forEach { element in + try element.encodeRaw(into: &buffer, context: context) + } + } +} + +// explicitly conforming to PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresThrowingDynamicTypeEncodable where Element: PostgresArrayEncodable {} + +extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + Element.psqlArrayType + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + // 0 if empty, 1 if not + buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) + // b + buffer.writeInteger(0, as: Int32.self) + // array element type + buffer.writeInteger(Element.psqlType.rawValue) + + // continue if the array is not empty + guard !self.isEmpty else { + return + } + + // length of array + buffer.writeInteger(numericCast(self.count), as: Int32.self) + // dimensions + buffer.writeInteger(1, as: Int32.self) + + self.forEach { element in + element.encodeRaw(into: &buffer, context: context) + } + } +} + +// explicitly conforming to PostgresDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresDynamicTypeEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable {} + +extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + guard case .binary = format else { + // currently we only support decoding arrays in binary format. + throw PostgresDecodingError.Code.failure + } + + guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, UInt32).self), + 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 + else { + throw PostgresDecodingError.Code.failure + } + + let elementType = PostgresDataType(element) + + guard isNotEmpty == 1 else { + self = [] + return + } + + guard let (expectedArrayCount, dimensions) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self), + expectedArrayCount > 0, + dimensions == 1 + else { + throw PostgresDecodingError.Code.failure + } + + var result = Array() + result.reserveCapacity(Int(expectedArrayCount)) + + for _ in 0 ..< expectedArrayCount { + guard let elementLength = buffer.readInteger(as: Int32.self), elementLength >= 0 else { + throw PostgresDecodingError.Code.failure + } + + guard var elementBuffer = buffer.readSlice(length: numericCast(elementLength)) else { + throw PostgresDecodingError.Code.failure + } + + let element = try Element.init(from: &elementBuffer, type: elementType, format: format, context: context) + + result.append(element) + } + + self = result + } +} diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift new file mode 100644 index 00000000..515d167a --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -0,0 +1,62 @@ +import NIOCore + +extension Bool: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + guard type == .bool else { + throw PostgresDecodingError.Code.typeMismatch + } + + switch format { + case .binary: + guard buffer.readableBytes == 1 else { + throw PostgresDecodingError.Code.failure + } + + switch buffer.readInteger(as: UInt8.self) { + case .some(0): + self = false + case .some(1): + self = true + default: + throw PostgresDecodingError.Code.failure + } + case .text: + guard buffer.readableBytes == 1 else { + throw PostgresDecodingError.Code.failure + } + + switch buffer.readInteger(as: UInt8.self) { + case .some(UInt8(ascii: "f")): + self = false + case .some(UInt8(ascii: "t")): + self = true + default: + throw PostgresDecodingError.Code.failure + } + } + } +} + +extension Bool: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .bool + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) + } +} diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift new file mode 100644 index 00000000..f6544df0 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -0,0 +1,84 @@ +import struct Foundation.Data +import NIOCore +import NIOFoundationCompat + +extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { + public static var psqlType: PostgresDataType { + .bytea + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeBytes(self) + } +} + +extension PostgresNonThrowingEncodable where Self: Sequence, Self.Element == UInt8 {} + +extension ByteBuffer: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .bytea + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + var copyOfSelf = self // dirty hack + byteBuffer.writeBuffer(©OfSelf) + } +} + +extension ByteBuffer: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) { + self = buffer + } +} + +extension Data: PostgresEncodable { + public static var psqlType: PostgresDataType { + .bytea + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeBytes(self) + } +} + +extension Data: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) { + self = buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! + } +} diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift new file mode 100644 index 00000000..31d8d749 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -0,0 +1,59 @@ +import NIOCore +import struct Foundation.Date + +extension Date: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .timestamptz + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) + byteBuffer.writeInteger(Int64(seconds)) + } + + // MARK: Private Constants + + @usableFromInline + static let _microsecondsPerSecond: Int64 = 1_000_000 + @usableFromInline + static let _secondsInDay: Int64 = 24 * 60 * 60 + + /// values are stored as seconds before or after midnight 2000-01-01 + @usableFromInline + static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) +} + +extension Date: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch type { + case .timestamp, .timestamptz: + guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { + throw PostgresDecodingError.Code.failure + } + let seconds = Double(microseconds) / Double(Self._microsecondsPerSecond) + self = Date(timeInterval: seconds, since: Self._psqlDateStart) + case .date: + guard buffer.readableBytes == 4, let days = buffer.readInteger(as: Int32.self) else { + throw PostgresDecodingError.Code.failure + } + let seconds = Int64(days) * Self._secondsInDay + self = Date(timeInterval: Double(seconds), since: Self._psqlDateStart) + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift new file mode 100644 index 00000000..f634d4ae --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -0,0 +1,49 @@ +import NIOCore +import struct Foundation.Decimal + +extension Decimal: PostgresEncodable { + public static var psqlType: PostgresDataType { + .numeric + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let numeric = PostgresNumeric(decimal: self) + byteBuffer.writeInteger(numeric.ndigits) + byteBuffer.writeInteger(numeric.weight) + byteBuffer.writeInteger(numeric.sign) + byteBuffer.writeInteger(numeric.dscale) + var value = numeric.value + byteBuffer.writeBuffer(&value) + } +} + +extension Decimal: PostgresDecodable { + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .numeric): + guard let numeric = PostgresNumeric(buffer: &buffer) else { + throw PostgresDecodingError.Code.failure + } + self = numeric.decimal + case (.text, .numeric): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Decimal(string: string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift new file mode 100644 index 00000000..8b5e4472 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -0,0 +1,97 @@ +import NIOCore + +extension Float: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .float4 + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.psqlWriteFloat(self) + } +} + +extension Float: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .float4): + guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { + throw PostgresDecodingError.Code.failure + } + self = float + case (.binary, .float8): + guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { + throw PostgresDecodingError.Code.failure + } + self = Float(double) + case (.text, .float4), (.text, .float8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Float(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} + +extension Double: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .float8 + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.psqlWriteDouble(self) + } +} + +extension Double: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .float4): + guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { + throw PostgresDecodingError.Code.failure + } + self = Double(float) + case (.binary, .float8): + guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { + throw PostgresDecodingError.Code.failure + } + self = double + case (.text, .float4), (.text, .float8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Double(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift new file mode 100644 index 00000000..c2f3b339 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -0,0 +1,254 @@ +import NIOCore + +// MARK: UInt8 + +extension UInt8: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .char + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: UInt8.self) + } +} + +extension UInt8: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch type { + case .bpchar, .char: + guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { + throw PostgresDecodingError.Code.failure + } + + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} + +// MARK: Int16 + +extension Int16: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .int2 + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int16.self) + } +} + +extension Int16: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .int2): + guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { + throw PostgresDecodingError.Code.failure + } + self = value + case (.text, .int2): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int16(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} + +// MARK: Int32 + +extension Int32: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .int4 + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int32.self) + } +} + +extension Int32: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .int2): + guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { + throw PostgresDecodingError.Code.failure + } + self = Int32(value) + case (.binary, .int4): + guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { + throw PostgresDecodingError.Code.failure + } + self = Int32(value) + case (.text, .int2), (.text, .int4): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int32(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} + +// MARK: Int64 + +extension Int64: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .int8 + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int64.self) + } +} + +extension Int64: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .int2): + guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { + throw PostgresDecodingError.Code.failure + } + self = Int64(value) + case (.binary, .int4): + guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { + throw PostgresDecodingError.Code.failure + } + self = Int64(value) + case (.binary, .int8): + guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int64.self) else { + throw PostgresDecodingError.Code.failure + } + self = value + case (.text, .int2), (.text, .int4), (.text, .int8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int64(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} + +// MARK: Int + +extension Int: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + switch MemoryLayout.size { + case 4: + return .int4 + case 8: + return .int8 + default: + preconditionFailure("Int is expected to be an Int32 or Int64") + } + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int.self) + } +} + +extension Int: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .int2): + guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { + throw PostgresDecodingError.Code.failure + } + self = Int(value) + case (.binary, .int4): + guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self).flatMap({ Int(exactly: $0) }) else { + throw PostgresDecodingError.Code.failure + } + self = value + case (.binary, .int8): + guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self).flatMap({ Int(exactly: $0) }) else { + throw PostgresDecodingError.Code.failure + } + self = value + case (.text, .int2), (.text, .int4), (.text, .int8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int(string) else { + throw PostgresDecodingError.Code.failure + } + self = value + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift new file mode 100644 index 00000000..e469f0e5 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -0,0 +1,47 @@ +import NIOCore +import NIOFoundationCompat +import class Foundation.JSONEncoder +import class Foundation.JSONDecoder + +@usableFromInline +let JSONBVersionByte: UInt8 = 0x01 + +extension PostgresEncodable where Self: Encodable { + public static var psqlType: PostgresDataType { + .jsonb + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + byteBuffer.writeInteger(JSONBVersionByte) + try context.jsonEncoder.encode(self, into: &byteBuffer) + } +} + +extension PostgresDecodable where Self: Decodable { + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .jsonb): + guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { + throw PostgresDecodingError.Code.failure + } + self = try context.jsonDecoder.decode(Self.self, from: buffer) + case (.binary, .json), (.text, .jsonb), (.text, .json): + self = try context.jsonDecoder.decode(Self.self, from: buffer) + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift new file mode 100644 index 00000000..6279cf4b --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift @@ -0,0 +1,325 @@ +import NIOCore + +// MARK: Protocols + +/// A type that can be encoded into a Postgres range type where it is the bound type +public protocol PostgresRangeEncodable: PostgresNonThrowingEncodable { + static var psqlRangeType: PostgresDataType { get } +} + +/// A type that can be decoded into a Swift RangeExpression type from a Postgres range where it is the bound type +public protocol PostgresRangeDecodable: PostgresDecodable { + /// If a Postgres range type has a well-defined step, + /// Postgres automatically converts it to a canonical form. + /// Types such as `int4range` get converted to upper-bound-exclusive. + /// This method is needed when converting an upper bound to inclusive. + /// It should throw if the type lacks a well-defined step. + func upperBoundExclusiveToUpperBoundInclusive() throws -> Self + + /// Postgres does not store any bound values for empty ranges, + /// but Swift requires a value to initialize an empty Range. + static var valueForEmptyRange: Self { get } +} + +/// A type that can be encoded into a Postgres range array type where it is the bound type +public protocol PostgresRangeArrayEncodable: PostgresRangeEncodable { + static var psqlRangeArrayType: PostgresDataType { get } +} + +/// A type that can be decoded into a Swift RangeExpression array type from a Postgres range array where it is the bound type +public protocol PostgresRangeArrayDecodable: PostgresRangeDecodable {} + +// MARK: Bound conformances + +extension FixedWidthInteger where Self: PostgresRangeDecodable { + public func upperBoundExclusiveToUpperBoundInclusive() -> Self { + return self - 1 + } + + public static var valueForEmptyRange: Self { + return .zero + } +} + +extension Int32: PostgresRangeEncodable { + public static var psqlRangeType: PostgresDataType { return .int4Range } +} + +extension Int32: PostgresRangeDecodable {} + +extension Int32: PostgresRangeArrayEncodable { + public static var psqlRangeArrayType: PostgresDataType { return .int4RangeArray } +} + +extension Int32: PostgresRangeArrayDecodable {} + +extension Int64: PostgresRangeEncodable { + public static var psqlRangeType: PostgresDataType { return .int8Range } +} + +extension Int64: PostgresRangeDecodable {} + +extension Int64: PostgresRangeArrayEncodable { + public static var psqlRangeArrayType: PostgresDataType { return .int8RangeArray } +} + +extension Int64: PostgresRangeArrayDecodable {} + +// MARK: PostgresRange + +@usableFromInline +struct PostgresRange { + @usableFromInline let lowerBound: Bound? + @usableFromInline let upperBound: Bound? + @usableFromInline let isLowerBoundInclusive: Bool + @usableFromInline let isUpperBoundInclusive: Bool + + @inlinable + init( + lowerBound: Bound?, + upperBound: Bound?, + isLowerBoundInclusive: Bool, + isUpperBoundInclusive: Bool + ) { + self.lowerBound = lowerBound + self.upperBound = upperBound + self.isLowerBoundInclusive = isLowerBoundInclusive + self.isUpperBoundInclusive = isUpperBoundInclusive + } +} + +/// Used by Postgres to represent certain range properties +@usableFromInline +struct PostgresRangeFlag { + @usableFromInline static let isEmpty: UInt8 = 0x01 + @usableFromInline static let isLowerBoundInclusive: UInt8 = 0x02 + @usableFromInline static let isUpperBoundInclusive: UInt8 = 0x04 +} + +extension PostgresRange: PostgresDecodable where Bound: PostgresRangeDecodable { + @inlinable + init( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + guard case .binary = format else { + throw PostgresDecodingError.Code.failure + } + + guard let boundType: PostgresDataType = type.boundType else { + throw PostgresDecodingError.Code.failure + } + + // flags byte contains certain properties of the range + guard let flags: UInt8 = byteBuffer.readInteger(as: UInt8.self) else { + throw PostgresDecodingError.Code.failure + } + + let isEmpty: Bool = flags & PostgresRangeFlag.isEmpty != 0 + if isEmpty { + self = PostgresRange( + lowerBound: Bound.valueForEmptyRange, + upperBound: Bound.valueForEmptyRange, + isLowerBoundInclusive: true, + isUpperBoundInclusive: false + ) + return + } + + guard let lowerBoundSize: Int32 = byteBuffer.readInteger(as: Int32.self), + Int(lowerBoundSize) == MemoryLayout.size, + var lowerBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(lowerBoundSize)) + else { + throw PostgresDecodingError.Code.failure + } + + let lowerBound = try Bound(from: &lowerBoundBytes, type: boundType, format: format, context: context) + + guard let upperBoundSize = byteBuffer.readInteger(as: Int32.self), + Int(upperBoundSize) == MemoryLayout.size, + var upperBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(upperBoundSize)) + else { + throw PostgresDecodingError.Code.failure + } + + let upperBound = try Bound(from: &upperBoundBytes, type: boundType, format: format, context: context) + + let isLowerBoundInclusive: Bool = flags & PostgresRangeFlag.isLowerBoundInclusive != 0 + let isUpperBoundInclusive: Bool = flags & PostgresRangeFlag.isUpperBoundInclusive != 0 + + self = PostgresRange( + lowerBound: lowerBound, + upperBound: upperBound, + isLowerBoundInclusive: isLowerBoundInclusive, + isUpperBoundInclusive: isUpperBoundInclusive + ) + + } +} + +extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable { + @usableFromInline + static var psqlType: PostgresDataType { return Bound.psqlRangeType } + + @usableFromInline + static var psqlFormat: PostgresFormat { return .binary } + + @inlinable + func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) { + // flags byte contains certain properties of the range + var flags: UInt8 = 0 + if self.isLowerBoundInclusive { + flags |= PostgresRangeFlag.isLowerBoundInclusive + } + if self.isUpperBoundInclusive { + flags |= PostgresRangeFlag.isUpperBoundInclusive + } + + let boundMemorySize = Int32(MemoryLayout.size) + + byteBuffer.writeInteger(flags) + if let lowerBound = self.lowerBound { + byteBuffer.writeInteger(boundMemorySize) + lowerBound.encode(into: &byteBuffer, context: context) + } + if let upperBound = self.upperBound { + byteBuffer.writeInteger(boundMemorySize) + upperBound.encode(into: &byteBuffer, context: context) + } + } +} + +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension PostgresRange: PostgresThrowingDynamicTypeEncodable & PostgresDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + +extension PostgresRange where Bound: Comparable { + @inlinable + init(range: Range) { + self.lowerBound = range.lowerBound + self.upperBound = range.upperBound + self.isLowerBoundInclusive = true + self.isUpperBoundInclusive = false + } + + @inlinable + init(closedRange: ClosedRange) { + self.lowerBound = closedRange.lowerBound + self.upperBound = closedRange.upperBound + self.isLowerBoundInclusive = true + self.isUpperBoundInclusive = true + } +} + +// MARK: Range + +extension Range: PostgresEncodable where Bound: PostgresRangeEncodable { + public static var psqlType: PostgresDataType { return Bound.psqlRangeType } + public static var psqlFormat: PostgresFormat { return .binary } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let postgresRange = PostgresRange(range: self) + postgresRange.encode(into: &byteBuffer, context: context) + } +} + +extension Range: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} + +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Range: PostgresDynamicTypeEncodable & PostgresThrowingDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + +extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + let postgresRange = try PostgresRange( + from: &buffer, + type: type, + format: format, + context: context + ) + + guard let lowerBound: Bound = postgresRange.lowerBound, + let upperBound: Bound = postgresRange.upperBound, + postgresRange.isLowerBoundInclusive, + !postgresRange.isUpperBoundInclusive + else { + throw PostgresDecodingError.Code.failure + } + + self = lowerBound..( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let postgresRange = PostgresRange(closedRange: self) + postgresRange.encode(into: &byteBuffer, context: context) + } +} + +// explicitly conforming to PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension ClosedRange: PostgresThrowingDynamicTypeEncodable where Bound: PostgresRangeEncodable {} + +extension ClosedRange: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} + +// explicitly conforming to PostgresDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension ClosedRange: PostgresDynamicTypeEncodable where Bound: PostgresRangeEncodable {} + +extension ClosedRange: PostgresDecodable where Bound: PostgresRangeDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + let postgresRange = try PostgresRange( + from: &buffer, + type: type, + format: format, + context: context + ) + + guard let lowerBound: Bound = postgresRange.lowerBound, + var upperBound: Bound = postgresRange.upperBound, + postgresRange.isLowerBoundInclusive + else { + throw PostgresDecodingError.Code.failure + } + + if !postgresRange.isUpperBoundInclusive { + upperBound = try upperBound.upperBoundExclusiveToUpperBoundInclusive() + } + + if lowerBound > upperBound { + throw PostgresDecodingError.Code.failure + } + + self = lowerBound...upperBound + } +} diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift new file mode 100644 index 00000000..ea097963 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -0,0 +1,35 @@ +import NIOCore + +extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEncodable { + public static var psqlType: PostgresDataType { + RawValue.psqlType + } + + public static var psqlFormat: PostgresFormat { + RawValue.psqlFormat + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + try rawValue.encode(into: &byteBuffer, context: context) + } +} + +extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDecodable, RawValue._DecodableType == RawValue { + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + guard let rawValue = try? RawValue(from: &buffer, type: type, format: format, context: context), + let selfValue = Self.init(rawValue: rawValue) else { + throw PostgresDecodingError.Code.failure + } + + self = selfValue + } +} diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift new file mode 100644 index 00000000..41091ab3 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -0,0 +1,48 @@ +import NIOCore +import struct Foundation.UUID + +extension String: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .text + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeString(self) + } +} + +extension String: PostgresDecodable { + + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (_, .varchar), + (_, .bpchar), + (_, .text), + (_, .name): + // we can force unwrap here, since this method only fails if there are not enough + // bytes available. + self = buffer.readString(length: buffer.readableBytes)! + case (_, .uuid): + guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { + throw PostgresDecodingError.Code.failure + } + self = uuid.uuidString + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift new file mode 100644 index 00000000..1de0f394 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -0,0 +1,56 @@ +import NIOCore +import NIOFoundationCompat +import struct Foundation.UUID +import typealias Foundation.uuid_t +import NIOFoundationCompat + +extension UUID: PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + .uuid + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeUUIDBytes(self) + } +} + +extension UUID: PostgresDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + switch (format, type) { + case (.binary, .uuid): + guard let uuid = buffer.readUUIDBytes() else { + throw PostgresDecodingError.Code.failure + } + self = uuid + case (.binary, .varchar), + (.binary, .text), + (.text, .uuid), + (.text, .text), + (.text, .varchar): + guard buffer.readableBytes == 36 else { + throw PostgresDecodingError.Code.failure + } + + guard let uuid = buffer.readString(length: 36).flatMap({ UUID(uuidString: $0) }) else { + throw PostgresDecodingError.Code.failure + } + self = uuid + default: + throw PostgresDecodingError.Code.typeMismatch + } + } +} diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift new file mode 100644 index 00000000..838e624d --- /dev/null +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -0,0 +1,24 @@ +import NIOCore + +internal extension ByteBuffer { + + @usableFromInline + mutating func psqlReadFloat() -> Float? { + return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } + } + + @usableFromInline + mutating func psqlReadDouble() -> Double? { + return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } + } + + @usableFromInline + mutating func psqlWriteFloat(_ float: Float) { + self.writeInteger(float.bitPattern) + } + + @usableFromInline + mutating func psqlWriteDouble(_ double: Double) { + self.writeInteger(double.bitPattern) + } +} diff --git a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift new file mode 100644 index 00000000..97c729f0 --- /dev/null +++ b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift @@ -0,0 +1,147 @@ +import Logging + +@usableFromInline +enum PSQLConnection {} + +extension PSQLConnection { + @usableFromInline + enum LoggerMetaDataKey: String { + case connectionID = "psql_connection_id" + case query = "psql_query" + case name = "psql_name" + case error = "psql_error" + case notice = "psql_notice" + case binds = "psql_binds" + case commandTag = "psql_command_tag" + + case connectionState = "psql_connection_state" + case connectionAction = "psql_connection_action" + case message = "psql_message" + case messageID = "psql_message_id" + case messagePayload = "psql_message_payload" + + + case database = "psql_database" + case username = "psql_username" + + case userEvent = "psql_user_event" + } +} + +@usableFromInline +struct PSQLLoggingMetadata: ExpressibleByDictionaryLiteral { + @usableFromInline + typealias Key = PSQLConnection.LoggerMetaDataKey + @usableFromInline + typealias Value = Logger.MetadataValue + + @usableFromInline var _baseRepresentation: Logger.Metadata + + @usableFromInline + init(dictionaryLiteral elements: (PSQLConnection.LoggerMetaDataKey, Logger.MetadataValue)...) { + let values = elements.lazy.map { (key, value) -> (String, Self.Value) in + (key.rawValue, value) + } + + self._baseRepresentation = Logger.Metadata(uniqueKeysWithValues: values) + } + + @usableFromInline + subscript(postgresLoggingKey loggingKey: PSQLConnection.LoggerMetaDataKey) -> Logger.Metadata.Value? { + get { + return self._baseRepresentation[loggingKey.rawValue] + } + set { + self._baseRepresentation[loggingKey.rawValue] = newValue + } + } + + @inlinable + var representation: Logger.Metadata { + self._baseRepresentation + } +} + + +extension Logger { + + static let psqlNoOpLogger = Logger(label: "psql_do_not_log", factory: { _ in SwiftLogNoOpLogHandler() }) + + @usableFromInline + subscript(postgresMetadataKey metadataKey: PSQLConnection.LoggerMetaDataKey) -> Logger.Metadata.Value? { + get { + return self[metadataKey: metadataKey.rawValue] + } + set { + self[metadataKey: metadataKey.rawValue] = newValue + } + } + +} + +extension Logger { + + /// See `Logger.trace(_:metadata:source:file:function:line:)` + @usableFromInline + func trace(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .trace, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.debug(_:metadata:source:file:function:line:)` + @usableFromInline + func debug(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .debug, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.info(_:metadata:source:file:function:line:)` + @usableFromInline + func info(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .info, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.notice(_:metadata:source:file:function:line:)` + @usableFromInline + func notice(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .notice, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.warning(_:metadata:source:file:function:line:)` + @usableFromInline + func warning(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .warning, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.error(_:metadata:source:file:function:line:)` + @usableFromInline + func error(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .error, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } + + /// See `Logger.critical(_:metadata:source:file:function:line:)` + @usableFromInline + func critical(_ message: @autoclosure () -> Logger.Message, + metadata: @autoclosure () -> PSQLLoggingMetadata, + source: @autoclosure () -> String? = nil, + file: String = #fileID, function: String = #function, line: UInt = #line) { + self.log(level: .critical, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) + } +} + diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift new file mode 100644 index 00000000..eff62e91 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -0,0 +1,91 @@ +import NIOCore + +extension PostgresBackendMessage { + + enum Authentication: PayloadDecodable, Hashable { + case ok + case kerberosV5 + case md5(salt: UInt32) + case plaintext + case scmCredential + case gss + case sspi + case gssContinue(data: ByteBuffer) + case sasl(names: [String]) + case saslContinue(data: ByteBuffer) + case saslFinal(data: ByteBuffer) + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + let authID = try buffer.throwingReadInteger(as: Int32.self) + + switch authID { + case 0: + return .ok + case 2: + return .kerberosV5 + case 3: + return .plaintext + case 5: + guard let salt = buffer.readInteger(as: UInt32.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(4, actual: buffer.readableBytes) + } + return .md5(salt: salt) + case 6: + return .scmCredential + case 7: + return .gss + case 8: + let data = buffer.readSlice(length: buffer.readableBytes)! + return .gssContinue(data: data) + case 9: + return .sspi + case 10: + var names = [String]() + let endIndex = buffer.readerIndex + buffer.readableBytes + while buffer.readerIndex < endIndex, let next = buffer.readNullTerminatedString() { + names.append(next) + } + + return .sasl(names: names) + case 11: + let data = buffer.readSlice(length: buffer.readableBytes)! + return .saslContinue(data: data) + case 12: + let data = buffer.readSlice(length: buffer.readableBytes)! + return .saslFinal(data: data) + default: + throw PSQLPartialDecodingError.unexpectedValue(value: authID) + } + } + + } +} + +extension PostgresBackendMessage.Authentication: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .ok: + return ".ok" + case .kerberosV5: + return ".kerberosV5" + case .md5(salt: let salt): + return ".md5(salt: \(String(reflecting: salt)))" + case .plaintext: + return ".plaintext" + case .scmCredential: + return ".scmCredential" + case .gss: + return ".gss" + case .sspi: + return ".sspi" + case .gssContinue(data: let data): + return ".gssContinue(data: \(String(reflecting: data)))" + case .sasl(names: let names): + return ".sasl(names: \(String(reflecting: names)))" + case .saslContinue(data: let data): + return ".saslContinue(salt: \(String(reflecting: data)))" + case .saslFinal(data: let data): + return ".saslFinal(salt: \(String(reflecting: data)))" + } + } +} diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift new file mode 100644 index 00000000..31a676d2 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -0,0 +1,23 @@ +import NIOCore + +extension PostgresBackendMessage { + + struct BackendKeyData: PayloadDecodable, Hashable { + let processID: Int32 + let secretKey: Int32 + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let (processID, secretKey) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(8, actual: buffer.readableBytes) + } + + return .init(processID: processID, secretKey: secretKey) + } + } +} + +extension PostgresBackendMessage.BackendKeyData: CustomDebugStringConvertible { + var debugDescription: String { + "processID: \(processID), secretKey: \(secretKey)" + } +} diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift new file mode 100644 index 00000000..491e10dc --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -0,0 +1,118 @@ +import NIOCore + +/// A backend data row message. +/// +/// - NOTE: This struct is not part of the ``PSQLBackendMessage`` namespace even +/// though this is where it actually belongs. The reason for this is, that we want +/// this type to be @usableFromInline. If a type is made @usableFromInline in an +/// enclosing type, the enclosing type must be @usableFromInline as well. +/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick +/// the Swift compiler +@usableFromInline +struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Hashable { + @usableFromInline + var columnCount: Int16 + @usableFromInline + var bytes: ByteBuffer + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + let columnCount = try buffer.throwingReadInteger(as: Int16.self) + let firstColumnIndex = buffer.readerIndex + + for _ in 0..= 0 else { + // if buffer length is negative, this means that the value is null + continue + } + + try buffer.throwingMoveReaderIndex(forwardBy: Int(bufferLength)) + } + + buffer.moveReaderIndex(to: firstColumnIndex) + let columnSlice = buffer.readSlice(length: buffer.readableBytes)! + return DataRow(columnCount: columnCount, bytes: columnSlice) + } +} + +extension DataRow: Sequence { + @usableFromInline + typealias Element = ByteBuffer? +} + +extension DataRow: Collection { + + @usableFromInline + struct ColumnIndex: Comparable { + @usableFromInline + var offset: Int + + @inlinable + init(_ index: Int) { + self.offset = index + } + + // Only needed implementation for comparable. The compiler synthesizes the rest from this. + @inlinable + static func < (lhs: Self, rhs: Self) -> Bool { + lhs.offset < rhs.offset + } + } + + @usableFromInline + typealias Index = DataRow.ColumnIndex + + @inlinable + var startIndex: ColumnIndex { + ColumnIndex(self.bytes.readerIndex) + } + + @inlinable + var endIndex: ColumnIndex { + ColumnIndex(self.bytes.readerIndex + self.bytes.readableBytes) + } + + @inlinable + var count: Int { + Int(self.columnCount) + } + + @inlinable + func index(after index: ColumnIndex) -> ColumnIndex { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + var elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + elementLength = 0 + } + return ColumnIndex(index.offset + MemoryLayout.size + elementLength) + } + + @inlinable + subscript(index: ColumnIndex) -> Element { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + return nil + } + return self.bytes.getSlice(at: index.offset + MemoryLayout.size, length: elementLength)! + } +} + +extension DataRow { + subscript(column index: Int) -> Element { + guard index < self.columnCount else { + preconditionFailure("index out of bounds") + } + + var byteIndex = self.startIndex + for _ in 0.. Self { + var fields: [PostgresBackendMessage.Field: String] = [:] + while let id = buffer.readInteger(as: UInt8.self) { + if id == 0 { + break + } + guard let field = PostgresBackendMessage.Field(rawValue: id) else { + throw PSQLPartialDecodingError.valueNotRawRepresentable( + value: id, + asType: PostgresBackendMessage.Field.self) + } + + guard let string = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + fields[field] = string + } + return Self.init(fields: fields) + } +} + +extension PostgresBackendMessage.Field: CustomStringConvertible { + + var description: String { + switch self { + case .localizedSeverity: + return "Localized Severity" + case .severity: + return "Severity" + case .sqlState: + return "Code" + case .message: + return "Message" + case .detail: + return "Detail" + case .hint: + return "Hint" + case .position: + return "Position" + case .internalPosition: + return "Internal position" + case .internalQuery: + return "Internal query" + case .locationContext: + return "Where" + case .schemaName: + return "Schema name" + case .tableName: + return "Table name" + case .columnName: + return "Column name" + case .dataTypeName: + return "Data type name" + case .constraintName: + return "Constraint name" + case .file: + return "File" + case .line: + return "Line" + case .routine: + return "Routine" + } + } +} diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift new file mode 100644 index 00000000..01b9ab4a --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -0,0 +1,23 @@ +import NIOCore + +extension PostgresBackendMessage { + + struct NotificationResponse: PayloadDecodable, Hashable { + let backendPID: Int32 + let channel: String + let payload: String + + static func decode(from buffer: inout ByteBuffer) throws -> PostgresBackendMessage.NotificationResponse { + let backendPID = try buffer.throwingReadInteger(as: Int32.self) + + guard let channel = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let payload = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + + return NotificationResponse(backendPID: backendPID, channel: channel, payload: payload) + } + } +} diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift new file mode 100644 index 00000000..4d12b1b6 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -0,0 +1,23 @@ +import NIOCore + +extension PostgresBackendMessage { + + struct ParameterDescription: PayloadDecodable, Hashable { + /// Specifies the object ID of the parameter data type. + var dataTypes: [PostgresDataType] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + let parameterCount = try buffer.throwingReadInteger(as: UInt16.self) + + var result = [PostgresDataType]() + result.reserveCapacity(Int(parameterCount)) + + for _ in 0.. Self { + guard let name = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + + guard let value = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + + return ParameterStatus(parameter: name, value: value) + } + } +} + +extension PostgresBackendMessage.ParameterStatus: CustomDebugStringConvertible { + var debugDescription: String { + "parameter: \(String(reflecting: self.parameter)), value: \(String(reflecting: self.value))" + } +} + diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift new file mode 100644 index 00000000..41af1b60 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -0,0 +1,31 @@ +import NIOCore + +extension PostgresBackendMessage { + enum TransactionState: UInt8, PayloadDecodable, Hashable { + case idle = 73 // ascii: I + case inTransaction = 84 // ascii: T + case inFailedTransaction = 69 // ascii: E + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + let value = try buffer.throwingReadInteger(as: UInt8.self) + guard let state = Self.init(rawValue: value) else { + throw PSQLPartialDecodingError.valueNotRawRepresentable(value: value, asType: TransactionState.self) + } + + return state + } + } +} + +extension PostgresBackendMessage.TransactionState: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .idle: + return ".idle" + case .inTransaction: + return ".inTransaction" + case .inFailedTransaction: + return ".inFailedTransaction" + } + } +} diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift new file mode 100644 index 00000000..766d06e9 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -0,0 +1,88 @@ +import NIOCore + +/// A backend row description message. +/// +/// - NOTE: This struct is not part of the ``PSQLBackendMessage`` namespace even +/// though this is where it actually belongs. The reason for this is, that we want +/// this type to be @usableFromInline. If a type is made @usableFromInline in an +/// enclosing type, the enclosing type must be @usableFromInline as well. +/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick +/// the Swift compiler. +@usableFromInline +struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Hashable { + /// Specifies the object ID of the parameter data type. + @usableFromInline + var columns: [Column] + + @usableFromInline + struct Column: Hashable, Sendable { + /// The field name. + @usableFromInline + var name: String + + /// If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. + @usableFromInline + var tableOID: Int32 + + /// If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. + @usableFromInline + var columnAttributeNumber: Int16 + + /// The object ID of the field's data type. + @usableFromInline + var dataType: PostgresDataType + + /// The data type size (see pg_type.typlen). Note that negative values denote variable-width types. + @usableFromInline + var dataTypeSize: Int16 + + /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + @usableFromInline + var dataTypeModifier: Int32 + + /// The format being used for the field. Currently will be text or binary. In a RowDescription returned + /// from the statement variant of Describe, the format code is not yet known and will always be text. + @usableFromInline + var format: PostgresFormat + } + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + let columnCount = try buffer.throwingReadInteger(as: Int16.self) + + guard columnCount >= 0 else { + throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) + } + + var result = [Column]() + result.reserveCapacity(Int(columnCount)) + + for _ in 0..) + case streamListening(AsyncThrowingStream.Continuation) + + case closure(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) + case done + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + checkedContinuation: CheckedContinuation + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .streamInitialized(checkedContinuation) + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + context: PostgresListenContext, + closure: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .closure(context, closure) + } + + func startListeningSucceeded(handler: PostgresChannelHandler) { + self.eventLoop.preconditionInEventLoop() + let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop) + + switch self.state { + case .streamInitialized(let checkedContinuation): + let (stream, continuation) = AsyncThrowingStream.makeStream(of: PostgresNotification.self) + let eventLoop = self.eventLoop + let channel = self.channel + let listenerID = self.id + continuation.onTermination = { reason in + switch reason { + case .cancelled: + eventLoop.execute { + handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID) + } + + case .finished: + break + + @unknown default: + break + } + } + self.state = .streamListening(continuation) + + let notificationSequence = PostgresNotificationSequence(base: stream) + checkedContinuation.resume(returning: notificationSequence) + + case .streamListening, .done: + fatalError("Invalid state: \(self.state)") + + case .closure: + break // ignore + } + } + + func notificationReceived(_ backendMessage: PostgresBackendMessage.NotificationResponse) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized, .done: + fatalError("Invalid state: \(self.state)") + case .streamListening(let continuation): + continuation.yield(.init(payload: backendMessage.payload)) + + case .closure(let postgresListenContext, let closure): + let message = PostgresMessage.NotificationResponse( + backendPID: backendMessage.backendPID, + channel: backendMessage.channel, + payload: backendMessage.payload + ) + closure(postgresListenContext, message) + } + } + + func failed(_ error: Error) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: error) + + case .streamListening(let continuation): + self.state = .done + continuation.finish(throwing: error) + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } + + func cancelled() { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: PSQLError(code: .queryCancelled)) + + case .streamListening(let continuation): + self.state = .done + continuation.finish() + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift new file mode 100644 index 00000000..4a9f9216 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -0,0 +1,650 @@ +import NIOCore + +/// An error that is thrown from the PostgresClient. +/// Sendability enforced through Copy on Write semantics +public struct PSQLError: Error, @unchecked Sendable { + + public struct Code: Sendable, Hashable, CustomStringConvertible { + enum Base: Sendable, Hashable { + case sslUnsupported + case failedToAddSSLHandler + case receivedUnencryptedDataAfterSSLRequest + case server + case messageDecodingFailure + case unexpectedBackendMessage + case unsupportedAuthMechanism + case authMechanismRequiresPassword + case saslError + case invalidCommandTag + + case queryCancelled + case tooManyParameters + case clientClosedConnection + case serverClosedConnection + case connectionError + case uncleanShutdown + + case listenFailed + case unlistenFailed + case poolClosed + } + + internal var base: Base + + private init(_ base: Base) { + self.base = base + } + + public static let sslUnsupported = Self(.sslUnsupported) + public static let failedToAddSSLHandler = Self(.failedToAddSSLHandler) + public static let receivedUnencryptedDataAfterSSLRequest = Self(.receivedUnencryptedDataAfterSSLRequest) + public static let server = Self(.server) + public static let messageDecodingFailure = Self(.messageDecodingFailure) + public static let unexpectedBackendMessage = Self(.unexpectedBackendMessage) + public static let unsupportedAuthMechanism = Self(.unsupportedAuthMechanism) + public static let authMechanismRequiresPassword = Self(.authMechanismRequiresPassword) + public static let saslError = Self(.saslError) + public static let invalidCommandTag = Self(.invalidCommandTag) + public static let queryCancelled = Self(.queryCancelled) + public static let tooManyParameters = Self(.tooManyParameters) + public static let clientClosedConnection = Self(.clientClosedConnection) + public static let serverClosedConnection = Self(.serverClosedConnection) + public static let connectionError = Self(.connectionError) + + public static let uncleanShutdown = Self(.uncleanShutdown) + public static let poolClosed = Self(.poolClosed) + + public static let listenFailed = Self.init(.listenFailed) + public static let unlistenFailed = Self.init(.unlistenFailed) + + @available(*, deprecated, renamed: "clientClosedConnection") + public static let connectionQuiescing = Self.clientClosedConnection + + @available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead") + public static let connectionClosed = Self.serverClosedConnection + + public var description: String { + switch self.base { + case .sslUnsupported: + return "sslUnsupported" + case .failedToAddSSLHandler: + return "failedToAddSSLHandler" + case .receivedUnencryptedDataAfterSSLRequest: + return "receivedUnencryptedDataAfterSSLRequest" + case .server: + return "server" + case .messageDecodingFailure: + return "messageDecodingFailure" + case .unexpectedBackendMessage: + return "unexpectedBackendMessage" + case .unsupportedAuthMechanism: + return "unsupportedAuthMechanism" + case .authMechanismRequiresPassword: + return "authMechanismRequiresPassword" + case .saslError: + return "saslError" + case .invalidCommandTag: + return "invalidCommandTag" + case .queryCancelled: + return "queryCancelled" + case .tooManyParameters: + return "tooManyParameters" + case .clientClosedConnection: + return "clientClosedConnection" + case .serverClosedConnection: + return "serverClosedConnection" + case .connectionError: + return "connectionError" + case .uncleanShutdown: + return "uncleanShutdown" + case .poolClosed: + return "poolClosed" + case .listenFailed: + return "listenFailed" + case .unlistenFailed: + return "unlistenFailed" + } + } + } + + private var backing: Backing + + private mutating func copyBackingStorageIfNecessary() { + if !isKnownUniquelyReferenced(&self.backing) { + self.backing = self.backing.copy() + } + } + + /// The ``PSQLError/Code-swift.struct`` code + public internal(set) var code: Code { + get { self.backing.code } + set { + self.copyBackingStorageIfNecessary() + self.backing.code = newValue + } + } + + /// The info that was received from the server + public internal(set) var serverInfo: ServerInfo? { + get { self.backing.serverInfo } + set { + self.copyBackingStorageIfNecessary() + self.backing.serverInfo = newValue + } + } + + /// The underlying error + public internal(set) var underlying: Error? { + get { self.backing.underlying } + set { + self.copyBackingStorageIfNecessary() + self.backing.underlying = newValue + } + } + + /// The file in which the Postgres operation was triggered that failed + public internal(set) var file: String? { + get { self.backing.file } + set { + self.copyBackingStorageIfNecessary() + self.backing.file = newValue + } + } + + /// The line in which the Postgres operation was triggered that failed + public internal(set) var line: Int? { + get { self.backing.line } + set { + self.copyBackingStorageIfNecessary() + self.backing.line = newValue + } + } + + /// The query that failed + public internal(set) var query: PostgresQuery? { + get { self.backing.query } + set { + self.copyBackingStorageIfNecessary() + self.backing.query = newValue + } + } + + /// the backend message... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var backendMessage: PostgresBackendMessage? { + get { self.backing.backendMessage } + set { + self.copyBackingStorageIfNecessary() + self.backing.backendMessage = newValue + } + } + + /// the unsupported auth scheme... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var unsupportedAuthScheme: UnsupportedAuthScheme? { + get { self.backing.unsupportedAuthScheme } + set { + self.copyBackingStorageIfNecessary() + self.backing.unsupportedAuthScheme = newValue + } + } + + /// the invalid command tag... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var invalidCommandTag: String? { + get { self.backing.invalidCommandTag } + set { + self.copyBackingStorageIfNecessary() + self.backing.invalidCommandTag = newValue + } + } + + init(code: Code, query: PostgresQuery, file: String? = nil, line: Int? = nil) { + self.backing = .init(code: code) + self.query = query + self.file = file + self.line = line + } + + init(code: Code) { + self.backing = .init(code: code) + } + + private final class Backing { + fileprivate var code: Code + fileprivate var serverInfo: ServerInfo? + fileprivate var underlying: Error? + fileprivate var file: String? + fileprivate var line: Int? + fileprivate var query: PostgresQuery? + fileprivate var backendMessage: PostgresBackendMessage? + fileprivate var unsupportedAuthScheme: UnsupportedAuthScheme? + fileprivate var invalidCommandTag: String? + + init(code: Code) { + self.code = code + } + + func copy() -> Self { + let new = Self.init(code: self.code) + new.serverInfo = self.serverInfo + new.underlying = self.underlying + new.file = self.file + new.line = self.line + new.query = self.query + new.backendMessage = self.backendMessage + return new + } + } + + public struct ServerInfo { + public struct Field: Hashable, Sendable, CustomStringConvertible { + fileprivate let backing: PostgresBackendMessage.Field + + fileprivate init(_ backing: PostgresBackendMessage.Field) { + self.backing = backing + } + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + /// localized translation of one of these. Always present. + public static let localizedSeverity = Self(.localizedSeverity) + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). + /// This is identical to the S field except that the contents are never localized. + /// This is present only in messages generated by PostgreSQL versions 9.6 and later. + public static let severity = Self(.severity) + + /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. + public static let sqlState = Self(.sqlState) + + /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). + /// Always present. + public static let message = Self(.message) + + /// Detail: an optional secondary error message carrying more detail about the problem. + /// Might run to multiple lines. + public static let detail = Self(.detail) + + /// Hint: an optional suggestion what to do about the problem. + /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) + /// rather than hard facts. Might run to multiple lines. + public static let hint = Self(.hint) + + /// Position: the field value is a decimal ASCII integer, indicating an error cursor + /// position as an index into the original query string. The first character has index 1, + /// and positions are measured in characters not bytes. + public static let position = Self(.position) + + /// Internal position: this is defined the same as the P field, but it is used when the + /// cursor position refers to an internally generated command rather than the one submitted by the client. + /// The q field will always appear when this field appears. + public static let internalPosition = Self(.internalPosition) + + /// Internal query: the text of a failed internally-generated command. + /// This could be, for example, a SQL query issued by a PL/pgSQL function. + public static let internalQuery = Self(.internalQuery) + + /// Where: an indication of the context in which the error occurred. + /// Presently this includes a call stack traceback of active procedural language functions and + /// internally-generated queries. The trace is one entry per line, most recent first. + public static let locationContext = Self(.locationContext) + + /// Schema name: if the error was associated with a specific database object, the name of + /// the schema containing that object, if any. + public static let schemaName = Self(.schemaName) + + /// Table name: if the error was associated with a specific table, the name of the table. + /// (Refer to the schema name field for the name of the table's schema.) + public static let tableName = Self(.tableName) + + /// Column name: if the error was associated with a specific table column, the name of the column. + /// (Refer to the schema and table name fields to identify the table.) + public static let columnName = Self(.columnName) + + /// Data type name: if the error was associated with a specific data type, the name of the data type. + /// (Refer to the schema name field for the name of the data type's schema.) + public static let dataTypeName = Self(.dataTypeName) + + /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. + /// Refer to fields listed above for the associated table or domain. (For this purpose, indexes are + /// treated as constraints, even if they weren't created with constraint syntax.) + public static let constraintName = Self(.constraintName) + + /// File: the file name of the source-code location where the error was reported. + public static let file = Self(.file) + + /// Line: the line number of the source-code location where the error was reported. + public static let line = Self(.line) + + /// Routine: the name of the source-code routine reporting the error. + public static let routine = Self(.routine) + + public var description: String { + switch self.backing { + case .localizedSeverity: + return "localizedSeverity" + case .severity: + return "severity" + case .sqlState: + return "sqlState" + case .message: + return "message" + case .detail: + return "detail" + case .hint: + return "hint" + case .position: + return "position" + case .internalPosition: + return "internalPosition" + case .internalQuery: + return "internalQuery" + case .locationContext: + return "locationContext" + case .schemaName: + return "schemaName" + case .tableName: + return "tableName" + case .columnName: + return "columnName" + case .dataTypeName: + return "dataTypeName" + case .constraintName: + return "constraintName" + case .file: + return "file" + case .line: + return "line" + case .routine: + return "routine" + } + } + } + + let underlying: PostgresBackendMessage.ErrorResponse + + fileprivate init(_ underlying: PostgresBackendMessage.ErrorResponse) { + self.underlying = underlying + } + + /// The detailed server error information. This field is set if the ``PSQLError/code-swift.property`` is + /// ``PSQLError/Code-swift.struct/server``. + public subscript(field: Field) -> String? { + self.underlying.fields[field.backing] + } + } + + // MARK: - Internal convenience factory methods - + + static func unexpectedBackendMessage(_ message: PostgresBackendMessage) -> Self { + var new = Self(code: .unexpectedBackendMessage) + new.backendMessage = message + return new + } + + static func messageDecodingFailure(_ error: PostgresMessageDecodingError) -> Self { + var new = Self(code: .messageDecodingFailure) + new.underlying = error + return new + } + + static func clientClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .clientClosedConnection) + error.underlying = underlying + return error + } + + static func serverClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .serverClosedConnection) + error.underlying = underlying + return error + } + + static let authMechanismRequiresPassword = PSQLError(code: .authMechanismRequiresPassword) + + static let sslUnsupported = PSQLError(code: .sslUnsupported) + + static let queryCancelled = PSQLError(code: .queryCancelled) + + static let uncleanShutdown = PSQLError(code: .uncleanShutdown) + + static let receivedUnencryptedDataAfterSSLRequest = PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) + + static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError { + var error = PSQLError(code: .server) + error.serverInfo = .init(response) + return error + } + + static func sasl(underlying: Error) -> PSQLError { + var error = PSQLError(code: .saslError) + error.underlying = underlying + return error + } + + static func failedToAddSSLHandler(underlying: Error) -> PSQLError { + var error = PSQLError(code: .failedToAddSSLHandler) + error.underlying = underlying + return error + } + + static func connectionError(underlying: Error) -> PSQLError { + var error = PSQLError(code: .connectionError) + error.underlying = underlying + return error + } + + static func unsupportedAuthMechanism(_ authScheme: UnsupportedAuthScheme) -> PSQLError { + var error = PSQLError(code: .unsupportedAuthMechanism) + error.unsupportedAuthScheme = authScheme + return error + } + + static func invalidCommandTag(_ value: String) -> PSQLError { + var error = PSQLError(code: .invalidCommandTag) + error.invalidCommandTag = value + return error + } + + static func unlistenError(underlying: Error) -> PSQLError { + var error = PSQLError(code: .unlistenFailed) + error.underlying = underlying + return error + } + + enum UnsupportedAuthScheme { + case none + case kerberosV5 + case md5 + case plaintext + case scmCredential + case gss + case sspi + case sasl(mechanisms: [String]) + } + + static var poolClosed: PSQLError { + Self.init(code: .poolClosed) + } +} + +extension PSQLError: CustomStringConvertible { + public var description: String { + // This may seem very odd... But we are afraid that users might accidentally send the + // unfiltered errors out to end-users. This may leak security relevant information. For this + // reason we overwrite the error description by default to this generic "Database error" + """ + PSQLError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """ + } +} + +extension PSQLError: CustomDebugStringConvertible { + public var debugDescription: String { + var result = #"PSQLError(code: \#(self.code)"# + + if let serverInfo = self.serverInfo?.underlying { + result.append(", serverInfo: [") + result.append( + serverInfo.fields + .sorted(by: { $0.key.rawValue < $1.key.rawValue }) + .map { "\(PSQLError.ServerInfo.Field($0.0)): \($0.1)" } + .joined(separator: ", ") + ) + result.append("]") + } + + if let backendMessage = self.backendMessage { + result.append(", backendMessage: \(String(reflecting: backendMessage))") + } + + if let unsupportedAuthScheme = self.unsupportedAuthScheme { + result.append(", unsupportedAuthScheme: \(unsupportedAuthScheme)") + } + + if let invalidCommandTag = self.invalidCommandTag { + result.append(", invalidCommandTag: \(invalidCommandTag)") + } + + if let underlying = self.underlying { + result.append(", underlying: \(String(reflecting: underlying))") + } + + if let file = self.file { + result.append(", triggeredFromRequestInFile: \(file)") + if let line = self.line { + result.append(", line: \(line)") + } + } + + if let query = self.query { + result.append(", query: \(String(reflecting: query))") + } + + result.append(")") + + return result + } +} + +/// An error that may happen when a ``PostgresRow`` or ``PostgresCell`` is decoded to native Swift types. +public struct PostgresDecodingError: Error, Equatable { + public struct Code: Hashable, Error, CustomStringConvertible { + enum Base { + case missingData + case typeMismatch + case failure + } + + var base: Base + + init(_ base: Base) { + self.base = base + } + + public static let missingData = Self.init(.missingData) + public static let typeMismatch = Self.init(.typeMismatch) + public static let failure = Self.init(.failure) + + public var description: String { + switch self.base { + case .missingData: + return "missingData" + case .typeMismatch: + return "typeMismatch" + case .failure: + return "failure" + } + } + } + + /// The decoding error code + public let code: Code + + /// The cell's column name for which the decoding failed + public let columnName: String + /// The cell's column index for which the decoding failed + public let columnIndex: Int + /// The swift type the cell should have been decoded into + public let targetType: Any.Type + /// The cell's postgres data type for which the decoding failed + public let postgresType: PostgresDataType + /// The cell's postgres format for which the decoding failed + public let postgresFormat: PostgresFormat + /// A copy of the cell data which was attempted to be decoded + public let postgresData: ByteBuffer? + + /// The file the decoding was attempted in + public let file: String + /// The line the decoding was attempted in + public let line: Int + + @usableFromInline + init( + code: Code, + columnName: String, + columnIndex: Int, + targetType: Any.Type, + postgresType: PostgresDataType, + postgresFormat: PostgresFormat, + postgresData: ByteBuffer?, + file: String, + line: Int + ) { + self.code = code + self.columnName = columnName + self.columnIndex = columnIndex + self.targetType = targetType + self.postgresType = postgresType + self.postgresFormat = postgresFormat + self.postgresData = postgresData + self.file = file + self.line = line + } + + public static func ==(lhs: PostgresDecodingError, rhs: PostgresDecodingError) -> Bool { + return lhs.code == rhs.code + && lhs.columnName == rhs.columnName + && lhs.columnIndex == rhs.columnIndex + && lhs.targetType == rhs.targetType + && lhs.postgresType == rhs.postgresType + && lhs.postgresFormat == rhs.postgresFormat + && lhs.postgresData == rhs.postgresData + && lhs.file == rhs.file + && lhs.line == rhs.line + } +} + +extension PostgresDecodingError: CustomStringConvertible { + public var description: String { + // This may seem very odd... But we are afraid that users might accidentally send the + // unfiltered errors out to end-users. This may leak security relevant information. For this + // reason we overwrite the error description by default to this generic "Database error" + """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """ + } +} + +extension PostgresDecodingError: CustomDebugStringConvertible { + public var debugDescription: String { + var result = #"PostgresDecodingError(code: \#(self.code)"# + + result.append(#", columnName: \#(String(reflecting: self.columnName))"#) + result.append(#", columnIndex: \#(self.columnIndex)"#) + result.append(#", targetType: \#(String(reflecting: self.targetType))"#) + result.append(#", postgresType: \#(self.postgresType)"#) + result.append(#", postgresFormat: \#(self.postgresFormat)"#) + if let postgresData = self.postgresData { + result.append(#", postgresData: \#(String(reflecting: postgresData))"#) + } + result.append(#", file: \#(self.file)"#) + result.append(#", line: \#(self.line)"#) + result.append(")") + + return result + } +} + diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift new file mode 100644 index 00000000..0f426f20 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -0,0 +1,116 @@ +import NIOCore +import NIOTLS +import Logging + +enum PSQLOutgoingEvent { + /// the event we send down the channel to inform the ``PostgresChannelHandler`` to authenticate + /// + /// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration` + case authenticate(AuthContext) + + case gracefulShutdown +} + +enum PSQLEvent { + + /// the event that is used to inform upstream handlers that ``PostgresChannelHandler`` has established a connection + case readyForStartup + + /// the event that is used to inform upstream handlers that ``PostgresChannelHandler`` is currently idle + case readyForQuery +} + + +final class PSQLEventsHandler: ChannelInboundHandler { + typealias InboundIn = Never + + let logger: Logger + var readyForStartupFuture: EventLoopFuture! { + self.readyForStartupPromise!.futureResult + } + var authenticateFuture: EventLoopFuture! { + self.authenticatePromise!.futureResult + } + + + private enum State { + case initialized + case connected + case readyForStartup + case authenticated + } + + private var readyForStartupPromise: EventLoopPromise! + private var authenticatePromise: EventLoopPromise! + private var state: State = .initialized + + init(logger: Logger) { + self.logger = logger + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case PSQLEvent.readyForStartup: + guard case .connected = self.state else { + preconditionFailure() + } + self.state = .readyForStartup + self.readyForStartupPromise.succeed(Void()) + case PSQLEvent.readyForQuery: + switch self.state { + case .initialized, .connected: + preconditionFailure("Expected to get a `readyForStartUp` before we get a `readyForQuery` event") + case .readyForStartup: + // for the first time, we are ready to query, this means startup/auth was + // successful + self.state = .authenticated + self.authenticatePromise.succeed(Void()) + case .authenticated: + break + } + default: + context.fireUserInboundEventTriggered(event) + } + } + + func handlerAdded(context: ChannelHandlerContext) { + self.readyForStartupPromise = context.eventLoop.makePromise(of: Void.self) + self.authenticatePromise = context.eventLoop.makePromise(of: Void.self) + + if context.channel.isActive, case .initialized = self.state { + self.state = .connected + } + } + + func channelActive(context: ChannelHandlerContext) { + if case .initialized = self.state { + self.state = .connected + } + context.fireChannelActive() + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + switch self.state { + case .initialized: + preconditionFailure("Unexpected message for state") + case .connected: + self.readyForStartupPromise.fail(error) + self.authenticatePromise.fail(error) + case .readyForStartup: + self.authenticatePromise.fail(error) + case .authenticated: + break + } + + context.fireErrorCaught(error) + } + + func handlerRemoved(context: ChannelHandlerContext) { + struct HandlerRemovedConnectionError: Error {} + + if case .initialized = self.state { + self.readyForStartupPromise.fail(HandlerRemovedConnectionError()) + self.authenticatePromise.fail(HandlerRemovedConnectionError()) + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift new file mode 100644 index 00000000..5a9abf7e --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift @@ -0,0 +1,14 @@ +struct PSQLPreparedStatement { + + /// The name with which the statement was prepared at the backend + let name: String + + /// The query that is executed when using this `PSQLPreparedStatement` + let query: String + + /// The postgres connection the statement was prepared on + let connection: PostgresConnection + + /// The `RowDescription` to apply to all `DataRow`s when executing this `PSQLPreparedStatement` + let rowDescription: RowDescription? +} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift new file mode 100644 index 00000000..ee925d0e --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -0,0 +1,448 @@ +import NIOCore +import Logging + +struct QueryResult { + enum Value: Equatable { + case noRows(PSQLRowStream.StatementSummary) + case rowDescription([RowDescription.Column]) + } + + var value: Value + + var logger: Logger +} + +// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. +final class PSQLRowStream: @unchecked Sendable { + private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + + enum Source { + case stream([RowDescription.Column], PSQLRowsDataSource) + case noRows(Result) + } + + let eventLoop: EventLoop + let logger: Logger + + private enum BufferState { + case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) + case finished(buffer: CircularBuffer, summary: StatementSummary) + case failure(Error) + } + + private enum DownstreamState { + case waitingForConsumer(BufferState) + case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) + case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) + case consumed(Result) + case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) + } + + internal let rowDescription: [RowDescription.Column] + private let lookupTable: [String: Int] + private var downstreamState: DownstreamState + + init( + source: Source, + eventLoop: EventLoop, + logger: Logger + ) { + let bufferState: BufferState + switch source { + case .stream(let rowDescription, let dataSource): + self.rowDescription = rowDescription + bufferState = .streaming(buffer: .init(), dataSource: dataSource) + case .noRows(.success(let summary)): + self.rowDescription = [] + bufferState = .finished(buffer: .init(), summary: summary) + case .noRows(.failure(let error)): + self.rowDescription = [] + bufferState = .failure(error) + } + + self.downstreamState = .waitingForConsumer(bufferState) + + self.eventLoop = eventLoop + self.logger = logger + + var lookup = [String: Int]() + lookup.reserveCapacity(rowDescription.count) + rowDescription.enumerated().forEach { (index, column) in + lookup[column.name] = index + } + self.lookupTable = lookup + } + + // MARK: Async Sequence + + func asyncSequence(onFinish: @escaping @Sendable () -> () = {}) -> PostgresRowSequence { + self.eventLoop.preconditionInEventLoop() + + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + let producer = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: DataRow.self, + failureType: Error.self, + backPressureStrategy: AdaptiveRowBuffer(), + finishOnDeinit: false, + delegate: self + ) + + let source = producer.source + + switch bufferState { + case .streaming(let bufferedRows, let dataSource): + let yieldResult = source.yield(contentsOf: bufferedRows) + self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) + self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) + + case .finished(let buffer, let summary): + _ = source.yield(contentsOf: buffer) + source.finish() + onFinish() + self.downstreamState = .consumed(.success(summary)) + + case .failure(let error): + source.finish(error) + self.downstreamState = .consumed(.failure(error)) + } + + return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription) + } + + func demand() { + if self.eventLoop.inEventLoop { + self.demand0() + } else { + self.eventLoop.execute { + self.demand0() + } + } + } + + private func demand0() { + switch self.downstreamState { + case .waitingForConsumer, .iteratingRows, .waitingForAll: + preconditionFailure("Invalid state: \(self.downstreamState)") + + case .consumed: + break + + case .asyncSequence(_, let dataSource, _): + dataSource.request(for: self) + } + } + + func cancel() { + if self.eventLoop.inEventLoop { + self.cancel0() + } else { + self.eventLoop.execute { + self.cancel0() + } + } + } + + private func cancel0() { + switch self.downstreamState { + case .asyncSequence(_, let dataSource, let onFinish): + self.downstreamState = .consumed(.failure(CancellationError())) + dataSource.cancel(for: self) + onFinish() + + case .consumed: + return + + case .waitingForConsumer, .iteratingRows, .waitingForAll: + preconditionFailure("Invalid state: \(self.downstreamState)") + } + } + + // MARK: Consume in array + + func all() -> EventLoopFuture<[PostgresRow]> { + if self.eventLoop.inEventLoop { + return self.all0() + } else { + return self.eventLoop.flatSubmit { + self.all0() + } + } + } + + private func all0() -> EventLoopFuture<[PostgresRow]> { + self.eventLoop.preconditionInEventLoop() + + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + switch bufferState { + case .streaming(let bufferedRows, let dataSource): + let promise = self.eventLoop.makePromise(of: [PostgresRow].self) + let rows = bufferedRows.map { data in + PostgresRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) + // immediately request more + dataSource.request(for: self) + return promise.futureResult + + case .finished(let buffer, let summary): + let rows = buffer.map { + PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) + } + + self.downstreamState = .consumed(.success(summary)) + return self.eventLoop.makeSucceededFuture(rows) + + case .failure(let error): + self.downstreamState = .consumed(.failure(error)) + return self.eventLoop.makeFailedFuture(error) + } + } + + // MARK: Consume on EventLoop + + func onRow(_ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + if self.eventLoop.inEventLoop { + return self.onRow0(onRow) + } else { + return self.eventLoop.flatSubmit { + self.onRow0(onRow) + } + } + } + + private func onRow0(_ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + self.eventLoop.preconditionInEventLoop() + + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + switch bufferState { + case .streaming(var buffer, let dataSource): + let promise = self.eventLoop.makePromise(of: Void.self) + do { + for data in buffer { + let row = PostgresRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription + ) + try onRow(row) + } + + buffer.removeAll() + self.downstreamState = .iteratingRows(onRow: onRow, promise, dataSource) + // immediately request more + dataSource.request(for: self) + } catch { + self.downstreamState = .consumed(.failure(error)) + dataSource.cancel(for: self) + promise.fail(error) + } + + return promise.futureResult + + case .finished(let buffer, let summary): + do { + for data in buffer { + let row = PostgresRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription + ) + try onRow(row) + } + + self.downstreamState = .consumed(.success(summary)) + return self.eventLoop.makeSucceededVoidFuture() + } catch { + self.downstreamState = .consumed(.failure(error)) + return self.eventLoop.makeFailedFuture(error) + } + + case .failure(let error): + self.downstreamState = .consumed(.failure(error)) + return self.eventLoop.makeFailedFuture(error) + } + } + + internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) { + self.logger.debug("Notice Received", metadata: [ + .notice: "\(notice)" + ]) + } + + internal func receive(_ newRows: [DataRow]) { + precondition(!newRows.isEmpty, "Expected to get rows!") + self.eventLoop.preconditionInEventLoop() + self.logger.trace("Row stream received rows", metadata: [ + "row_count": "\(newRows.count)" + ]) + + switch self.downstreamState { + case .waitingForConsumer(.streaming(buffer: var buffer, dataSource: let dataSource)): + buffer.append(contentsOf: newRows) + self.downstreamState = .waitingForConsumer(.streaming(buffer: buffer, dataSource: dataSource)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can new rows be received, if an end was already signalled?") + + case .iteratingRows(let onRow, let promise, let dataSource): + do { + for data in newRows { + let row = PostgresRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription + ) + try onRow(row) + } + // immediately request more + dataSource.request(for: self) + } catch { + dataSource.cancel(for: self) + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + return + } + + case .waitingForAll(var rows, let promise, let dataSource): + newRows.forEach { data in + let row = PostgresRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) + rows.append(row) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) + // immediately request more + dataSource.request(for: self) + + case .asyncSequence(let consumer, let source, _): + let yieldResult = consumer.yield(contentsOf: newRows) + self.executeActionBasedOnYieldResult(yieldResult, source: source) + + case .consumed(.success): + preconditionFailure("How can we receive further rows, if we are supposed to be done") + + case .consumed(.failure): + break + } + } + + internal func receive(completion result: Result) { + self.eventLoop.preconditionInEventLoop() + + switch result { + case .success(let commandTag): + self.receiveEnd(commandTag) + case .failure(let error): + self.receiveError(error) + } + } + + private func receiveEnd(_ commandTag: String) { + switch self.downstreamState { + case .waitingForConsumer(.streaming(buffer: let buffer, _)): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): + preconditionFailure("How can we get another end, if an end was already signalled?") + + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.success(.tag(commandTag))) + promise.succeed(()) + + case .waitingForAll(let rows, let promise, _): + self.downstreamState = .consumed(.success(.tag(commandTag))) + promise.succeed(rows) + + case .asyncSequence(let source, _, let onFinish): + self.downstreamState = .consumed(.success(.tag(commandTag))) + source.finish() + onFinish() + + case .consumed(.success(.tag)), .consumed(.failure): + break + } + } + + private func receiveError(_ error: Error) { + switch self.downstreamState { + case .waitingForConsumer(.streaming): + self.downstreamState = .waitingForConsumer(.failure(error)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): + preconditionFailure("How can we get another end, if an end was already signalled?") + + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .waitingForAll(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .asyncSequence(let consumer, _, let onFinish): + self.downstreamState = .consumed(.failure(error)) + consumer.finish(error) + onFinish() + + case .consumed(.success(.tag)), .consumed(.failure): + break + } + } + + private func executeActionBasedOnYieldResult(_ yieldResult: AsyncSequenceSource.YieldResult, source: PSQLRowsDataSource) { + self.eventLoop.preconditionInEventLoop() + switch yieldResult { + case .dropped: + // ignore + break + + case .produceMore: + source.request(for: self) + + case .stopProducing: + // ignore + break + } + } + + var commandTag: String { + guard case .consumed(.success(let consumed)) = self.downstreamState else { + preconditionFailure("commandTag may only be called if all rows have been consumed") + } + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } + } +} + +extension PSQLRowStream: NIOAsyncSequenceProducerDelegate { + func produceMore() { + self.demand() + } + + func didTerminate() { + self.cancel() + } +} + +protocol PSQLRowsDataSource { + + func request(for stream: PSQLRowStream) + func cancel(for stream: PSQLRowStream) + +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift new file mode 100644 index 00000000..6106fd21 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -0,0 +1,118 @@ +import Logging +import NIOCore + +enum HandlerTask: Sendable { + case extendedQuery(ExtendedQueryContext) + case closeCommand(CloseCommandContext) + case startListening(NotificationListener) + case cancelListening(String, Int) + case executePreparedStatement(PreparedStatementContext) +} + +enum PSQLTask { + case extendedQuery(ExtendedQueryContext) + case closeCommand(CloseCommandContext) + + func failWithError(_ error: PSQLError) { + switch self { + case .extendedQuery(let extendedQueryContext): + switch extendedQueryContext.query { + case .unnamed(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .executeStatement(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .prepareStatement(_, _, _, let eventLoopPromise): + eventLoopPromise.fail(error) + } + + case .closeCommand(let closeCommandContext): + closeCommandContext.promise.fail(error) + } + } +} + +final class ExtendedQueryContext: Sendable { + enum Query { + case unnamed(PostgresQuery, EventLoopPromise) + case executeStatement(PSQLExecuteStatement, EventLoopPromise) + case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) + } + + let query: Query + let logger: Logger + + init( + query: PostgresQuery, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .unnamed(query, promise) + self.logger = logger + } + + init( + executeStatement: PSQLExecuteStatement, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .executeStatement(executeStatement, promise) + self.logger = logger + } + + init( + name: String, + query: String, + bindingDataTypes: [PostgresDataType], + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise) + self.logger = logger + } +} + +final class PreparedStatementContext: Sendable { + let name: String + let sql: String + let bindingDataTypes: [PostgresDataType] + let bindings: PostgresBindings + let logger: Logger + let promise: EventLoopPromise + + init( + name: String, + sql: String, + bindings: PostgresBindings, + bindingDataTypes: [PostgresDataType], + logger: Logger, + promise: EventLoopPromise + ) { + self.name = name + self.sql = sql + self.bindings = bindings + if bindingDataTypes.isEmpty { + self.bindingDataTypes = bindings.metadata.map(\.dataType) + } else { + self.bindingDataTypes = bindingDataTypes + } + self.logger = logger + self.promise = promise + } +} + +final class CloseCommandContext: Sendable { + let target: CloseTarget + let logger: Logger + let promise: EventLoopPromise + + init( + target: CloseTarget, + logger: Logger, + promise: EventLoopPromise + ) { + self.target = target + self.logger = logger + self.promise = promise + } +} + diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift new file mode 100644 index 00000000..792beec3 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -0,0 +1,184 @@ +import NIOCore +//import struct Foundation.Data + + +/// A protocol to implement for all associated value in the `PostgresBackendMessage` enum +protocol PSQLMessagePayloadDecodable { + + /// Decodes the associated value for a `PostgresBackendMessage` from the given `ByteBuffer`. + /// + /// When the decoding is done all bytes in the given `ByteBuffer` must be consumed. + /// `buffer.readableBytes` must be `0`. In case of an error a `PartialDecodingError` + /// must be thrown. + /// + /// - Parameter buffer: The `ByteBuffer` to read the message from. When done the `ByteBuffer` + /// must be fully consumed. + static func decode(from buffer: inout ByteBuffer) throws -> Self +} + +/// A wire message that is created by a Postgres server to be consumed by Postgres client. +/// +/// All messages are defined in the official Postgres Documentation in the section +/// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) +enum PostgresBackendMessage: Hashable { + + typealias PayloadDecodable = PSQLMessagePayloadDecodable + + case authentication(Authentication) + case backendKeyData(BackendKeyData) + case bindComplete + case closeComplete + case commandComplete(String) + case dataRow(DataRow) + case emptyQueryResponse + case error(ErrorResponse) + case noData + case notice(NoticeResponse) + case notification(NotificationResponse) + case parameterDescription(ParameterDescription) + case parameterStatus(ParameterStatus) + case parseComplete + case portalSuspended + case readyForQuery(TransactionState) + case rowDescription(RowDescription) + case sslSupported + case sslUnsupported +} + +extension PostgresBackendMessage { + enum ID: UInt8, Hashable { + case authentication = 82 // ascii: R + case backendKeyData = 75 // ascii: K + case bindComplete = 50 // ascii: 2 + case closeComplete = 51 // ascii: 3 + case commandComplete = 67 // ascii: C + case copyData = 100 // ascii: d + case copyDone = 99 // ascii: c + case copyInResponse = 71 // ascii: G + case copyOutResponse = 72 // ascii: H + case copyBothResponse = 87 // ascii: W + case dataRow = 68 // ascii: D + case emptyQueryResponse = 73 // ascii: I + case error = 69 // ascii: E + case functionCallResponse = 86 // ascii: V + case negotiateProtocolVersion = 118 // ascii: v + case noData = 110 // ascii: n + case noticeResponse = 78 // ascii: N + case notificationResponse = 65 // ascii: A + case parameterDescription = 116 // ascii: t + case parameterStatus = 83 // ascii: S + case parseComplete = 49 // ascii: 1 + case portalSuspended = 115 // ascii: s + case readyForQuery = 90 // ascii: Z + case rowDescription = 84 // ascii: T + } +} + +extension PostgresBackendMessage { + + static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresBackendMessage { + switch messageID { + case .authentication: + return try .authentication(.decode(from: &buffer)) + + case .backendKeyData: + return try .backendKeyData(.decode(from: &buffer)) + + case .bindComplete: + return .bindComplete + + case .closeComplete: + return .closeComplete + + case .commandComplete: + guard let commandTag = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .commandComplete(commandTag) + + case .dataRow: + return try .dataRow(.decode(from: &buffer)) + + case .emptyQueryResponse: + return .emptyQueryResponse + + case .parameterStatus: + return try .parameterStatus(.decode(from: &buffer)) + + case .error: + return try .error(.decode(from: &buffer)) + + case .noData: + return .noData + + case .noticeResponse: + return try .notice(.decode(from: &buffer)) + + case .notificationResponse: + return try .notification(.decode(from: &buffer)) + + case .parameterDescription: + return try .parameterDescription(.decode(from: &buffer)) + + case .parseComplete: + return .parseComplete + + case .portalSuspended: + return .portalSuspended + + case .readyForQuery: + return try .readyForQuery(.decode(from: &buffer)) + + case .rowDescription: + return try .rowDescription(.decode(from: &buffer)) + + case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: + preconditionFailure() + } + } +} + +extension PostgresBackendMessage: CustomDebugStringConvertible { + var debugDescription: String { + switch self { + case .authentication(let authentication): + return ".authentication(\(String(reflecting: authentication)))" + case .backendKeyData(let backendKeyData): + return ".backendKeyData(\(String(reflecting: backendKeyData)))" + case .bindComplete: + return ".bindComplete" + case .closeComplete: + return ".closeComplete" + case .commandComplete(let commandTag): + return ".commandComplete(\(String(reflecting: commandTag)))" + case .dataRow(let dataRow): + return ".dataRow(\(String(reflecting: dataRow)))" + case .emptyQueryResponse: + return ".emptyQueryResponse" + case .error(let error): + return ".error(\(String(reflecting: error)))" + case .noData: + return ".noData" + case .notice(let notice): + return ".notice(\(String(reflecting: notice)))" + case .notification(let notification): + return ".notification(\(String(reflecting: notification)))" + case .parameterDescription(let parameterDescription): + return ".parameterDescription(\(String(reflecting: parameterDescription)))" + case .parameterStatus(let parameterStatus): + return ".parameterStatus(\(String(reflecting: parameterStatus)))" + case .parseComplete: + return ".parseComplete" + case .portalSuspended: + return ".portalSuspended" + case .readyForQuery(let transactionState): + return ".readyForQuery(\(String(reflecting: transactionState)))" + case .rowDescription(let rowDescription): + return ".rowDescription(\(String(reflecting: rowDescription)))" + case .sslSupported: + return ".sslSupported" + case .sslUnsupported: + return ".sslUnsupported" + } + } +} diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift new file mode 100644 index 00000000..6f6be7ec --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -0,0 +1,208 @@ +import NIOCore + +struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { + typealias InboundOut = PostgresBackendMessage + + private(set) var hasAlreadyReceivedBytes: Bool + + init(hasAlreadyReceivedBytes: Bool = false) { + self.hasAlreadyReceivedBytes = hasAlreadyReceivedBytes + } + + mutating func decode(buffer: inout ByteBuffer) throws -> PostgresBackendMessage? { + + if !self.hasAlreadyReceivedBytes { + // We have not received any bytes yet! Let's peek at the first message id. If it + // is a "S" or "N" we assume that it is connected to an SSL upgrade request. All + // other messages that we expect now, don't start with either "S" or "N" + + let startReaderIndex = buffer.readerIndex + guard let firstByte = buffer.readInteger(as: UInt8.self) else { + return nil + } + + switch firstByte { + case UInt8(ascii: "S"): + self.hasAlreadyReceivedBytes = true + return .sslSupported + + case UInt8(ascii: "N"): + self.hasAlreadyReceivedBytes = true + return .sslUnsupported + + default: + // move reader index back + buffer.moveReaderIndex(to: startReaderIndex) + self.hasAlreadyReceivedBytes = true + } + } + + // all other packages start with a MessageID (UInt8) and their message length (UInt32). + // do we have enough bytes for that? + let startReaderIndex = buffer.readerIndex + guard let (idByte, length) = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt32).self) else { + // if this fails, the readerIndex wasn't changed + return nil + } + + // 1. try to read the message + guard var message = buffer.readSlice(length: Int(length) - 4) else { + // we need to move the reader index back to its start point + buffer.moveReaderIndex(to: startReaderIndex) + return nil + } + + // 2. make sure we have a known message identifier + guard let messageID = PostgresBackendMessage.ID(rawValue: idByte) else { + buffer.moveReaderIndex(to: startReaderIndex) + let completeMessage = buffer.readSlice(length: Int(length) + 1)! + throw PostgresMessageDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessage) + } + + // 3. decode the message + do { + let result = try PostgresBackendMessage.decode(from: &message, for: messageID) + if message.readableBytes > 0 { + throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(0, actual: message.readableBytes) + } + return result + } catch let error as PSQLPartialDecodingError { + buffer.moveReaderIndex(to: startReaderIndex) + let completeMessage = buffer.readSlice(length: Int(length) + 1)! + throw PostgresMessageDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessage) + } catch { + preconditionFailure("Expected to only see `PartialDecodingError`s here.") + } + } + + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PostgresBackendMessage? { + try self.decode(buffer: &buffer) + } +} + + + +/// An error representing a failure to decode [a Postgres wire message](https://www.postgresql.org/docs/13/protocol-message-formats.html) +/// to the Swift structure `PSQLBackendMessage`. +/// +/// If you encounter a `DecodingError` when using a trusted Postgres server please make to file an issue at: +/// [https://github.com/vapor/postgres-nio/issues](https://github.com/vapor/postgres-nio/issues) +struct PostgresMessageDecodingError: Error { + + /// The backend message ID bytes + let messageID: UInt8 + + /// The backend message's payload encoded in base64 + let payload: String + + /// A textual description of the error + let description: String + + /// The file this error was thrown in + let file: String + + /// The line in `file` this error was thrown + let line: Int + + static func withPartialError( + _ partialError: PSQLPartialDecodingError, + messageID: UInt8, + messageBytes: ByteBuffer + ) -> Self { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PostgresMessageDecodingError( + messageID: messageID, + payload: data.base64EncodedString(), + description: partialError.description, + file: partialError.file, + line: partialError.line) + } + + static func unknownMessageIDReceived( + messageID: UInt8, + messageBytes: ByteBuffer, + file: String = #fileID, + line: Int = #line + ) -> Self { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PostgresMessageDecodingError( + messageID: messageID, + payload: data.base64EncodedString(), + description: "Received a message with messageID '\(Character(UnicodeScalar(messageID)))'. There is no message type associated with this message identifier.", + file: file, + line: line) + } + +} + +struct PSQLPartialDecodingError: Error { + /// A textual description of the error + let description: String + + /// The file this error was thrown in + let file: String + + /// The line in `file` this error was thrown + let line: Int + + static func valueNotRawRepresentable( + value: Target.RawValue, + asType: Target.Type, + file: String = #fileID, + line: Int = #line + ) -> Self { + return PSQLPartialDecodingError( + description: "Can not represent '\(value)' with type '\(asType)'.", + file: file, line: line) + } + + static func unexpectedValue(value: Any, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Value '\(value)' is not expected.", + file: file, line: line) + } + + static func expectedAtLeastNRemainingBytes(_ expected: Int, actual: Int, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected at least '\(expected)' remaining bytes. But only found \(actual).", + file: file, line: line) + } + + static func expectedExactlyNRemainingBytes(_ expected: Int, actual: Int, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected exactly '\(expected)' remaining bytes. But found \(actual).", + file: file, line: line) + } + + static func fieldNotDecodable(type: Any.Type, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Could not read '\(type)' from ByteBuffer.", + file: file, line: line) + } + + static func integerMustBePositiveOrNull(_ actual: Number, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected the integer to be positive or null, but got \(actual).", + file: file, line: line) + } +} + +extension ByteBuffer { + mutating func throwingReadInteger(as: I.Type, file: String = #fileID, line: Int = #line) throws -> I { + guard let result = self.readInteger(endianness: .big, as: I.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(MemoryLayout.size, actual: self.readableBytes, file: file, line: line) + } + return result + } + + mutating func throwingMoveReaderIndex(forwardBy offset: Int, file: String = #fileID, line: Int = #line) throws { + guard self.readSlice(length: offset) != nil else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(offset, actual: self.readableBytes, file: file, line: line) + } + } +} + diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift new file mode 100644 index 00000000..7598a31a --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -0,0 +1,88 @@ +import NIOCore + +/// A representation of a cell value within a ``PostgresRow`` and ``PostgresRandomAccessRow``. +public struct PostgresCell: Sendable, Equatable { + /// The cell's value as raw bytes. + public var bytes: ByteBuffer? + /// The cell's data type. This is important metadata when decoding the cell. + public var dataType: PostgresDataType + /// The format in which the cell's bytes are encoded. + public var format: PostgresFormat + + /// The cell's column name within the row. + public var columnName: String + /// The cell's column index within the row. + public var columnIndex: Int + + public init( + bytes: ByteBuffer?, + dataType: PostgresDataType, + format: PostgresFormat, + columnName: String, + columnIndex: Int + ) { + self.bytes = bytes + self.dataType = dataType + self.format = format + + self.columnName = columnName + self.columnIndex = columnIndex + } +} + +extension PostgresCell { + /// Decode the cell into a Swift type, that conforms to ``PostgresDecodable`` + /// + /// - Parameters: + /// - _: The Swift type, which conforms to ``PostgresDecodable``, to decode from the cell's ``PostgresCell/bytes`` values. + /// - context: A ``PostgresDecodingContext`` to supply a custom ``PostgresJSONDecoder`` for decoding JSON fields. + /// - file: The source file in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - line: The source file line in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - Returns: A decoded Swift type. + @inlinable + public func decode( + _: T.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> T { + var copy = self.bytes + do { + return try T._decodeRaw( + from: ©, + type: self.dataType, + format: self.format, + context: context + ) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: self.columnName, + columnIndex: self.columnIndex, + targetType: T.self, + postgresType: self.dataType, + postgresFormat: self.format, + postgresData: copy, + file: file, + line: line + ) + } + } + + + /// Decode the cell into a Swift type, that conforms to ``PostgresDecodable`` + /// + /// - Parameters: + /// - _: The Swift type, which conforms to ``PostgresDecodable``, to decode from the cell's ``PostgresCell/bytes`` values. + /// - file: The source file in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - line: The source file line in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - Returns: A decoded Swift type. + @inlinable + public func decode( + _: T.Type, + file: String = #fileID, + line: Int = #line + ) throws -> T { + try self.decode(T.self, context: .default, file: file, line: line) + } +} diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift new file mode 100644 index 00000000..0a14849a --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -0,0 +1,855 @@ +import NIOCore +import NIOTLS +import Crypto +import Logging + +final class PostgresChannelHandler: ChannelDuplexHandler { + typealias OutboundIn = HandlerTask + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private let logger: Logger + private let eventLoop: EventLoop + private var state: ConnectionStateMachine + + /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). + /// + /// The context is captured in `handlerAdded` and released` in `handlerRemoved` + private var handlerContext: ChannelHandlerContext? + private var rowStream: PSQLRowStream? + private var decoder: NIOSingleStepByteToMessageProcessor + private var encoder: PostgresFrontendMessageEncoder! + private let configuration: PostgresConnection.InternalConfiguration + private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? + + private var listenState = ListenStateMachine() + private var preparedStatementState = PreparedStatementStateMachine() + + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + logger: Logger, + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? + ) { + self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) + self.eventLoop = eventLoop + self.configuration = configuration + self.configureSSLCallback = configureSSLCallback + self.logger = logger + self.decoder = NIOSingleStepByteToMessageProcessor(PostgresBackendMessageDecoder()) + } + + #if DEBUG + /// for testing purposes only + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + state: ConnectionStateMachine = .init(.initialized), + logger: Logger = .psqlNoOpLogger, + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? + ) { + self.state = state + self.eventLoop = eventLoop + self.configuration = configuration + self.configureSSLCallback = configureSSLCallback + self.logger = logger + self.decoder = NIOSingleStepByteToMessageProcessor(PostgresBackendMessageDecoder()) + } + #endif + + // MARK: Handler lifecycle + + func handlerAdded(context: ChannelHandlerContext) { + self.handlerContext = context + self.encoder = PostgresFrontendMessageEncoder(buffer: context.channel.allocator.buffer(capacity: 256)) + + if context.channel.isActive { + self.connected(context: context) + } + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.handlerContext = nil + } + + // MARK: Channel handler incoming + + func channelActive(context: ChannelHandlerContext) { + // `fireChannelActive` needs to be called BEFORE we set the state machine to connected, + // since we want to make sure that upstream handlers know about the active connection before + // it receives a + context.fireChannelActive() + + self.connected(context: context) + } + + func channelInactive(context: ChannelHandlerContext) { + do { + try self.decoder.finishProcessing(seenEOF: true) { message in + self.handleMessage(message, context: context) + } + } catch let error as PostgresMessageDecodingError { + let action = self.state.errorHappened(.messageDecodingFailure(error)) + self.run(action, with: context) + } catch { + preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") + } + + self.logger.trace("Channel inactive.") + let action = self.state.closed() + self.run(action, with: context) + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.debug("Channel error caught.", metadata: [.error: "\(error)"]) + let action = self.state.errorHappened(.connectionError(underlying: error)) + self.run(action, with: context) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let buffer = self.unwrapInboundIn(data) + + do { + try self.decoder.process(buffer: buffer) { message in + self.handleMessage(message, context: context) + } + } catch let error as PostgresMessageDecodingError { + let action = self.state.errorHappened(.messageDecodingFailure(error)) + self.run(action, with: context) + } catch { + preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") + } + } + + private func handleMessage(_ message: PostgresBackendMessage, context: ChannelHandlerContext) { + self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) + let action: ConnectionStateMachine.ConnectionAction + + switch message { + case .authentication(let authentication): + action = self.state.authenticationMessageReceived(authentication) + case .backendKeyData(let keyData): + action = self.state.backendKeyDataReceived(keyData) + case .bindComplete: + action = self.state.bindCompleteReceived() + case .closeComplete: + action = self.state.closeCompletedReceived() + case .commandComplete(let commandTag): + action = self.state.commandCompletedReceived(commandTag) + case .dataRow(let dataRow): + action = self.state.dataRowReceived(dataRow) + case .emptyQueryResponse: + action = self.state.emptyQueryResponseReceived() + case .error(let errorResponse): + action = self.state.errorReceived(errorResponse) + case .noData: + action = self.state.noDataReceived() + case .notice(let noticeResponse): + action = self.state.noticeReceived(noticeResponse) + case .notification(let notification): + action = self.state.notificationReceived(notification) + case .parameterDescription(let parameterDescription): + action = self.state.parameterDescriptionReceived(parameterDescription) + case .parameterStatus(let parameterStatus): + action = self.state.parameterStatusReceived(parameterStatus) + case .parseComplete: + action = self.state.parseCompleteReceived() + case .portalSuspended: + action = self.state.portalSuspendedReceived() + case .readyForQuery(let transactionState): + action = self.state.readyForQueryReceived(transactionState) + case .rowDescription(let rowDescription): + action = self.state.rowDescriptionReceived(rowDescription) + case .sslSupported: + action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) + case .sslUnsupported: + action = self.state.sslUnsupportedReceived() + } + + self.run(action, with: context) + } + + func channelReadComplete(context: ChannelHandlerContext) { + let action = self.state.channelReadComplete() + self.run(action, with: context) + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + self.logger.trace("User inbound event received", metadata: [ + .userEvent: "\(event)" + ]) + + switch event { + case TLSUserEvent.handshakeCompleted: + let action = self.state.sslEstablished() + self.run(action, with: context) + default: + context.fireUserInboundEventTriggered(event) + } + } + + // MARK: Channel handler outgoing + + func read(context: ChannelHandlerContext) { + self.logger.trace("Channel read event received") + let action = self.state.readEventCaught() + self.run(action, with: context) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let handlerTask = self.unwrapOutboundIn(data) + let psqlTask: PSQLTask + + switch handlerTask { + case .closeCommand(let command): + psqlTask = .closeCommand(command) + case .extendedQuery(let query): + psqlTask = .extendedQuery(query) + + case .startListening(let listener): + switch self.listenState.startListening(listener) { + case .startListening(let channel): + psqlTask = self.makeStartListeningQuery(channel: channel, context: context) + + case .none: + return + + case .succeedListenStart(let listener): + listener.startListeningSucceeded(handler: self) + return + } + + case .cancelListening(let channel, let id): + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .none: + return + + case .stopListening(let channel, let listener): + psqlTask = self.makeUnlistenQuery(channel: channel, context: context) + listener.failed(CancellationError()) + + case .cancelListener(let listener): + listener.failed(CancellationError()) + return + } + case .executePreparedStatement(let preparedStatement): + let action = self.preparedStatementState.lookup( + preparedStatement: preparedStatement + ) + switch action { + case .prepareStatement: + psqlTask = self.makePrepareStatementTask( + preparedStatement: preparedStatement, + context: context + ) + case .waitForAlreadyInFlightPreparation: + // The state machine already keeps track of this + // and will execute the statement as soon as it's prepared + return + case .executeStatement(let rowDescription): + psqlTask = self.makeExecutePreparedStatementTask( + preparedStatement: preparedStatement, + rowDescription: rowDescription + ) + case .returnError(let error): + preparedStatement.promise.fail(error) + return + } + } + + let action = self.state.enqueue(task: psqlTask) + self.run(action, with: context) + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + self.logger.trace("Close triggered by upstream.") + guard mode == .all else { + // TODO: Support also other modes ? + promise?.fail(ChannelError.operationUnsupported) + return + } + + let action = self.state.close(promise: promise) + self.run(action, with: context) + } + + func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + self.logger.trace("User outbound event received", metadata: [.userEvent: "\(event)"]) + + switch event { + case PSQLOutgoingEvent.authenticate(let authContext): + let action = self.state.provideAuthenticationContext(authContext) + self.run(action, with: context) + + case PSQLOutgoingEvent.gracefulShutdown: + let action = self.state.gracefulClose(promise) + self.run(action, with: context) + + default: + context.triggerUserOutboundEvent(event, promise: promise) + } + } + + // MARK: Listening + + func cancelNotificationListener(channel: String, id: Int) { + self.eventLoop.preconditionInEventLoop() + + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .cancelListener(let listener): + listener.cancelled() + + case .stopListening(let channel, cancelListener: let listener): + listener.cancelled() + + guard let context = self.handlerContext else { + return + } + + let query = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: query) + self.run(action, with: context) + + case .none: + break + } + } + + // MARK: Channel handler actions + + private func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { + self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"]) + + switch action { + case .establishSSLConnection: + self.establishSSLConnection(context: context) + case .read: + context.read() + case .wait: + break + case .sendStartupMessage(let authContext): + self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendSSLRequest: + self.encoder.ssl() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendPasswordMessage(let mode, let authContext): + self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) + case .sendSaslInitialResponse(let name, let initialResponse): + self.encoder.saslInitialResponse(mechanism: name, bytes: initialResponse) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendSaslResponse(let bytes): + self.encoder.saslResponse(bytes) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .closeConnectionAndCleanup(let cleanupContext): + self.closeConnectionAndCleanup(cleanupContext, context: context) + case .fireChannelInactive: + context.fireChannelInactive() + case .sendParseDescribeSync(let name, let query, let bindingDataTypes): + self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context) + case .sendBindExecuteSync(let executeStatement): + self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) + case .sendParseDescribeBindExecuteSync(let query): + self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) + case .succeedQuery(let promise, with: let result): + self.succeedQuery(promise, result: result, context: context) + case .failQuery(let promise, with: let error, let cleanupContext): + promise.fail(error) + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + + case .forwardRows(let rows): + self.rowStream!.receive(rows) + + case .forwardStreamComplete(let buffer, let commandTag): + guard let rowStream = self.rowStream else { + // if the stream was cancelled we don't have it here anymore. + return + } + self.rowStream = nil + if buffer.count > 0 { + rowStream.receive(buffer) + } + rowStream.receive(completion: .success(commandTag)) + + + case .forwardStreamError(let error, let read, let cleanupContext): + self.rowStream!.receive(completion: .failure(error)) + self.rowStream = nil + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } else if read { + context.read() + } + + case .provideAuthenticationContext: + context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup) + + if let username = self.configuration.username { + let authContext = AuthContext( + username: username, + password: self.configuration.password, + database: self.configuration.database, + additionalParameters: self.configuration.options.additionalStartupParameters + ) + let action = self.state.provideAuthenticationContext(authContext) + return self.run(action, with: context) + } + case .fireEventReadyForQuery: + context.fireUserInboundEventTriggered(PSQLEvent.readyForQuery) + case .closeConnection(let promise): + if context.channel.isActive { + // The normal, graceful termination procedure is that the frontend sends a Terminate + // message and immediately closes the connection. On receipt of this message, the + // backend closes the connection and terminates. + self.encoder.terminate() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + context.close(mode: .all, promise: promise) + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + promise.succeed(rowDescription) + case .failPreparedStatementCreation(let promise, with: let error, let cleanupContext): + promise.fail(error) + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + case .sendCloseSync(let sendClose): + self.sendCloseAndSyncMessage(sendClose, context: context) + case .succeedClose(let closeContext): + closeContext.promise.succeed(Void()) + case .failClose(let closeContext, with: let error, let cleanupContext): + closeContext.promise.fail(error) + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + case .forwardNotificationToListeners(let notification): + self.forwardNotificationToListeners(notification, context: context) + } + } + + // MARK: - Private Methods - + + private func connected(context: ChannelHandlerContext) { + let action = self.state.connected(tls: .init(self.configuration.tls)) + self.run(action, with: context) + } + + private func establishSSLConnection(context: ChannelHandlerContext) { + // This method must only be called, if we signalized the StateMachine before that we are + // able to setup a SSL connection. + do { + try self.configureSSLCallback!(context.channel, self) + let action = self.state.sslHandlerAdded() + self.run(action, with: context) + } catch { + let action = self.state.errorHappened(.failedToAddSSLHandler(underlying: error)) + self.run(action, with: context) + } + } + + private func sendPasswordMessage( + mode: PasswordAuthencationMode, + authContext: AuthContext, + context: ChannelHandlerContext + ) { + switch mode { + case .md5(let salt): + let hash1 = (authContext.password ?? "") + authContext.username + let pwdhash = Insecure.MD5.hash(data: [UInt8](hash1.utf8)).asciiHexDigest() + + var hash2 = [UInt8]() + hash2.reserveCapacity(pwdhash.count + 4) + hash2.append(contentsOf: pwdhash) + var saltNetworkOrder = salt.bigEndian + withUnsafeBytes(of: &saltNetworkOrder) { ptr in + hash2.append(contentsOf: ptr) + } + let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() + + self.encoder.password(hash.utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + + case .cleartext: + self.encoder.password((authContext.password ?? "").utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + } + + private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { + switch sendClose { + case .preparedStatement(let name): + self.encoder.closePreparedStatement(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + + case .portal(let name): + self.encoder.closePortal(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + } + + private func sendParseDescribeAndSyncMessage( + statementName: String, + query: String, + bindingDataTypes: [PostgresDataType], + context: ChannelHandlerContext + ) { + precondition(self.rowStream == nil, "Expected to not have an open stream at this point") + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes) + self.encoder.describePreparedStatement(statementName) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + + private func sendBindExecuteAndSyncMessage( + executeStatement: PSQLExecuteStatement, + context: ChannelHandlerContext + ) { + self.encoder.bind( + portalName: "", + preparedStatementName: executeStatement.name, + bind: executeStatement.binds + ) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + + private func sendParseDescribeBindExecuteAndSyncMessage( + query: PostgresQuery, + context: ChannelHandlerContext + ) { + precondition(self.rowStream == nil, "Expected to not have an open stream at this point") + let unnamedStatementName = "" + self.encoder.parse( + preparedStatementName: unnamedStatementName, + query: query.sql, + parameters: query.binds.metadata.lazy.map(\.dataType) + ) + self.encoder.describePreparedStatement(unnamedStatementName) + self.encoder.bind(portalName: "", preparedStatementName: unnamedStatementName, bind: query.binds) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + + private func succeedQuery( + _ promise: EventLoopPromise, + result: QueryResult, + context: ChannelHandlerContext + ) { + let rows: PSQLRowStream + switch result.value { + case .rowDescription(let columns): + rows = PSQLRowStream( + source: .stream(columns, self), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + self.rowStream = rows + + case .noRows(let summary): + rows = PSQLRowStream( + source: .noRows(.success(summary)), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + } + + promise.succeed(rows) + } + + private func closeConnectionAndCleanup( + _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, + context: ChannelHandlerContext + ) { + self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) + + // 1. fail all tasks + cleanup.tasks.forEach { task in + task.failWithError(cleanup.error) + } + + // 2. stop all listeners + for listener in self.listenState.fail(cleanup.error) { + listener.failed(cleanup.error) + } + + // 3. fire an error + if cleanup.error.code != .clientClosedConnection { + context.fireErrorCaught(cleanup.error) + } + + // 4. close the connection or fire channel inactive + switch cleanup.action { + case .close: + context.close(mode: .all, promise: cleanup.closePromise) + case .fireChannelInactive: + cleanup.closePromise?.succeed(()) + context.fireChannelInactive() + } + } + + private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), + logger: self.logger, + promise: promise + ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) + promise.futureResult.whenComplete { result in + let (selfTransferred, context) = loopBound.value + selfTransferred.startListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func startListenCompleted(_ result: Result, for channel: String, context: ChannelHandlerContext) { + switch result { + case .success: + switch self.listenState.startListeningSucceeded(channel: channel) { + case .activateListeners(let listeners): + for list in listeners { + list.startListeningSucceeded(handler: self) + } + + case .stopListening: + let task = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let finalError: PSQLError + if var psqlError = error as? PSQLError { + psqlError.code = .listenFailed + finalError = psqlError + } else { + var psqlError = PSQLError(code: .listenFailed) + psqlError.underlying = error + finalError = psqlError + } + let listeners = self.listenState.startListeningFailed(channel: channel, error: finalError) + for list in listeners { + list.failed(finalError) + } + } + } + + private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), + logger: self.logger, + promise: promise + ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) + promise.futureResult.whenComplete { result in + let (selfTransferred, context) = loopBound.value + selfTransferred.stopListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func stopListenCompleted( + _ result: Result, + for channel: String, + context: ChannelHandlerContext + ) { + switch result { + case .success: + switch self.listenState.stopListeningSucceeded(channel: channel) { + case .none: + break + + case .startListening: + let task = self.makeStartListeningQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let action = self.state.errorHappened(.unlistenError(underlying: error)) + self.run(action, with: context) + } + } + + private func forwardNotificationToListeners( + _ notification: PostgresBackendMessage.NotificationResponse, + context: ChannelHandlerContext + ) { + switch self.listenState.notificationReceived(channel: notification.channel) { + case .none: + break + + case .notify(let listeners): + for listener in listeners { + listener.notificationReceived(notification) + } + } + } + + private func makePrepareStatementTask( + preparedStatement: PreparedStatementContext, + context: ChannelHandlerContext + ) -> PSQLTask { + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) + promise.futureResult.whenComplete { result in + let (selfTransferred, context) = loopBound.value + switch result { + case .success(let rowDescription): + selfTransferred.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + case .failure(let error): + let psqlError: PSQLError + if let error = error as? PSQLError { + psqlError = error + } else { + psqlError = .connectionError(underlying: error) + } + selfTransferred.prepareStatementFailed( + name: preparedStatement.name, + error: psqlError, + context: context + ) + } + } + return .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + bindingDataTypes: preparedStatement.bindingDataTypes, + logger: preparedStatement.logger, + promise: promise + )) + } + + private func makeExecutePreparedStatementTask( + preparedStatement: PreparedStatementContext, + rowDescription: RowDescription? + ) -> PSQLTask { + return .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + } + + private func prepareStatementComplete( + name: String, + rowDescription: RowDescription?, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.preparationComplete( + name: name, + rowDescription: rowDescription + ) + for preparedStatement in action.statements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: action.rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + ) + self.run(action, with: context) + } + } + + private func prepareStatementFailed( + name: String, + error: PSQLError, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.errorHappened( + name: name, + error: error + ) + for statement in action.statements { + statement.promise.fail(action.error) + } + } +} + +extension PostgresChannelHandler: PSQLRowsDataSource { + func request(for stream: PSQLRowStream) { + guard self.rowStream === stream, let handlerContext = self.handlerContext else { + return + } + let action = self.state.requestQueryRows() + self.run(action, with: handlerContext) + } + + func cancel(for stream: PSQLRowStream) { + guard self.rowStream === stream, let handlerContext = self.handlerContext else { + return + } + let action = self.state.cancelQueryStream() + self.run(action, with: handlerContext) + } +} + +private extension Insecure.MD5.Digest { + + private static let lowercaseLookup: [UInt8] = [ + UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), + UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), + UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "a"), UInt8(ascii: "b"), + UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), + ] + + func asciiHexDigest() -> [UInt8] { + var result = [UInt8]() + result.reserveCapacity(2 * Insecure.MD5Digest.byteCount) + for byte in self { + result.append(Self.lowercaseLookup[Int(byte >> 4)]) + result.append(Self.lowercaseLookup[Int(byte & 0x0F)]) + } + return result + } + + func md5PrefixHexdigest() -> String { + // TODO: The array should be stack allocated in the best case. But we support down to 5.2. + // Given that this method is called only on startup of a new connection, this is an + // okay tradeoff for now. + var result = [UInt8]() + result.reserveCapacity(3 + 2 * Insecure.MD5Digest.byteCount) + result.append(UInt8(ascii: "m")) + result.append(UInt8(ascii: "d")) + result.append(UInt8(ascii: "5")) + + for byte in self { + result.append(Self.lowercaseLookup[Int(byte >> 4)]) + result.append(Self.lowercaseLookup[Int(byte & 0x0F)]) + } + return String(decoding: result, as: Unicode.UTF8.self) + } +} + +extension ConnectionStateMachine.TLSConfiguration { + fileprivate init(_ tls: PostgresConnection.Configuration.TLS) { + switch (tls.isAllowed, tls.isEnforced) { + case (false, _): + self = .disable + case (true, true): + self = .require + case (true, false): + self = .prefer + } + } +} diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift new file mode 100644 index 00000000..fd82c8ea --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -0,0 +1,235 @@ +import NIOCore +import class Foundation.JSONEncoder +import class Foundation.JSONDecoder + +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +public protocol PostgresThrowingDynamicTypeEncodable { + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)`` + var psqlType: PostgresDataType { get } + + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. + var psqlFormat: PostgresFormat { get } + + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws +} + +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +/// +/// This is the non-throwing alternative to ``PostgresThrowingDynamicTypeEncodable``. It allows users +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without having to spell `try`. +public protocol PostgresDynamicTypeEncodable: PostgresThrowingDynamicTypeEncodable { + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) +} + +/// A type that can encode itself to a postgres wire binary representation. +public protocol PostgresEncodable: PostgresThrowingDynamicTypeEncodable { + // TODO: Rename to `PostgresThrowingEncodable` with next major release + + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)``. + static var psqlType: PostgresDataType { get } + + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. + static var psqlFormat: PostgresFormat { get } +} + +/// A type that can encode itself to a postgres wire binary representation. It enforces that the +/// ``PostgresEncodable/encode(into:context:)-1jkcp`` does not throw. This allows users +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without +/// having to spell `try`. +public protocol PostgresNonThrowingEncodable: PostgresEncodable, PostgresDynamicTypeEncodable { + // TODO: Rename to `PostgresEncodable` with next major release +} + +/// A type that can decode itself from a postgres wire binary representation. +/// +/// If you want to conform a type to PostgresDecodable you must implement the decode method. +public protocol PostgresDecodable { + /// A type definition of the type that actually implements the PostgresDecodable protocol. This is an escape hatch to + /// prevent a cycle in the conformace of the Optional type to PostgresDecodable. + /// + /// String? should be PostgresDecodable, String?? should not be PostgresDecodable + associatedtype _DecodableType: PostgresDecodable = Self + + /// Create an entity from the `byteBuffer` in postgres wire format + /// + /// - Parameters: + /// - byteBuffer: A `ByteBuffer` to decode. The byteBuffer is sliced in such a way that it is expected + /// that the complete buffer is consumed for decoding + /// - type: The postgres data type. Depending on this type the `byteBuffer`'s bytes need to be interpreted + /// in different ways. + /// - format: The postgres wire format. Can be `.text` or `.binary` + /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` + /// to use when decoding json and metadata to create better errors. + /// - Returns: A decoded object + init( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws + + /// Decode an entity from the `byteBuffer` in postgres wire format. This method has a default implementation and + /// is only overwritten for `Optional`s. Other than in the + static func _decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self +} + +extension PostgresDecodable { + @inlinable + public static func _decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { + guard var buffer = byteBuffer else { + throw PostgresDecodingError.Code.missingData + } + return try self.init(from: &buffer, type: type, format: format, context: context) + } +} + +/// A type that can be encoded into and decoded from a postgres binary format +public typealias PostgresCodable = PostgresEncodable & PostgresDecodable + +extension PostgresEncodable { + @inlinable + public var psqlType: PostgresDataType { Self.psqlType } + + @inlinable + public var psqlFormat: PostgresFormat { Self.psqlFormat } +} + +extension PostgresThrowingDynamicTypeEncodable { + @inlinable + func encodeRaw( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + // The length of the parameter value, in bytes (this count does not include + // itself). Can be zero. + let lengthIndex = buffer.writerIndex + buffer.writeInteger(0, as: Int32.self) + let startIndex = buffer.writerIndex + // The value of the parameter, in the format indicated by the associated format + // code. n is the above length. + try self.encode(into: &buffer, context: context) + + // overwrite the empty length, with the real value + buffer.setInteger(numericCast(buffer.writerIndex - startIndex), at: lengthIndex, as: Int32.self) + } +} + +extension PostgresDynamicTypeEncodable { + @inlinable + func encodeRaw( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + // The length of the parameter value, in bytes (this count does not include + // itself). Can be zero. + let lengthIndex = buffer.writerIndex + buffer.writeInteger(0, as: Int32.self) + let startIndex = buffer.writerIndex + // The value of the parameter, in the format indicated by the associated format + // code. n is the above length. + self.encode(into: &buffer, context: context) + + // overwrite the empty length, with the real value + buffer.setInteger(numericCast(buffer.writerIndex - startIndex), at: lengthIndex, as: Int32.self) + } +} + +/// A context that is passed to Swift objects that are encoded into the Postgres wire format. Used +/// to pass further information to the encoding method. +public struct PostgresEncodingContext: Sendable { + /// A ``PostgresJSONEncoder`` used to encode the object to json. + public var jsonEncoder: JSONEncoder + + /// Creates a ``PostgresEncodingContext`` with the given ``PostgresJSONEncoder``. In case you want + /// to use the a ``PostgresEncodingContext`` with an unconfigured Foundation `JSONEncoder` + /// you can use the ``default`` context instead. + /// + /// - Parameter jsonEncoder: A ``PostgresJSONEncoder`` to use when encoding objects to json + public init(jsonEncoder: JSONEncoder) { + self.jsonEncoder = jsonEncoder + } +} + +extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { + /// A default ``PostgresEncodingContext`` that uses a Foundation `JSONEncoder`. + public static let `default` = PostgresEncodingContext(jsonEncoder: JSONEncoder()) +} + +/// A context that is passed to Swift objects that are decoded from the Postgres wire format. Used +/// to pass further information to the decoding method. +public struct PostgresDecodingContext: Sendable { + /// A ``PostgresJSONDecoder`` used to decode the object from json. + public var jsonDecoder: JSONDecoder + + /// Creates a ``PostgresDecodingContext`` with the given ``PostgresJSONDecoder``. In case you want + /// to use the a ``PostgresDecodingContext`` with an unconfigured Foundation `JSONDecoder` + /// you can use the ``default`` context instead. + /// + /// - Parameter jsonDecoder: A ``PostgresJSONDecoder`` to use when decoding objects from json + public init(jsonDecoder: JSONDecoder) { + self.jsonDecoder = jsonDecoder + } +} + +extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { + /// A default ``PostgresDecodingContext`` that uses a Foundation `JSONDecoder`. + public static let `default` = PostgresDecodingContext(jsonDecoder: Foundation.JSONDecoder()) +} + +extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped._DecodableType == Wrapped { + public typealias _DecodableType = Wrapped + + public init( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + preconditionFailure("This should not be called") + } + + @inlinable + public static func _decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Optional { + switch byteBuffer { + case .some(var buffer): + return try Wrapped(from: &buffer, type: type, format: format, context: context) + case .none: + return .none + } + } +} diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift new file mode 100644 index 00000000..97805418 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -0,0 +1,233 @@ +import NIOCore + +struct PostgresFrontendMessageEncoder { + + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let sslRequestCode: Int32 = 80877103 + + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let cancelRequestCode: Int32 = 80877102 + + static let startupVersionThree: Int32 = 0x00_03_00_00 + + private enum State { + case flushed + case writable + } + + private var buffer: ByteBuffer + private var state: State = .writable + + init(buffer: ByteBuffer) { + self.buffer = buffer + } + + mutating func startup(user: String, database: String?, options: [(String, String)]) { + self.clearIfNeeded() + self.buffer.psqlLengthPrefixed { buffer in + buffer.writeInteger(Self.startupVersionThree) + buffer.writeNullTerminatedString("user") + buffer.writeNullTerminatedString(user) + + if let database = database { + buffer.writeNullTerminatedString("database") + buffer.writeNullTerminatedString(database) + } + + // we don't send replication parameters, as the default is false and this is what we + // need for a client + for (key, value) in options { + buffer.writeNullTerminatedString(key) + buffer.writeNullTerminatedString(value) + } + + buffer.writeInteger(UInt8(0)) + } + } + + mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) { + self.clearIfNeeded() + self.buffer.psqlLengthPrefixed(id: .bind) { buffer in + buffer.writeNullTerminatedString(portalName) + buffer.writeNullTerminatedString(preparedStatementName) + + // The number of parameter format codes that follow (denoted C below). This can be + // zero to indicate that there are no parameters or that the parameters all use the + // default format (text); or one, in which case the specified format code is applied + // to all parameters; or it can equal the actual number of parameters. + buffer.writeInteger(UInt16(bind.count)) + + // The parameter format codes. Each must presently be zero (text) or one (binary). + bind.metadata.forEach { + buffer.writeInteger($0.format.rawValue) + } + + buffer.writeInteger(UInt16(bind.count)) + + var parametersCopy = bind.bytes + buffer.writeBuffer(¶metersCopy) + + // The number of result-column format codes that follow (denoted R below). This can be + // zero to indicate that there are no result columns or that the result columns should + // all use the default format (text); or one, in which case the specified format code + // is applied to all result columns (if any); or it can equal the actual number of + // result columns of the query. + buffer.writeInteger(1, as: Int16.self) + // The result-column format codes. Each must presently be zero (text) or one (binary). + buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) + } + } + + mutating func cancel(processID: Int32, secretKey: Int32) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(16), Self.cancelRequestCode, processID, secretKey) + } + + mutating func closePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func closePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func describePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func describePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .execute, length: UInt32(5 + portalName.utf8.count)) + self.buffer.writeNullTerminatedString(portalName) + self.buffer.writeInteger(maxNumberOfRows) + } + + mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers( + id: .parse, + length: UInt32(preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) + ) + self.buffer.writeNullTerminatedString(preparedStatementName) + self.buffer.writeNullTerminatedString(query) + self.buffer.writeInteger(UInt16(parameters.count)) + + for dataType in parameters { + self.buffer.writeInteger(dataType.rawValue) + } + } + + mutating func password(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count) + 1) + self.buffer.writeBytes(bytes) + self.buffer.writeInteger(UInt8(0)) + } + + mutating func flush() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .flush, length: 0) + } + + mutating func saslResponse(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count)) + self.buffer.writeBytes(bytes) + } + + mutating func saslInitialResponse(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(mechanism.utf8.count + 1 + 4 + bytes.count)) + self.buffer.writeNullTerminatedString(mechanism) + if bytes.count > 0 { + self.buffer.writeInteger(Int32(bytes.count)) + self.buffer.writeBytes(bytes) + } else { + self.buffer.writeInteger(Int32(-1)) + } + } + + mutating func ssl() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) + } + + mutating func sync() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) + } + + mutating func terminate() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .terminate, length: 0) + } + + mutating func flushBuffer() -> ByteBuffer { + self.state = .flushed + return self.buffer + } + + private mutating func clearIfNeeded() { + switch self.state { + case .flushed: + self.state = .writable + self.buffer.clear() + + case .writable: + break + } + } +} + +private enum FrontendMessageID: UInt8, Hashable, Sendable { + case bind = 66 // B + case close = 67 // C + case describe = 68 // D + case execute = 69 // E + case flush = 72 // H + case parse = 80 // P + case password = 112 // p - also both sasl values + case sync = 83 // S + case terminate = 88 // X +} + +extension ByteBuffer { + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32) { + self.writeMultipleIntegers(id.rawValue, 4 + length) + } + + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32, _ t1: T1) { + self.writeMultipleIntegers(id.rawValue, 4 + length, t1) + } + + mutating fileprivate func psqlLengthPrefixed(id: FrontendMessageID, _ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + 1 + self.psqlWriteMultipleIntegers(id: id, length: 0) + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } + + mutating fileprivate func psqlLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + self.writeInteger(UInt32(0)) // placeholder + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } +} diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift new file mode 100644 index 00000000..d8f525eb --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -0,0 +1,25 @@ + +public struct PostgresNotification: Sendable { + public let payload: String +} + +public struct PostgresNotificationSequence: AsyncSequence, Sendable { + public typealias Element = PostgresNotification + + let base: AsyncThrowingStream + + public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(base: self.base.makeAsyncIterator()) + } + + public struct AsyncIterator: AsyncIteratorProtocol { + var base: AsyncThrowingStream.AsyncIterator + + public mutating func next() async throws -> Element? { + try await self.base.next() + } + } +} + +@available(*, unavailable) +extension PostgresNotificationSequence.AsyncIterator: Sendable {} diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift new file mode 100644 index 00000000..6449ab29 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -0,0 +1,359 @@ +import NIOCore + +/// A Postgres SQL query, that can be executed on a Postgres server. Contains the raw sql string and bindings. +public struct PostgresQuery: Sendable, Hashable { + /// The query string + public var sql: String + /// The query binds + public var binds: PostgresBindings + + public init(unsafeSQL sql: String, binds: PostgresBindings = PostgresBindings()) { + self.sql = sql + self.binds = binds + } +} + +extension PostgresQuery: ExpressibleByStringInterpolation { + public init(stringInterpolation: StringInterpolation) { + self.sql = stringInterpolation.sql + self.binds = stringInterpolation.binds + } + + public init(stringLiteral value: String) { + self.sql = value + self.binds = PostgresBindings() + } +} + +extension PostgresQuery { + public struct StringInterpolation: StringInterpolationProtocol, Sendable { + public typealias StringLiteralType = String + + @usableFromInline + var sql: String + @usableFromInline + var binds: PostgresBindings + + public init(literalCapacity: Int, interpolationCount: Int) { + self.sql = "" + self.binds = PostgresBindings(capacity: interpolationCount) + } + + public mutating func appendLiteral(_ literal: String) { + self.sql.append(contentsOf: literal) + } + + @inlinable + public mutating func appendInterpolation(_ value: Value) throws { + try self.binds.append(value, context: .default) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation(_ value: Optional) throws { + switch value { + case .none: + self.binds.appendNull() + case .some(let value): + try self.binds.append(value, context: .default) + } + + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation(_ value: Value) { + self.binds.append(value, context: .default) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation(_ value: Optional) { + switch value { + case .none: + self.binds.appendNull() + case .some(let value): + self.binds.append(value, context: .default) + } + + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation( + _ value: Value, + context: PostgresEncodingContext + ) throws { + try self.binds.append(value, context: context) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation(unescaped interpolated: String) { + self.sql.append(contentsOf: interpolated) + } + } +} + +extension PostgresQuery: CustomStringConvertible { + // See `CustomStringConvertible.description`. + public var description: String { + "\(self.sql) \(self.binds)" + } +} + +extension PostgresQuery: CustomDebugStringConvertible { + // See `CustomDebugStringConvertible.debugDescription`. + public var debugDescription: String { + "PostgresQuery(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds)))" + } +} + +struct PSQLExecuteStatement { + /// The statements name + var name: String + /// The binds + var binds: PostgresBindings + + var rowDescription: RowDescription? +} + +public struct PostgresBindings: Sendable, Hashable { + @usableFromInline + struct Metadata: Sendable, Hashable { + @usableFromInline + var dataType: PostgresDataType + @usableFromInline + var format: PostgresFormat + @usableFromInline + var protected: Bool + + @inlinable + init(dataType: PostgresDataType, format: PostgresFormat, protected: Bool) { + self.dataType = dataType + self.format = format + self.protected = protected + } + + @inlinable + init(value: Value, protected: Bool) { + self.init(dataType: value.psqlType, format: value.psqlFormat, protected: protected) + } + } + + @usableFromInline + var metadata: [Metadata] + @usableFromInline + var bytes: ByteBuffer + + public var count: Int { + self.metadata.count + } + + public init() { + self.metadata = [] + self.bytes = ByteBuffer() + } + + public init(capacity: Int) { + self.metadata = [] + self.metadata.reserveCapacity(capacity) + self.bytes = ByteBuffer() + self.bytes.reserveCapacity(128 * capacity) + } + + public mutating func appendNull() { + self.bytes.writeInteger(-1, as: Int32.self) + self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) + } + + @inlinable + public mutating func append(_ value: Value) throws { + try self.append(value, context: .default) + } + + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + + @inlinable + public mutating func append( + _ value: Value, + context: PostgresEncodingContext + ) throws { + try value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: true)) + } + + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + + @inlinable + public mutating func append(_ value: Value) { + self.append(value, context: .default) + } + + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + + @inlinable + public mutating func append( + _ value: Value, + context: PostgresEncodingContext + ) { + value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: true)) + } + + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + + @inlinable + mutating func appendUnprotected( + _ value: Value, + context: PostgresEncodingContext + ) throws { + try value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: false)) + } + + @inlinable + mutating func appendUnprotected( + _ value: Value, + context: PostgresEncodingContext + ) { + value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: false)) + } + + public mutating func append(_ postgresData: PostgresData) { + switch postgresData.value { + case .none: + self.bytes.writeInteger(-1, as: Int32.self) + case .some(var input): + self.bytes.writeInteger(Int32(input.readableBytes)) + self.bytes.writeBuffer(&input) + } + self.metadata.append(.init(dataType: postgresData.type, format: .binary, protected: true)) + } +} + +extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertible { + // See `CustomStringConvertible.description`. + public var description: String { + """ + [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) + .lazy.map({ Self.makeBindingPrintable(protected: $0.protected, type: $0.dataType, format: $0.format, buffer: $1) }) + .joined(separator: ", "))] + """ + } + + // See `CustomDebugStringConvertible.description`. + public var debugDescription: String { + """ + [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) + .lazy.map({ Self.makeDebugDescription(protected: $0.protected, type: $0.dataType, format: $0.format, buffer: $1) }) + .joined(separator: ", "))] + """ + } + + private static func makeDebugDescription(protected: Bool, type: PostgresDataType, format: PostgresFormat, buffer: ByteBuffer?) -> String { + "(\(Self.makeBindingPrintable(protected: protected, type: type, format: format, buffer: buffer)); \(type); format: \(format))" + } + + private static func makeBindingPrintable(protected: Bool, type: PostgresDataType, format: PostgresFormat, buffer: ByteBuffer?) -> String { + if protected { + return "****" + } + + guard var buffer = buffer else { + return "null" + } + + do { + switch (type, format) { + case (.int4, _), (.int2, _), (.int8, _): + let number = try Int64.init(from: &buffer, type: type, format: format, context: .default) + return String(describing: number) + + case (.bool, _): + let bool = try Bool.init(from: &buffer, type: type, format: format, context: .default) + return String(describing: bool) + + case (.varchar, _), (.bpchar, _), (.text, _), (.name, _): + let value = try String.init(from: &buffer, type: type, format: format, context: .default) + return String(reflecting: value) // adds quotes + + default: + return "\(buffer.readableBytes) bytes" + } + } catch { + return "\(buffer.readableBytes) bytes" + } + } +} + +/// A small helper to inspect encoded bindings +private struct BindingsReader: Sequence { + typealias Element = Optional + + var buffer: ByteBuffer + + struct Iterator: IteratorProtocol { + typealias Element = Optional + private var buffer: ByteBuffer + + init(buffer: ByteBuffer) { + self.buffer = buffer + } + + mutating func next() -> Optional> { + guard let length = self.buffer.readInteger(as: Int32.self) else { + return .none + } + + if length < 0 { + return .some(.none) + } + + return .some(self.buffer.readSlice(length: Int(length))!) + } + } + + func makeIterator() -> Iterator { + Iterator(buffer: self.buffer) + } +} diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift new file mode 100644 index 00000000..3936b51e --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -0,0 +1,116 @@ +import NIOCore +import NIOConcurrencyHelpers + +/// An async sequence of ``PostgresRow``s. +/// +/// - Note: This is a struct to allow us to move to a move only type easily once they become available. +public struct PostgresRowSequence: AsyncSequence, Sendable { + public typealias Element = PostgresRow + + typealias BackingSequence = NIOThrowingAsyncSequenceProducer + + let backing: BackingSequence + let lookupTable: [String: Int] + let columns: [RowDescription.Column] + + init(_ backing: BackingSequence, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.backing = backing + self.lookupTable = lookupTable + self.columns = columns + } + + public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator( + backing: self.backing.makeAsyncIterator(), + lookupTable: self.lookupTable, + columns: self.columns + ) + } +} + +extension PostgresRowSequence { + public struct AsyncIterator: AsyncIteratorProtocol { + public typealias Element = PostgresRow + + let backing: BackingSequence.AsyncIterator + + let lookupTable: [String: Int] + let columns: [RowDescription.Column] + + init(backing: BackingSequence.AsyncIterator, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.backing = backing + self.lookupTable = lookupTable + self.columns = columns + } + + public mutating func next() async throws -> PostgresRow? { + if let dataRow = try await self.backing.next() { + return PostgresRow( + data: dataRow, + lookupTable: self.lookupTable, + columns: self.columns + ) + } + return nil + } + } +} + +@available(*, unavailable) +extension PostgresRowSequence.AsyncIterator: Sendable {} + +extension PostgresRowSequence { + public func collect() async throws -> [PostgresRow] { + var result = [PostgresRow]() + for try await row in self { + result.append(row) + } + return result + } +} + +struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy { + static let defaultBufferTarget = 256 + static let defaultBufferMinimum = 1 + static let defaultBufferMaximum = 16384 + + let minimum: Int + let maximum: Int + + private var target: Int + private var canShrink: Bool = false + + init(minimum: Int, maximum: Int, target: Int) { + precondition(minimum <= target && target <= maximum) + self.minimum = minimum + self.maximum = maximum + self.target = target + } + + init() { + self.init( + minimum: Self.defaultBufferMinimum, + maximum: Self.defaultBufferMaximum, + target: Self.defaultBufferTarget + ) + } + + mutating func didYield(bufferDepth: Int) -> Bool { + if bufferDepth > self.target, self.canShrink, self.target > self.minimum { + self.target &>>= 1 + } + self.canShrink = true + + return false // bufferDepth < self.target + } + + mutating func didConsume(bufferDepth: Int) -> Bool { + // If the buffer is drained now, we should double our target size. + if bufferDepth == 0, self.target < self.maximum { + self.target = self.target * 2 + self.canShrink = false + } + + return bufferDepth < self.target + } +} diff --git a/Sources/PostgresNIO/New/PostgresTransactionError.swift b/Sources/PostgresNIO/New/PostgresTransactionError.swift new file mode 100644 index 00000000..35038446 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresTransactionError.swift @@ -0,0 +1,21 @@ +/// A wrapper around the errors that can occur during a transaction. +public struct PostgresTransactionError: Error { + + /// The file in which the transaction was started + public var file: String + /// The line in which the transaction was started + public var line: Int + + /// The error thrown when running the `BEGIN` query + public var beginError: Error? + /// The error thrown in the transaction closure + public var closureError: Error? + + /// The error thrown while rolling the transaction back. If the ``closureError`` is set, + /// but the ``rollbackError`` is empty, the rollback was successful. If the ``rollbackError`` + /// is set, the rollback failed. + public var rollbackError: Error? + + /// The error thrown while commiting the transaction. + public var commitError: Error? +} diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift new file mode 100644 index 00000000..21165388 --- /dev/null +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -0,0 +1,61 @@ +/// A prepared statement. +/// +/// Structs conforming to this protocol will need to provide the SQL statement to +/// send to the server and a way of creating bindings are decoding the result. +/// +/// As an example, consider this struct: +/// ```swift +/// struct Example: PostgresPreparedStatement { +/// static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// typealias Row = (Int, String) +/// +/// var state: String +/// +/// func makeBindings() -> PostgresBindings { +/// var bindings = PostgresBindings() +/// bindings.append(self.state) +/// return bindings +/// } +/// +/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { +/// try row.decode(Row.self) +/// } +/// } +/// ``` +/// +/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, +/// which will take care of preparing the statement on the server side and executing it. +public protocol PostgresPreparedStatement: Sendable { + /// The prepared statements name. + /// + /// > Note: There is a default implementation that returns the implementor's name. + static var name: String { get } + + /// The type rows returned by the statement will be decoded into + associatedtype Row + + /// The SQL statement to prepare on the database server. + static var sql: String { get } + + /// The postgres data types of the values that are bind when this statement is executed. + /// + /// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned + /// from ``PostgresPreparedStatement/makeBindings()``. + /// + /// > Note: There is a default implementation that returns an empty array, which will lead to + /// automatic inference. + static var bindingDataTypes: [PostgresDataType] { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement. + /// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``. + func makeBindings() throws -> PostgresBindings + + /// Decode a row returned by the database into an instance of `Row` + func decodeRow(_ row: PostgresRow) throws -> Row +} + +extension PostgresPreparedStatement { + public static var name: String { String(reflecting: self) } + + public static var bindingDataTypes: [PostgresDataType] { [] } +} diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift new file mode 100644 index 00000000..7931c90c --- /dev/null +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -0,0 +1,172 @@ + +extension PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + try self.decode(Column.self, context: .default, file: file, line: line) + } + + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + precondition(self.columns.count >= 1) + let columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + let column = columnIterator.next().unsafelyUnwrapped + let swiftTargetType: Any.Type = Column.self + + do { + let r0 = try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + let packCount = ComputeParameterPackLength.count(ofPack: repeat (each Column).self) + precondition(self.columns.count >= packCount) + + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var columnIterator = self.columns.makeIterator() + + return ( + repeat try Self.decodeNextColumn( + (each Column).self, + cellIterator: &cellIterator, + columnIterator: &columnIterator, + columnIndex: &columnIndex, + context: context, + file: file, + line: line + ) + ) + } + + @inlinable + static func decodeNextColumn( + _ columnType: Column.Type, + cellIterator: inout IndexingIterator, + columnIterator: inout IndexingIterator<[RowDescription.Column]>, + columnIndex: inout Int, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws -> Column { + defer { columnIndex += 1 } + + let column = columnIterator.next().unsafelyUnwrapped + var cellData = cellIterator.next().unsafelyUnwrapped + do { + return try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: Column.self, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + try self.decode(columnType, context: .default, file: file, line: line) + } +} + +extension AsyncSequence where Element == PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(Column.self, context: context, file: file, line: line) + } + } + + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(Column.self, context: .default, file: file, line: line) + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(columnType, context: context, file: file, line: line) + } + } + + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(columnType, context: .default, file: file, line: line) + } +} + +@usableFromInline +enum ComputeParameterPackLength { + @usableFromInline + enum BoolConverter { + @usableFromInline + typealias Bool = Swift.Bool + } + + @inlinable + static func count(ofPack t: repeat each T) -> Int { + MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride + } +} diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift new file mode 100644 index 00000000..319b86c4 --- /dev/null +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -0,0 +1,207 @@ +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOSSL + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class ConnectionFactory: Sendable { + + struct ConfigCache: Sendable { + var config: PostgresClient.Configuration + } + + let configBox: NIOLockedValueBox + + struct SSLContextCache: Sendable { + enum State { + case none + case producing(TLSConfiguration, [CheckedContinuation]) + case cached(TLSConfiguration, NIOSSLContext) + case failed(TLSConfiguration, any Error) + } + + var state: State = .none + } + + let sslContextBox = NIOLockedValueBox(SSLContextCache()) + + let eventLoopGroup: any EventLoopGroup + + let logger: Logger + + init(config: PostgresClient.Configuration, eventLoopGroup: any EventLoopGroup, logger: Logger) { + self.eventLoopGroup = eventLoopGroup + self.configBox = NIOLockedValueBox(ConfigCache(config: config)) + self.logger = logger + } + + func makeConnection(_ connectionID: PostgresConnection.ID, pool: PostgresClient.Pool) async throws -> PostgresConnection { + let config = try await self.makeConnectionConfig() + + var connectionLogger = self.logger + connectionLogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + + return try await PostgresConnection.connect( + on: self.eventLoopGroup.any(), + configuration: config, + id: connectionID, + logger: connectionLogger + ).get() + } + + func makeConnectionConfig() async throws -> PostgresConnection.Configuration { + let config = self.configBox.withLockedValue { $0.config } + + let tls: PostgresConnection.Configuration.TLS + switch config.tls.base { + case .prefer(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .prefer(sslContext) + + case .require(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .require(sslContext) + case .disable: + tls = .disable + } + + var connectionConfig: PostgresConnection.Configuration + switch config.endpointInfo { + case .bindUnixDomainSocket(let path): + connectionConfig = PostgresConnection.Configuration( + unixSocketPath: path, + username: config.username, + password: config.password, + database: config.database + ) + + case .connectTCP(let host, let port): + connectionConfig = PostgresConnection.Configuration( + host: host, + port: port, + username: config.username, + password: config.password, + database: config.database, + tls: tls + ) + } + + connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) + connectionConfig.options.tlsServerName = config.options.tlsServerName + connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters + + return connectionConfig + } + + private func getSSLContext(for tlsConfiguration: TLSConfiguration) async throws -> NIOSSLContext { + enum Action { + case produce + case succeed(NIOSSLContext) + case fail(any Error) + case wait + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + + case .cached(let cachedTLSConfiguration, let context): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .succeed(context) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .failed(let cachedTLSConfiguration, let error): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .fail(error) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .producing(let cachedTLSConfiguration, var continuations): + continuations.append(continuation) + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + cache.state = .producing(cachedTLSConfiguration, continuations) + return .wait + } else { + cache.state = .producing(tlsConfiguration, continuations) + return .produce + } + } + } + + switch action { + case .wait: + break + + case .produce: + // TBD: we might want to consider moving this off the concurrent executor + self.reportProduceSSLContextResult( + Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), + for: tlsConfiguration + ) + + case .succeed(let context): + continuation.resume(returning: context) + + case .fail(let error): + continuation.resume(throwing: error) + } + } + } + + private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + enum Action { + case fail(any Error, [CheckedContinuation]) + case succeed(NIOSSLContext, [CheckedContinuation]) + case none + } + + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + preconditionFailure("Invalid state: \(cache.state)") + + case .cached, .failed: + return .none + + case .producing(let cachedTLSConfiguration, let continuations): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + switch result { + case .success(let context): + cache.state = .cached(cachedTLSConfiguration, context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(cachedTLSConfiguration, failure) + return .fail(failure, continuations) + } + } else { + return .none + } + } + } + + switch action { + case .none: + break + + case .succeed(let context, let continuations): + for continuation in continuations { + continuation.resume(returning: context) + } + + case .fail(let error, let continuations): + for continuation in continuations { + continuation.resume(throwing: error) + } + } + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift new file mode 100644 index 00000000..d54e34eb --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -0,0 +1,581 @@ +import NIOCore +import NIOSSL +import Atomics +import Logging +import ServiceLifecycle +import _ConnectionPoolModule + +/// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's +/// behavior. +/// +/// ## Creating a client +/// +/// You create a ``PostgresClient`` by first creating a ``PostgresClient/Configuration`` struct that you can +/// use to modify the client's behavior. +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "configuration") +/// +/// Now you can create a client with your configuration object: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "makeClient") +/// +/// ## Running a client +/// +/// ``PostgresClient`` relies on structured concurrency. Because of this it needs a task in which it can schedule all the +/// background work that it needs to do in order to manage connections on the users behave. For this reason, developers +/// must provide a task to the client by scheduling the client's run method in a long running task: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") +/// +/// ``PostgresClient`` can not lease connections, if its ``run()`` method isn't active. Cancelling the ``run()`` method +/// is equivalent to closing the client. Once a client's ``run()`` method has been cancelled, executing queries or prepared +/// statements will fail. +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class PostgresClient: Sendable, ServiceLifecycle.Service { + public struct Configuration: Sendable { + public struct TLS: Sendable { + enum Base { + case disable + case prefer(NIOSSL.TLSConfiguration) + case require(NIOSSL.TLSConfiguration) + } + + var base: Base + + private init(_ base: Base) { + self.base = base + } + + /// Do not try to create a TLS connection to the server. + public static let disable: Self = Self.init(.disable) + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.prefer(sslContext)) + } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.require(sslContext)) + } + } + + // MARK: Client options + + /// Describes general client behavior options. Those settings are considered advanced options. + public struct Options: Sendable { + /// A keep-alive behavior for Postgres connections. The ``frequency`` defines after which time an idle + /// connection shall run a keep-alive ``query``. + public struct KeepAliveBehavior: Sendable { + /// The amount of time that shall pass before an idle connection runs a keep-alive ``query``. + public var frequency: Duration + + /// The ``query`` that is run on an idle connection after it has been idle for ``frequency``. + public var query: PostgresQuery + + /// Create a new `KeepAliveBehavior`. + /// - Parameters: + /// - frequency: The amount of time that shall pass before an idle connection runs a keep-alive `query`. + /// Defaults to `30` seconds. + /// - query: The `query` that is run on an idle connection after it has been idle for `frequency`. + /// Defaults to `SELECT 1;`. + public init(frequency: Duration = .seconds(30), query: PostgresQuery = "SELECT 1;") { + self.frequency = frequency + self.query = query + } + } + + /// A timeout for creating a TCP/Unix domain socket connection. Defaults to `10` seconds. + public var connectTimeout: Duration = .seconds(10) + + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. + /// Defaults to none (but see below). + /// + /// > When set to `nil`: + /// If the connection is made to a server over TCP using + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other + /// method, SNI is disabled. + public var tlsServerName: String? = nil + + /// Whether the connection is required to provide backend key data (internal Postgres stuff). + /// + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). + public var requireBackendKeyData: Bool = true + + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] = [] + + /// The minimum number of connections that the client shall keep open at any time, even if there is no + /// demand. Default to `0`. + /// + /// If the open connection count becomes less than ``minimumConnections`` new connections + /// are created immidiatly. Must be greater or equal to zero and less than ``maximumConnections``. + /// + /// Idle connections are kept alive using the ``keepAliveBehavior``. + public var minimumConnections: Int = 0 + + /// The maximum number of connections that the client may open to the server at any time. Must be greater + /// than ``minimumConnections``. Defaults to `20` connections. + /// + /// Connections, that are created in response to demand are kept alive for the ``connectionIdleTimeout`` + /// before they are dropped. + public var maximumConnections: Int = 20 + + /// The maximum amount time that a connection that is not part of the ``minimumConnections`` is kept + /// open without being leased. Defaults to `60` seconds. + public var connectionIdleTimeout: Duration = .seconds(60) + + /// The ``KeepAliveBehavior-swift.struct`` to ensure that the underlying tcp-connection is still active + /// for idle connections. `Nil` means that the client shall not run keep alive queries to the server. Defaults to a + /// keep alive query of `SELECT 1;` every `30` seconds. + public var keepAliveBehavior: KeepAliveBehavior? = KeepAliveBehavior() + + /// Create an options structure with default values. + /// + /// Most users should not need to adjust the defaults. + public init() {} + } + + // MARK: - Accessors + + /// The hostname to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var host: String? { + if case let .connectTCP(host, _) = self.endpointInfo { return host } + else { return nil } + } + + /// The port to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var port: Int? { + if case let .connectTCP(_, port) = self.endpointInfo { return port } + else { return nil } + } + + /// The socket path to connect to for Unix domain socket connections. + /// + /// Always `nil` for other configurations. + public var unixSocketPath: String? { + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } + else { return nil } + } + + /// The TLS mode to use for the connection. Valid for all configurations. + /// + /// See ``TLS-swift.struct``. + public var tls: TLS = .prefer(.makeClientConfiguration()) + + /// Options for handling the communication channel. Most users don't need to change these. + /// + /// See ``Options-swift.struct``. + public var options: Options = .init() + + /// The username to connect with. + public var username: String + + /// The password, if any, for the user specified by ``username``. + /// + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero + /// length; these are not the same thing. + public var password: String? + + /// The name of the database to open. + /// + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. + public var database: String? + + // MARK: - Initializers + + /// Create a configuration for connecting to a server with a hostname and optional port. + /// + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost + /// definitely want this one. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + /// - tls: The TLS mode to use. + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for connecting to a server through a UNIX domain socket. + /// + /// - Parameters: + /// - path: The filesystem path of the socket to connect to. + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(unixSocketPath: String, username: String, password: String?, database: String?) { + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) + } + + // MARK: - Implementation details + + enum EndpointInfo { + case bindUnixDomainSocket(path: String) + case connectTCP(host: String, port: Int) + } + + var endpointInfo: EndpointInfo + + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { + self.endpointInfo = endpointInfo + self.tls = tls + self.username = username + self.password = password + self.database = database + } + } + + typealias Pool = ConnectionPool< + PostgresConnection, + PostgresConnection.ID, + ConnectionIDGenerator, + ConnectionRequest, + ConnectionRequest.ID, + PostgresKeepAliveBehavor, + PostgresClientMetrics, + ContinuousClock + > + + let pool: Pool + let factory: ConnectionFactory + let runningAtomic = ManagedAtomic(false) + let backgroundLogger: Logger + + /// Creates a new ``PostgresClient``, that does not log any background information. + /// + /// > Warning: + /// The client can only lease connections if the user is running the client's ``run()`` method in a long running task. + /// + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + public convenience init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup + ) { + self.init(configuration: configuration, eventLoopGroup: eventLoopGroup, backgroundLogger: Self.loggingDisabled) + } + + /// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task. + /// + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + /// - backgroundLogger: A `swift-log` `Logger` to log background messages to. A copy of this logger is also + /// forwarded to the created connections as a background logger. + public init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup, + backgroundLogger: Logger + ) { + let factory = ConnectionFactory(config: configuration, eventLoopGroup: eventLoopGroup, logger: backgroundLogger) + self.factory = factory + self.backgroundLogger = backgroundLogger + + self.pool = ConnectionPool( + configuration: .init(configuration), + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: .init(configuration.options.keepAliveBehavior, logger: backgroundLogger), + observabilityDelegate: .init(logger: backgroundLogger), + clock: ContinuousClock() + ) { (connectionID, pool) in + let connection = try await factory.makeConnection(connectionID, pool: pool) + + return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) + } + } + + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + @_disfavoredOverload + public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + #if compiler(>=6.0) + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withConnection( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } + } + #else + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } + } + #endif + + /// Run a query on the Postgres server the client is connected to. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func query( + _ query: PostgresQuery, + logger: Logger? = nil, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresRowSequence { + let logger = logger ?? Self.loggingDisabled + do { + guard query.binds.count <= Int(UInt16.max) else { + throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) + } + + let connection = try await self.leaseConnection() + + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(connection.id)" + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise + ) + + connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult.map { + $0.asyncSequence(onFinish: { + self.pool.releaseConnection(connection) + }) + }.get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = query + throw error // rethrow with more metadata + } + } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger? = nil, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + let logger = logger ?? Self.loggingDisabled + + do { + let connection = try await self.leaseConnection() + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, + logger: logger, + promise: promise + )) + connection.channel.write(task, promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult + .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } + + /// The structured root task for the client's background work. + /// + /// > Warning: + /// Users must call this function in order to allow the client to process any background work. Executing queries, + /// prepared statements or leasing connections will hang until the developer executes the client's ``run()`` + /// method. + /// + /// Cancelling the task which executes the ``run()`` method, is equivalent to closing the client. Once the task + /// has been cancelled the client is not able to process any new queries or prepared statements. + /// + /// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") + /// + /// > Note: + /// ``PostgresClient`` implements [ServiceLifecycle](https://github.com/swift-server/swift-service-lifecycle)'s `Service` protocol. Because of this + /// ``PostgresClient`` can be passed to a `ServiceGroup` for easier lifecycle management. + public func run() async { + let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) + precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") + + await cancelWhenGracefulShutdown { + await self.pool.run() + } + } + + // MARK: - Private Methods - + + private func leaseConnection() async throws -> PostgresConnection { + if !self.runningAtomic.load(ordering: .relaxed) { + self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") + } + return try await self.pool.leaseConnection() + } + + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { + PostgresConnection.defaultEventLoopGroup + } + + static let loggingDisabled = Logger(label: "Postgres-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PostgresKeepAliveBehavor: ConnectionKeepAliveBehavior { + let behavior: PostgresClient.Configuration.Options.KeepAliveBehavior? + let logger: Logger + + init(_ behavior: PostgresClient.Configuration.Options.KeepAliveBehavior?, logger: Logger) { + self.behavior = behavior + self.logger = logger + } + + var keepAliveFrequency: Duration? { + self.behavior?.frequency + } + + func runKeepAlive(for connection: PostgresConnection) async throws { + try await connection.query(self.behavior!.query, logger: self.logger).map { _ in }.get() + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPoolConfiguration { + init(_ config: PostgresClient.Configuration) { + self = ConnectionPoolConfiguration() + self.minimumConnectionCount = config.options.minimumConnections + self.maximumConnectionSoftLimit = config.options.maximumConnections + self.maximumConnectionHardLimit = config.options.maximumConnections + self.idleTimeout = config.options.connectionIdleTimeout + } +} + +extension PostgresConnection: PooledConnection { + public func close() { + self.channel.close(mode: .all, promise: nil) + } + + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + self.closeFuture.whenComplete { _ in closure(nil) } + } +} + +extension ConnectionPoolError { + func mapToPSQLError(lastConnectError: Error?) -> Error { + var psqlError: PSQLError + switch self { + case .poolShutdown: + psqlError = PSQLError.poolClosed + psqlError.underlying = self + + case .requestCancelled: + psqlError = PSQLError.queryCancelled + psqlError.underlying = self + + default: + return self + } + return psqlError + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift new file mode 100644 index 00000000..aa8215db --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -0,0 +1,85 @@ +import _ConnectionPoolModule +import Logging + +final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { + typealias ConnectionID = PostgresConnection.ID + + let logger: Logger + + init(logger: Logger) { + self.logger = logger + } + + func startedConnecting(id: ConnectionID) { + self.logger.debug("Creating new connection", metadata: [ + .connectionID: "\(id)", + ]) + } + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) { + self.logger.debug("Connection creation failed", metadata: [ + .connectionID: "\(id)", + .error: "\(String(reflecting: error))" + ]) + } + + func connectSucceeded(id: ConnectionID) { + self.logger.debug("Connection established", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionLeased(id: ConnectionID) { + self.logger.debug("Connection leased", metadata: [ + .connectionID: "\(id)" + ]) + } + + func connectionReleased(id: ConnectionID) { + self.logger.debug("Connection released", metadata: [ + .connectionID: "\(id)" + ]) + } + + func keepAliveTriggered(id: ConnectionID) { + self.logger.debug("run ping pong", metadata: [ + .connectionID: "\(id)", + ]) + } + + func keepAliveSucceeded(id: ConnectionID) {} + + func keepAliveFailed(id: PostgresConnection.ID, error: Error) {} + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) { + self.logger.debug("Close connection", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) { + self.logger.debug("Connection closed", metadata: [ + .connectionID: "\(id)" + ]) + } + + func requestQueueDepthChanged(_ newDepth: Int) { + + } + + func connectSucceeded(id: PostgresConnection.ID, streamCapacity: UInt16) { + + } + + func connectionUtilizationChanged(id: PostgresConnection.ID, streamsUsed: UInt16, streamCapacity: UInt16) { + + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift new file mode 100644 index 00000000..7d464c2b --- /dev/null +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -0,0 +1,74 @@ +import NIOCore + +extension PSQLError { + func toPostgresError() -> Error { + switch self.code.base { + case .queryCancelled: + return self + case .server, .listenFailed: + guard let serverInfo = self.serverInfo else { + return self + } + + var fields = [PostgresMessage.Error.Field: String]() + fields.reserveCapacity(serverInfo.underlying.fields.count) + serverInfo.underlying.fields.forEach { (key, value) in + fields[PostgresMessage.Error.Field(rawValue: key.rawValue)!] = value + } + return PostgresError.server(PostgresMessage.Error(fields: fields)) + case .sslUnsupported: + return PostgresError.protocol("Server does not support TLS") + case .failedToAddSSLHandler: + return self.underlying ?? self + case .messageDecodingFailure: + let message = self.underlying != nil ? String(describing: self.underlying!) : "no message" + return PostgresError.protocol("Error decoding message: \(message)") + case .unexpectedBackendMessage: + let message = self.backendMessage != nil ? String(describing: self.backendMessage!) : "no message" + return PostgresError.protocol("Unexpected message: \(message)") + case .unsupportedAuthMechanism: + let message = self.unsupportedAuthScheme != nil ? String(describing: self.unsupportedAuthScheme!) : "no scheme" + return PostgresError.protocol("Unsupported auth scheme: \(message)") + case .authMechanismRequiresPassword: + return PostgresError.protocol("Unable to authenticate without password") + case .receivedUnencryptedDataAfterSSLRequest: + return PostgresError.protocol("Received unencrypted data after SSL request") + case .saslError: + return self.underlying ?? self + case .tooManyParameters, .invalidCommandTag: + return self + case .clientClosedConnection, + .serverClosedConnection: + return PostgresError.connectionClosed + case .connectionError: + return self.underlying ?? self + case .unlistenFailed: + return self.underlying ?? self + case .uncleanShutdown: + return PostgresError.protocol("Unexpected connection close") + case .poolClosed: + return self + } + } +} + +extension PostgresFormat { + init(psqlFormatCode: PostgresFormat) { + switch psqlFormatCode { + case .binary: + self = .binary + case .text: + self = .text + } + } +} + +extension Error { + internal var asAppropriatePostgresError: Error { + if let psqlError = self as? PSQLError { + return psqlError.toPostgresError() + } else { + return self + } + } +} diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index a03b6339..8de93814 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -1,39 +1,46 @@ -import NIO +import NIOCore import Logging +import NIOConcurrencyHelpers extension PostgresDatabase { public func query( _ string: String, _ binds: [PostgresData] = [] ) -> EventLoopFuture { - var rows: [PostgresRow] = [] - var metadata: PostgresQueryMetadata? - return self.query(string, binds, onMetadata: { - metadata = $0 - }) { - rows.append($0) + let box = NIOLockedValueBox((metadata: PostgresQueryMetadata?.none, rows: [PostgresRow]())) + + return self.query(string, binds, onMetadata: { metadata in + box.withLockedValue { + $0.metadata = metadata + } + }) { row in + box.withLockedValue { + $0.rows.append(row) + } }.map { - .init(metadata: metadata!, rows: rows) + box.withLockedValue { + PostgresQueryResult(metadata: $0.metadata!, rows: $0.rows) + } } } + @preconcurrency public func query( _ string: String, _ binds: [PostgresData] = [], - onMetadata: @escaping (PostgresQueryMetadata) -> () = { _ in }, - onRow: @escaping (PostgresRow) throws -> () + onMetadata: @Sendable @escaping (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable @escaping (PostgresRow) throws -> () ) -> EventLoopFuture { - let query = PostgresParameterizedQuery( - query: string, - binds: binds, - onMetadata: onMetadata, - onRow: onRow - ) - return self.send(query, logger: self.logger) + var bindings = PostgresBindings(capacity: binds.count) + binds.forEach { bindings.append($0) } + let query = PostgresQuery(unsafeSQL: string, binds: bindings) + let request = PostgresCommands.query(query, onMetadata: onMetadata, onRow: onRow) + + return self.send(request, logger: logger) } } -public struct PostgresQueryResult { +public struct PostgresQueryResult: Sendable { public let metadata: PostgresQueryMetadata public let rows: [PostgresRow] } @@ -59,17 +66,14 @@ extension PostgresQueryResult: Collection { } } -public struct PostgresQueryMetadata { +public struct PostgresQueryMetadata: Sendable { public let command: String public var oid: Int? public var rows: Int? init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { @@ -78,7 +82,13 @@ public struct PostgresQueryMetadata { self.command = .init(parts[0]) self.oid = Int(parts[1]) self.rows = Int(parts[2]) - case "DELETE", "UPDATE", "SELECT", "MOVE", "FETCH", "COPY": + case "SELECT" where parts.count == 1: + // AWS Redshift does not return the actual row count as defined in the postgres wire spec for SELECT: + // https://www.postgresql.org/docs/13/protocol-message-formats.html in section `CommandComplete` + self.command = "SELECT" + self.oid = nil + self.rows = nil + case "SELECT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY": // rows guard parts.count == 2 else { return nil @@ -94,118 +104,3 @@ public struct PostgresQueryMetadata { } } } - -// MARK: Private - -private final class PostgresParameterizedQuery: PostgresRequest { - let query: String - let binds: [PostgresData] - var onMetadata: (PostgresQueryMetadata) -> () - var onRow: (PostgresRow) throws -> () - var rowLookupTable: PostgresRow.LookupTable? - var resultFormatCodes: [PostgresFormatCode] - var logger: Logger? - - init( - query: String, - binds: [PostgresData], - onMetadata: @escaping (PostgresQueryMetadata) -> (), - onRow: @escaping (PostgresRow) throws -> () - ) { - self.query = query - self.binds = binds - self.onMetadata = onMetadata - self.onRow = onRow - self.resultFormatCodes = [.binary] - } - - func log(to logger: Logger) { - self.logger = logger - logger.debug("\(self.query) \(self.binds)") - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - if case .error = message.identifier { - // we should continue after errors - return [] - } - switch message.identifier { - case .bindComplete: - return [] - case .dataRow: - let data = try PostgresMessage.DataRow(message: message) - guard let rowLookupTable = self.rowLookupTable else { fatalError() } - let row = PostgresRow(dataRow: data, lookupTable: rowLookupTable) - try onRow(row) - return [] - case .rowDescription: - let row = try PostgresMessage.RowDescription(message: message) - self.rowLookupTable = PostgresRow.LookupTable( - rowDescription: row, - resultFormat: self.resultFormatCodes - ) - return [] - case .noData: - return [] - case .parseComplete: - return [] - case .parameterDescription: - let params = try PostgresMessage.ParameterDescription(message: message) - if params.dataTypes.count != self.binds.count { - self.logger!.warning("Expected parameters count (\(params.dataTypes.count)) does not equal binds count (\(binds.count))") - } else { - for (i, item) in zip(params.dataTypes, self.binds).enumerated() { - if item.0 != item.1.type { - self.logger!.warning("bind $\(i + 1) type (\(item.1.type)) does not match expected parameter type (\(item.0))") - } - } - } - return [] - case .commandComplete: - let complete = try PostgresMessage.CommandComplete(message: message) - guard let metadata = PostgresQueryMetadata(string: complete.tag) else { - throw PostgresError.protocol("Unexpected query metadata: \(complete.tag)") - } - self.onMetadata(metadata) - return [] - case .notice: - return [] - case .notificationResponse: - return [] - case .readyForQuery: - return nil - case .parameterStatus: - return [] - default: throw PostgresError.protocol("Unexpected message during query: \(message)") - } - } - - func start() throws -> [PostgresMessage] { - guard self.binds.count <= Int16.max else { - throw PostgresError.protocol("Bind count must be <= \(Int16.max).") - } - let parse = PostgresMessage.Parse( - statementName: "", - query: self.query, - parameterTypes: self.binds.map { $0.type } - ) - let describe = PostgresMessage.Describe( - command: .statement, - name: "" - ) - let bind = PostgresMessage.Bind( - portalName: "", - statementName: "", - parameterFormatCodes: self.binds.map { $0.formatCode }, - parameters: self.binds.map { .init(value: $0.value) }, - resultFormatCodes: self.resultFormatCodes - ) - let execute = PostgresMessage.Execute( - portalName: "", - maxRows: 0 - ) - - let sync = PostgresMessage.Sync() - return try [parse.message(), describe.message(), bind.message(), execute.message(), sync.message()] - } -} diff --git a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift index 756c163c..5cf2d7a4 100644 --- a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift +++ b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift @@ -1,71 +1,19 @@ -import NIO +import NIOCore +import NIOConcurrencyHelpers import Logging extension PostgresDatabase { public func simpleQuery(_ string: String) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return simpleQuery(string) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.simpleQuery(string) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func simpleQuery(_ string: String, _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { - let query = PostgresSimpleQuery(query: string, onRow: onRow) - return self.send(query, logger: self.logger) - } -} - -// MARK: Private - -private final class PostgresSimpleQuery: PostgresRequest { - var query: String - var onRow: (PostgresRow) throws -> () - var rowLookupTable: PostgresRow.LookupTable? - - init(query: String, onRow: @escaping (PostgresRow) throws -> ()) { - self.query = query - self.onRow = onRow - } - - func log(to logger: Logger) { - logger.debug("\(self.query)") - } - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - if case .error = message.identifier { - // we should continue after errors - return [] - } - switch message.identifier { - case .dataRow: - let data = try PostgresMessage.DataRow(message: message) - guard let rowLookupTable = self.rowLookupTable else { fatalError() } - let row = PostgresRow(dataRow: data, lookupTable: rowLookupTable) - try onRow(row) - return [] - case .rowDescription: - let row = try PostgresMessage.RowDescription(message: message) - self.rowLookupTable = PostgresRow.LookupTable( - rowDescription: row, - resultFormat: [] - ) - return [] - case .commandComplete: - return [] - case .readyForQuery: - return nil - case .notice: - return [] - case .notificationResponse: - return [] - case .parameterStatus: - return [] - default: - throw PostgresError.protocol("Unexpected message during simple query: \(message)") - } - } - - func start() throws -> [PostgresMessage] { - return try [ - PostgresMessage.SimpleQuery(string: self.query).message() - ] + @preconcurrency + public func simpleQuery(_ string: String, _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + self.query(string, onRow: onRow) } } diff --git a/Sources/PostgresNIO/PostgresDatabase.swift b/Sources/PostgresNIO/PostgresDatabase.swift index 3f6f826f..fcd1afc7 100644 --- a/Sources/PostgresNIO/PostgresDatabase.swift +++ b/Sources/PostgresNIO/PostgresDatabase.swift @@ -1,11 +1,15 @@ -public protocol PostgresDatabase { +import NIOCore +import Logging + +@preconcurrency +public protocol PostgresDatabase: Sendable { var logger: Logger { get } var eventLoop: EventLoop { get } func send( _ request: PostgresRequest, logger: Logger ) -> EventLoopFuture - + func withConnection(_ closure: @escaping (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture } diff --git a/Sources/PostgresNIO/PostgresRequest.swift b/Sources/PostgresNIO/PostgresRequest.swift index 71ec1cbb..71bb6cd1 100644 --- a/Sources/PostgresNIO/PostgresRequest.swift +++ b/Sources/PostgresNIO/PostgresRequest.swift @@ -1,5 +1,8 @@ import Logging +/// Protocol to encapsulate a function call on the Postgres server +/// +/// This protocol is deprecated going forward. public protocol PostgresRequest { // return nil to end request func respond(to message: PostgresMessage) throws -> [PostgresMessage]? diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 9c388b65..144ff3c9 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,3 +1,3 @@ -@_exported import NIO -@_exported import NIOSSL -@_exported import struct Logging.Logger +@_documentation(visibility: internal) @_exported import NIO +@_documentation(visibility: internal) @_exported import NIOSSL +@_documentation(visibility: internal) @_exported import struct Logging.Logger diff --git a/Sources/PostgresNIO/Utilities/NIOUtils.swift b/Sources/PostgresNIO/Utilities/NIOUtils.swift index a1345ebc..75ab8c20 100644 --- a/Sources/PostgresNIO/Utilities/NIOUtils.swift +++ b/Sources/PostgresNIO/Utilities/NIOUtils.swift @@ -1,21 +1,7 @@ import Foundation -import NIO +import NIOCore internal extension ByteBuffer { - mutating func readNullTerminatedString() -> String? { - if let nullIndex = readableBytesView.firstIndex(of: 0) { - defer { moveReaderIndex(forwardBy: 1) } - return readString(length: nullIndex - readerIndex) - } else { - return nil - } - } - - mutating func write(nullTerminated string: String) { - self.writeString(string) - self.writeInteger(0, as: UInt8.self) - } - mutating func readInteger(endianness: Endianness = .big, as rawRepresentable: E.Type) -> E? where E: RawRepresentable, E.RawValue: FixedWidthInteger { guard let rawValue = readInteger(endianness: endianness, as: E.RawValue.self) else { return nil @@ -65,44 +51,6 @@ internal extension ByteBuffer { } return array } - - mutating func readFloat() -> Float? { - return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } - } - - mutating func readDouble() -> Double? { - return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } - } - - mutating func writeFloat(_ float: Float) { - self.writeInteger(float.bitPattern) - } - - mutating func writeDouble(_ double: Double) { - self.writeInteger(double.bitPattern) - } - - mutating func readUUID() -> UUID? { - guard self.readableBytes >= MemoryLayout.size else { - return nil - } - - let value: UUID = self.getUUID(at: self.readerIndex)! /* must work as we have enough bytes */ - // should be MoveReaderIndex - self.moveReaderIndex(forwardBy: MemoryLayout.size) - return value - } - - func getUUID(at index: Int) -> UUID? { - var uuid: uuid_t = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - return self.viewBytes(at: index, length: MemoryLayout.size(ofValue: uuid)).map { bufferBytes in - withUnsafeMutableBytes(of: &uuid) { target in - precondition(target.count <= bufferBytes.count) - target.copyBytes(from: bufferBytes) - } - return UUID(uuid: uuid) - } - } } internal extension Sequence where Element == UInt8 { diff --git a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift index 11224f4b..fae903fe 100644 --- a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift +++ b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift @@ -1,5 +1,5 @@ extension PostgresError { - public struct Code: ExpressibleByStringLiteral, Equatable { + public struct Code: Sendable, ExpressibleByStringLiteral, Equatable { // Class 00 — Successful Completion public static let successfulCompletion: Code = "00000" diff --git a/Sources/PostgresNIO/Utilities/PostgresError.swift b/Sources/PostgresNIO/Utilities/PostgresError.swift index 2ccd7495..b9524275 100644 --- a/Sources/PostgresNIO/Utilities/PostgresError.swift +++ b/Sources/PostgresNIO/Utilities/PostgresError.swift @@ -14,7 +14,8 @@ public enum PostgresError: Error, LocalizedError, CustomStringConvertible { public var description: String { let description: String switch self { - case .protocol(let message): description = "protocol error: \(message)" + case .protocol(let message): + description = "protocol error: \(message)" case .server(let error): return "server: \(error.description)" case .connectionClosed: diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift new file mode 100644 index 00000000..ba57ee9b --- /dev/null +++ b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift @@ -0,0 +1,41 @@ +import class Foundation.JSONDecoder +import struct Foundation.Data +import NIOFoundationCompat +import NIOCore +import NIOConcurrencyHelpers + +/// A protocol that mimicks the Foundation `JSONDecoder.decode(_:from:)` function. +/// Conform a non-Foundation JSON decoder to this protocol if you want PostgresNIO to be +/// able to use it when decoding JSON & JSONB values (see `PostgresNIO._defaultJSONDecoder`) +@preconcurrency +public protocol PostgresJSONDecoder: Sendable { + func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable + + func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T +} + +extension PostgresJSONDecoder { + public func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T { + var copy = buffer + let data = copy.readData(length: buffer.readableBytes)! + return try self.decode(type, from: data) + } +} + +//@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension JSONDecoder: PostgresJSONDecoder {} + +private let jsonDecoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONDecoder()) + +/// The default JSON decoder used by PostgresNIO when decoding JSON & JSONB values. +/// As `_defaultJSONDecoder` will be reused for decoding all JSON & JSONB values +/// from potentially multiple threads at once, you must ensure your custom JSON decoder is +/// thread safe internally like `Foundation.JSONDecoder`. +public var _defaultJSONDecoder: PostgresJSONDecoder { + set { + jsonDecoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonDecoderLocked.withLockedValue { $0 } + } +} diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift new file mode 100644 index 00000000..9585f20b --- /dev/null +++ b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift @@ -0,0 +1,39 @@ +import Foundation +import NIOFoundationCompat +import NIOCore +import NIOConcurrencyHelpers + +/// A protocol that mimicks the Foundation `JSONEncoder.encode(_:)` function. +/// Conform a non-Foundation JSON encoder to this protocol if you want PostgresNIO to be +/// able to use it when encoding JSON & JSONB values (see `PostgresNIO._defaultJSONEncoder`) +@preconcurrency +public protocol PostgresJSONEncoder: Sendable { + func encode(_ value: T) throws -> Data where T : Encodable + + func encode(_ value: T, into buffer: inout ByteBuffer) throws +} + +extension PostgresJSONEncoder { + public func encode(_ value: T, into buffer: inout ByteBuffer) throws { + let data = try self.encode(value) + buffer.writeData(data) + } +} + +extension JSONEncoder: PostgresJSONEncoder {} + +private let jsonEncoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONEncoder()) + +/// The default JSON encoder used by PostgresNIO when encoding JSON & JSONB values. +/// As `_defaultJSONEncoder` will be reused for encoding all JSON & JSONB values +/// from potentially multiple threads at once, you must ensure your custom JSON encoder is +/// thread safe internally like `Foundation.JSONEncoder`. +public var _defaultJSONEncoder: PostgresJSONEncoder { + set { + jsonEncoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonEncoderLocked.withLockedValue { $0 } + } +} + diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift new file mode 100644 index 00000000..2a717b6b --- /dev/null +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -0,0 +1,658 @@ +import Crypto +import Foundation + +extension UInt8 { + fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ } + fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ } + fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ } +} + +fileprivate extension String { + /** + ```` + The characters ',' or '=' in usernames are sent as '=2C' and + '=3D' respectively. If the server receives a username that + contains '=' not followed by either '2C' or '3D', then the + server MUST fail the authentication. + ```` + */ + var decodedAsSaslName: String? { + guard !self.contains(",") else { return nil } + + let partial = self.replacingOccurrences(of: "=2C", with: ",") + .replacingOccurrences(of: "=3D", with: "@@@TEMPORARY_REPLACEMENT_MARKER_EQUALS@@@") + + guard !partial.contains("=") else { return nil } + return partial.replacingOccurrences(of: "@@@TEMPORARY_REPLACEMENT_MARKER_EQUALS@@@", with: "=") + } + + var encodedAsSaslName: String { + return self.replacingOccurrences(of: ",", with: "=2C").replacingOccurrences(of: "=", with: "=3D") + } + + init?(printableAscii: [UInt8]) { + // `isdigit()` is (bad) libc, `CharacterSet` is Foundation. Rather than pull in either one + // directly, just hardcode it. Not a great ideal in general, mind you. UTF-8 is designed so + // all non-ASCII scalars will have high bits in their encoding somewhere, so this check will + // catch those too. Plus we ask `String` to explicitly accept ASCII only for good measure. + // `printable = %x21-2B / %x2D-7E` + guard !printableAscii.contains(where: { $0 < 0x21 || $0 == 0x2c || $0 > 0x7e }) else { return nil } + self.init(bytes: printableAscii, encoding: .ascii) + } + + init?(asciiAlphanumericMorse: [UInt8]) { + guard asciiAlphanumericMorse.isAsciiAlphanumericMorse else { return nil } + self.init(bytes: asciiAlphanumericMorse, encoding: .ascii) + } +} + +fileprivate extension Array where Element == UInt8 { + + // TODO: Use the Base64 coder from NIOWebSocket or Crypto rather than yanking in Foundation + func decodingBase64() -> [UInt8]? { + var actual = self + if actual.count % 4 != 0 { + actual.append(contentsOf: Array.init(repeating: .equals, count: 4 - (actual.count % 4))) + } + return actual.withUnsafeBytes({ (buf: UnsafeRawBufferPointer) -> Data? in + Data(base64Encoded: Data(bytesNoCopy: .init(mutating: buf.baseAddress!), count: buf.count, deallocator: .none)) + }).map { .init($0) } + } + + func encodingBase64() -> [UInt8] { + return Array(self.withUnsafeBytes { + Data(bytesNoCopy: .init(mutating: $0.baseAddress!), count: $0.count, deallocator: .none).base64EncodedData() + }) + } + + /// `1*(ALPHA / DIGIT / "." / "-")` + var isAsciiAlphanumericMorse: Bool { + // This is dumb. Match if we don't contain not containing any of the valid ranges. Yep. + return !self.contains(where: { c in ![0x30...0x39, 0x41...0x51, 0x61...0x7a].reduce(false, { $0 || $1.contains(c) }) }) + } + + var isAllDigits: Bool { + return !self.contains(where: { !(0x30...0x39).contains($0) }) + } + + /** + ```` + value = 1*value-char + value-safe-char = %x01-2B / %x2D-3C / %x3E-7F / UTF8-2 / UTF8-3 / UTF8-4 + value-char = value-safe-char / "=" + ```` + */ + var isValidScramValue: Bool { + // TODO: FInd a better way than doing a whole construction of String... + return self.count > 0 && !(String(decoding: self, as: Unicode.UTF8.self).contains(",")) + } + +} + +fileprivate enum SCRAMServerError: Error, RawRepresentable { + // Really could just use a string for this... + case invalidEncoding, extensionsNotSupported, invalidProof, channelBindingsDontMatch, + serverDoesSupportChannelBinding, channelBindingNotSupported, unsupportedChannelBindingType, + unknownUser, invalidUsernameEncoding, noResources, otherError, serverErrorValueExt(String) + + init?(rawValue: String) { + switch rawValue { + case "invalid-encoding": self = .invalidEncoding + case "extensions-not-supported": self = .extensionsNotSupported + case "invalid-proof": self = .invalidProof + case "channel-bindings-dont-match": self = .channelBindingsDontMatch + case "server-does-support-channel-binding": self = .serverDoesSupportChannelBinding + case "channel-binding-not-supported": self = .channelBindingNotSupported + case "unsupported-channel-binding-type": self = .unsupportedChannelBindingType + case "unknown-user": self = .unknownUser + case "invalid-username-encoding": self = .invalidUsernameEncoding + case "no-resources": self = .noResources + case "other-error": self = .otherError + default: self = .serverErrorValueExt(rawValue) + } + } + + var rawValue: String { + switch self { + case .invalidEncoding: return "invalid-encoding" + case .extensionsNotSupported: return "extensions-not-supported" + case .invalidProof: return "invalid-proof" + case .channelBindingsDontMatch: return "channel-bindings-dont-match" + case .serverDoesSupportChannelBinding: return "server-does-support-channel-binding" + case .channelBindingNotSupported: return "channel-binding-not-supported" + case .unsupportedChannelBindingType: return "unsupported-channel-binding-type" + case .unknownUser: return "unknown-user" + case .invalidUsernameEncoding: return "invalid-username-encoding" + case .noResources: return "no-resources" + case .otherError: return "other-error" + case .serverErrorValueExt(let raw): return raw + } + } +} + +fileprivate enum SCRAMAttribute { + enum GS2ChannelBinding { + case unsupported // client lacks support: `"n"` + case unused // client thinks server lacks support: `"y"` + case bind(String, [UInt8]?) // explicit channel binding: `"p=" 1*(ALPHA / DIGIT / "." / "-")`, `cbind-data`, per RFC 5056§7 + } + /// authorization identity: `"a=" saslname` + case a(String?) + /// authentication identity: `"n=" saslname` + case n(String) + /// reserved for mandatory extension signaling: `"m=" 1*(value-char)` + case m([UInt8]) + /// nonce: `"r=" printable` + case r(String) + /// GS2 header and channel binding: `"c=" base64(cbind-input)` + case c(binding: GS2ChannelBinding = .unsupported, authIdentity: String? = nil) + /// salt: `"s=" base64` + case s([UInt8]) + /// iteration count: `"i=" posit-number` + case i(UInt32) + /// computed proof: `"p=" base64` (notably slightly conflicts with GS2 header's channel binding) + case p([UInt8]) + /// verifier (computed server signature): `"v=" base64` + case v([UInt8]) + /// server error: `"e=" server-error-value` + case e(SCRAMServerError) + /// unknown optional extension: `attr-val` ... `ALPHA "=" 1*value-char` + case optional(name: CChar, value: [UInt8]) + /// partial GS2 header signal (binding type, no data) + case gp(GS2ChannelBinding) + /// partial GS2 header signal (binding data) + case gm([UInt8]) +} + +fileprivate struct SCRAMMessageParser { + static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? { + guard name.count == 1 || isGS2Header else { return nil } + switch name.first { + case UInt8(ascii: "m") where !isGS2Header: return .m(value) + case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) } + case UInt8(ascii: "c") where !isGS2Header: + guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } + guard (1...3).contains(parsedAttrs.count) else { return nil } + switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { + case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) + case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) + case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) + case let (.gp(bind), .none, .none): return .c(binding: bind) + default: return nil + } + case UInt8(ascii: "n") where !isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .n($0) } + case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) } + case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } + case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) } + case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) } + case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values + guard value.isValidScramValue else { return nil } + return SCRAMServerError(rawValue: String(decoding: value, as: Unicode.UTF8.self)).flatMap { .e($0) } + + case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused) + case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported) + case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } + case UInt8(ascii: "a") where isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .a($0) } + case .none where isGS2Header: return .a(nil) + + default: + if isGS2Header { + return .gm(name + value) + } else { + guard value.count > 0, value.isValidScramValue else { return nil } + return .optional(name: CChar(name[0]), value: value) + } + } + } + + static func parse(raw: [UInt8], isGS2Header: Bool = false) -> [SCRAMAttribute]? { + // There are two ways to implement this parse: + // 1. All-at-once: Split on comma, split each on equals, validate + // each results in a valid attribute. + // 2. Sequential: State machine lookahead parse. + // The former is simpler. The latter provides better validation. + let likelyAttributeSets = raw.split(separator: .comma, maxSplits: isGS2Header ? 2 : Int.max, omittingEmptySubsequences: false) + let likelyAttributePairs = likelyAttributeSets.map { $0.split(separator: .equals, maxSplits: 1, omittingEmptySubsequences: false) } + + let results = likelyAttributePairs.map { parseAttributePair(name: Array($0[0]), value: $0.dropFirst().first.map { Array($0) } ?? [], isGS2Header: isGS2Header) } + let validResults = results.compactMap { $0 } + guard validResults.count == results.count else { return nil } + + return validResults + } + + static func serialize(_ attributes: [SCRAMAttribute], isInitialGS2Header: Bool = true) -> [UInt8]? { + var result: [UInt8] = [] + + for attribute in attributes { + switch attribute { + case .m(let value): + result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value) + case .r(let nonce): + result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) + case .n(let name): + result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) + case .s(let salt): + result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64()) + case .i(let count): + result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) + case .p(let proof): + result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64()) + case .v(let signature): + result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64()) + case .e(let error): + result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) + case .c(let binding, let identity): + if isInitialGS2Header { + switch binding { + case .unsupported: result.append(UInt8(ascii: "n")) + case .unused: result.append(UInt8(ascii: "y")) + case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) }) + } + result.append(.comma) + if let identity = identity { + result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) + } + result.append(.comma) + } else { + guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil } + if case let .bind(_, data) = binding { + guard let data = data else { return nil } + partial.append(contentsOf: data) + } + result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64()) + } + default: + return nil + } + result.append(.comma) + } + return result.dropLast() + } +} + +internal enum SASLMechanism { +internal enum SCRAM { + +/// Implementation of `SCRAM-SHA-256` as a `SASLAuthenticationMechanism` +/// +/// Implements SCRAM-SHA-256 as described by: +/// - [RFC 7677 (SCRAM-SHA-256 and SCRAM-SHA-256-PLUS SASL Mechanisms)](https://tools.ietf.org/html/rfc7677) +/// - [RFC 5802 (SCRAM SASL and GSS-API Mechanisms)](https://tools.ietf.org/html/rfc5802) +/// - [RFC 4422 (Simple Authentication and Security Layer)](https://tools.ietf.org/html/rfc4422) +internal struct SHA256: SASLAuthenticationMechanism { + + static internal var name: String { return "SCRAM-SHA-256" } + + /// Set up a client-side `SCRAM-SHA-256` authentication. + /// + /// - Parameters: + /// - username: The username to authenticate as. + /// - password: A closure which returns the plaintext password for the + /// authenticating user. If the closure throws, authentication + /// immediately fails with the thrown error. + internal init(username: String, password: @escaping () throws -> String) { + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .unsupported) + } + + /// Set up a server-side `SCRAM-SHA-256` authentication. + /// + /// - Parameters: + /// - iterations: The number of iterations applied to salted passwords. + /// Must be at least 4096. + /// - saltedPassword: A closure which receives the username of the user + /// attempting to authentication and must return the + /// salted password for that user, as well as the salt + /// itself. If the closure throw, authentication + /// immediately fails with the thrown error. + internal init(serveWithIterations iterations: UInt32 = 4096, saltedPassword: @escaping (String) throws -> ([UInt8], [UInt8])) { + self._impl = .init(iterationCount: iterations, passwordGrabber: saltedPassword, requireBinding: false) + } + + internal func step(message: [UInt8]?) -> SASLAuthenticationStepResult { + return _impl.step(message: message) + } + + private let _impl: SASLMechanism_SCRAM_SHA256_Common +} + +/// Implementation of `SCRAM-SHA-256-PLUS` as a `SASLAuthenticationMechanism` +/// +/// Implements SCRAM-SHA-256-PLUS as described by: +/// - [RFC 7677 (SCRAM-SHA-256 and SCRAM-SHA-256-PLUS SASL Mechanisms)](https://tools.ietf.org/html/rfc7677) +/// - [RFC 5802 (SCRAM SASL and GSS-API Mechanisms)](https://tools.ietf.org/html/rfc5802) +/// - [RFC 4422 (Simple Authentication and Security Layer)](https://tools.ietf.org/html/rfc4422) +internal struct SHA256_PLUS: SASLAuthenticationMechanism { + + static internal var name: String { return "SCRAM-SHA-256-PLUS" } + + /// Set up a client-side `SCRAM-SHA-256-PLUS` authentication. + /// + /// - Parameters: + /// - username: The username to authenticate as. + /// - password: A closure which returns the plaintext password for the + /// authenticating user. If the closure throws, authentication + /// immediately fails with the thrown error. + /// - channelBindingName: The RFC5056 channel binding to apply to the + /// authentication. + /// - channelBindingData: The appropriate data associated with the RFC5056 + /// channel binding specified. + internal init(username: String, password: @escaping () throws -> String, channelBindingName: String, channelBindingData: [UInt8]) { + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) + } + + /// Set up a server-side `SCRAM-SHA-256` authentication. + /// + /// - Parameters: + /// - iterations: The number of iterations applied to salted passwords. + /// Must be at least 4096. + /// - saltedPassword: A closure which receives the username of the user + /// attempting to authentication and must return the + /// salted password for that user, as well as the salt + /// itself. If the closure throw, authentication + /// immediately fails with the thrown error. + internal init(serveWithIterations iterations: UInt32 = 4096, saltedPassword: @escaping (String) throws -> ([UInt8], [UInt8])) { + self._impl = .init(iterationCount: iterations, passwordGrabber: saltedPassword, requireBinding: true) + } + + internal func step(message: [UInt8]?) -> SASLAuthenticationStepResult { + return _impl.step(message: message) + } + + private let _impl: SASLMechanism_SCRAM_SHA256_Common +} + +} // enum SCRAM +} // enum SASLMechanism + +/// Common implementation of SCRAM-SHA-256 and SCRAM-SHA-256-PLUS +fileprivate final class SASLMechanism_SCRAM_SHA256_Common { + + /// Initialized with initial client state + init(username: String, passwordGrabber: @escaping (String) throws -> ([UInt8], [UInt8]), bindingInfo: SCRAMAttribute.GS2ChannelBinding) { + let nonce = Data((0..<18).map { _ in UInt8.random(in: .min...(.max)) }).base64EncodedString() + self.state = .clientInitial(username: username, nonce: nonce, binding: bindingInfo) + self.passwordGrabber = passwordGrabber + } + + /// Initialized with initial server state + init(iterationCount: UInt32, passwordGrabber: @escaping (String) throws -> ([UInt8], [UInt8]), requireBinding: Bool) { + let nonce = Data((0..<18).map { _ in UInt8.random(in: .min...(.max)) }).base64EncodedString() + self.state = .serverInitial(extraNonce: nonce, iterationCount: iterationCount, bindingRequired: requireBinding) + self.passwordGrabber = passwordGrabber + } + + private var state: State + private let passwordGrabber: (String) throws -> ([UInt8], [UInt8]) + + private enum State { + case clientInitial(username: String, nonce: String, binding: SCRAMAttribute.GS2ChannelBinding) + case clientSentFirstMessage(username: String, nonce: String, binding: SCRAMAttribute.GS2ChannelBinding, bareMessage: [UInt8]) + case clientSentFinalMessage(saltedPassword: [UInt8], authMessage: [UInt8]) + case clientDone + + case serverInitial(extraNonce: String, iterationCount: UInt32, bindingRequired: Bool) + case serverSentFirstMessage(clientBareFirstMessage: [UInt8], nonce: String, binding: SCRAMAttribute.GS2ChannelBinding, saltedPassword: [UInt8], serverFirstMessage: [UInt8]) + case serverDone + } + + public func step(message: [UInt8]?) -> SASLAuthenticationStepResult { + do { + switch state { + case .clientInitial(let username, let nonce, let binding): + guard message == nil else { throw SASLAuthenticationError.initialRequestNotSent } + return try self.handleClientInitial(username: username, nonce: nonce, binding: binding) + case .clientSentFirstMessage(let username, let nonce, let binding, let firstMessageBare): + guard let serverFirstMessage = message else { throw SASLAuthenticationError.initialRequestAlreadySent } + return try self.handleClientSentFirst(message: serverFirstMessage, username: username, nonce: nonce, binding: binding, firstMessageBare: firstMessageBare) + case .clientSentFinalMessage(let saltedPassword, let authMessage): + guard let serverFinalMessage = message else { throw SASLAuthenticationError.initialRequestAlreadySent } + return try self.handleClientSentFinal(message: serverFinalMessage, saltedPassword: saltedPassword, authMessage: authMessage) + case .clientDone: + throw SASLAuthenticationError.resultAlreadyDelivered + + case .serverInitial(let extraNonce, let iterationCount, let bindingRequired): + return try self.handleServerInitial(message!, extraNonce: extraNonce, iterationCount: iterationCount, bindingRequired: bindingRequired) + case .serverSentFirstMessage(let clientBareFirstMessage, let nonce, let previousBinding, let saltedPassword, let serverFirstMessage): + return try self.handleServerSentFirst(message!, clientBareFirstMessage: clientBareFirstMessage, nonce: nonce, previousBinding: previousBinding, saltedPassword: saltedPassword, serverFirstMessage: serverFirstMessage) + case .serverDone: + throw SASLAuthenticationError.resultAlreadyDelivered + } + } catch { + return .fail(response: nil, error: error) + } + } + + private func handleClientInitial(username: String, nonce: String, binding: SCRAMAttribute.GS2ChannelBinding) throws -> SASLAuthenticationStepResult { + // Generate a `client-first-message-bare` + guard let clientFirstMessageBare = SCRAMMessageParser.serialize([.n(username), .r(nonce)]) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Generate a `gs2-header` + guard let clientGs2Header = SCRAMMessageParser.serialize([.c(binding: binding, authIdentity: nil)]) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Paste them together to make a `client-first-message` + let clientFirstMessage = clientGs2Header + clientFirstMessageBare + + // Save state and send + self.state = .clientSentFirstMessage(username: username, nonce: nonce, binding: binding, bareMessage: clientFirstMessageBare) + return .continue(response: clientFirstMessage) + } + + private func handleClientSentFirst(message: [UInt8], username: String, nonce: String, + binding: SCRAMAttribute.GS2ChannelBinding, firstMessageBare: [UInt8]) throws -> SASLAuthenticationStepResult { + // Parse incoming + guard let incomingAttributes = SCRAMMessageParser.parse(raw: message) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Validate as `server-first-message` and extract data + guard incomingAttributes.count >= 3 else { throw SASLAuthenticationError.genericAuthenticationFailure } + guard case let .r(serverNonce) = incomingAttributes.dropFirst(0).first else { throw SASLAuthenticationError.genericAuthenticationFailure } + guard case let .s(serverSalt) = incomingAttributes.dropFirst(1).first else { throw SASLAuthenticationError.genericAuthenticationFailure } + guard case let .i(serverIterations) = incomingAttributes.dropFirst(2).first else { throw SASLAuthenticationError.genericAuthenticationFailure } + + // Generate a `client-final-message-no-proof` + guard let clientFinalNoProof = SCRAMMessageParser.serialize([.c(binding: binding), .r(serverNonce)], isInitialGS2Header: false) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Retrieve the authentication credential + let (password, _) = try self.passwordGrabber("") + + // TODO: Perform `Normalize(password)`, aka the SASLprep profile (RFC4013) of stringprep (RFC3454) + + // Calculate `AuthMessage`, `ClientSignature`, and `ClientProof` + let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) + let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let storedKey = SHA256.hash(data: Data(clientKey)) + var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) + let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) + var clientProof = Array(clientKey) + + clientProof.withUnsafeMutableBytes { proofBuf in + clientSignature.withUnsafeBytes { signatureBuf in + for i in 0.. SASLAuthenticationStepResult { + // Parse incoming + guard let incomingAttributes = SCRAMMessageParser.parse(raw: message) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Validate as `server-final-message` and extract data + switch incomingAttributes.first { + case .v(let verifier): + // Verify server signature + let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) + + guard Array(serverSignature) == verifier else { + return .fail(response: nil, error: SASLAuthenticationError.genericAuthenticationFailure) + } + case .e(let error): + return .fail(response: nil, error: error) + default: throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Mark done and return success + self.state = .clientDone + return .succeed(response: nil) + } + + private func handleServerInitial(_ message: [UInt8], extraNonce: String, + iterationCount: UInt32, bindingRequired: Bool) throws -> SASLAuthenticationStepResult { + var binding: SCRAMAttribute.GS2ChannelBinding = .unsupported + + // Parse as GS2 header first. This is awful and the parser should be refactored. + guard var channelAttributes = SCRAMMessageParser.parse(raw: message, isGS2Header: true), channelAttributes.count > 0 else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + // Channel binding flag is required. Binding data may not be specified in initial hello. + switch channelAttributes.removeFirst() { + case .gp(.unsupported): + guard !bindingRequired else { throw SCRAMServerError.serverErrorValueExt("client-negotiated-badly") } + binding = .unsupported + case .gp(.unused): + if bindingRequired { throw SCRAMServerError.serverDoesSupportChannelBinding } + else { throw SCRAMServerError.serverErrorValueExt("channel-bindings-expected-from-client") } + case .gp(.bind(let type, .none)): + guard bindingRequired else { throw SCRAMServerError.channelBindingNotSupported } + binding = .bind(type, nil) + default: throw SASLAuthenticationError.genericAuthenticationFailure + } + // Optional authorization name may appear in GS2 header + if case .a(_) = channelAttributes.first { + channelAttributes.removeFirst() + // TODO: Allow callers to handle authorization names + } + // Extract remaining message content. Again, parser needs refactoring. + guard case let .gm(clientFirstMessageBare) = channelAttributes.first, channelAttributes.count == 1, + let incomingAttributes = SCRAMMessageParser.parse(raw: clientFirstMessageBare), + case let .n(username) = incomingAttributes.dropFirst(0).first, + case let .r(clientNonce) = incomingAttributes.dropFirst(1).first else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Retrieve credentials + let (saltedPassword, salt) = try self.passwordGrabber(username) + + // Generate a `server-first-message` + guard let serverFirstMessage = SCRAMMessageParser.serialize([.r(clientNonce + extraNonce), .s(salt), .i(iterationCount)]) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Save state and send + self.state = .serverSentFirstMessage(clientBareFirstMessage: clientFirstMessageBare, nonce: clientNonce + extraNonce, binding: binding, saltedPassword: saltedPassword, serverFirstMessage: serverFirstMessage) + return .continue(response: serverFirstMessage) + } + + private func handleServerSentFirst( + _ message: [UInt8], + clientBareFirstMessage: [UInt8], nonce: String, previousBinding: SCRAMAttribute.GS2ChannelBinding, + saltedPassword: [UInt8], serverFirstMessage: [UInt8] + ) throws -> SASLAuthenticationStepResult { + guard let incomingAttributes = SCRAMMessageParser.parse(raw: message) else { throw SASLAuthenticationError.genericAuthenticationFailure } + guard case let .c(binding, _) = incomingAttributes.dropFirst(0).first, + case let .r(repeatNonce) = incomingAttributes.dropFirst(1).first, + case let .p(proof) = incomingAttributes.last else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + switch (binding, previousBinding) { + case (.unsupported, .unsupported): break // all good + case (.bind(let type, _), .bind(let prevType, .none)) where type == prevType: + // TODO: Actually handle the binding + break + default: throw SCRAMServerError.channelBindingsDontMatch + } + guard nonce == repeatNonce else { throw SASLAuthenticationError.genericAuthenticationFailure } + + // Compute client signature + let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let storedKey = SHA256.hash(data: Data(clientKey)) + var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) + let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) + + // Recompute client key from signature and proof, verify match + var clientProofKey = Array(clientSignature) + + clientProofKey.withUnsafeMutableBytes { proofBuf in + proof.withUnsafeBytes { signatureBuf in + for i in 0...authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) + + // Generate a `server-final-message` + guard let serverFinalMessage = SCRAMMessageParser.serialize([.v(Array(serverSignature))]) else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + + // Save state and signal success with the reply + self.state = .serverDone + return .succeed(response: serverFinalMessage) + } + +} + +/** + ```` + o Hi(str, salt, i): + + U1 := HMAC(str, salt + INT(1)) + U2 := HMAC(str, U1) + ... + Ui-1 := HMAC(str, Ui-2) + Ui := HMAC(str, Ui-1) + + Hi := U1 XOR U2 XOR ... XOR Ui + + where "i" is the iteration count, "+" is the string concatenation + operator, and INT(g) is a 4-octet encoding of the integer g, most + significant octet first. + + Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the + pseudorandom function (PRF) and with dkLen == output length of + HMAC() == output length of H(). + ```` +*/ +private func Hi(string: [UInt8], salt: [UInt8], iterations: UInt32) -> [UInt8] { + let key = SymmetricKey(data: string) + var Ui = HMAC.authenticationCode(for: salt + [0x00, 0x00, 0x00, 0x01], using: key) // salt + 0x00000001 as big-endian + var Hi = Array(Ui) + + Hi.withUnsafeMutableBytes { Hibuf -> Void in + for _ in 2...iterations { + Ui = HMAC.authenticationCode(for: Data(Ui), using: key) + + Ui.withUnsafeBytes { Uibuf -> Void in + for i in 0.. { + + private enum Role { + case client, server + } + + private enum State { + /// Client: Waiting to send initial request. May transition to `waitingNextStep`. + /// Server: Waiting for initial request. May transition to `waitingNextStep`, `done`. + case waitingForInitial + + /// Client: Initial request sent, waiting for next challenge. May transition to `done`. + /// Server: Latest challenge sent, waiting for next response. May transition to `done`. + case waitingNextStep + + /// Client: Received success or failure. No more operations permitted. + /// Server: Sent success or failure. No more operations permitted. + case done + } + + private let mechanism: M + private let role: Role + private var state: State = .waitingForInitial + + public init(asClientSpeaking mechanism: M) { + self.mechanism = mechanism + self.role = .client + } + + public init(asServerAccepting mechanism: M) { + self.mechanism = mechanism + self.role = .server + } + + /// Handle an incoming message via the provided mechanism. The `sender` + /// closure will be invoked with any data that should be transmitted to the + /// other side of the negotiation. An error thrown from the closure will + /// immediately result in an authentication failure state. The closure may + /// be invoked even if authentication otherwise fails (such as for + /// mechanisms which send failure responses). On authentication failure, an + /// error is thrown. Otherwise, `true` is returned to indicate that + /// authentication has successfully completed. `false` is returned to + /// indicate that further steps are required by the current mechanism. + /// + /// Pass a `nil` message to start the initial request from a client. It is + /// invalid to do this for a server. + public func handle(message: [UInt8]?, sender: ([UInt8]) throws -> Void) throws -> Bool { + guard self.state != .done else { + // Already did whatever we were gonna do. + throw SASLAuthenticationError.resultAlreadyDelivered + } + + if message == nil { + guard self.role == .client else { + // Can't respond to `nil` as server + self.state = .done + throw SASLAuthenticationError.serverRoleRequiresMessage + } + guard self.state == .waitingForInitial else { + // Can't respond to `nil` as client twice. + self.state = .done + throw SASLAuthenticationError.initialRequestAlreadySent + } + } else if self.role == .client && state == .waitingForInitial { + // Must respond to `nil` as client first and exactly once. + self.state = .done + throw SASLAuthenticationError.initialRequestNotSent + } + + switch self.mechanism.step(message: message) { + case .continue(let response): + if let response = response { + try sender(response) + } + self.state = .waitingNextStep + return false + case .succeed(let response): + if let response = response { + try sender(response) + } + self.state = .done + return true + case .fail(let response, let error): + if let response = response { + try sender(response) + } + self.state = .done + if let error = error { + throw error + } else { + throw SASLAuthenticationError.genericAuthenticationFailure + } + } + } + +} + +/// Various errors that can occur during SASL negotiation that are not specific +/// to the particular SASL mechanism in use. +public enum SASLAuthenticationError: Error { + /// A server can not handle a nonexistent message. Only an initial-state + /// client can do that, and even then it's really just a proxy for the API + /// having difficulty expressing "this must be done once and then never + /// again" clearly. + case serverRoleRequiresMessage + + /// A client may only receive a nonexistent message once during the initial + /// state. This is a proxy for the API not being good at expressing a "must + /// do this first and only once." + case initialRequestAlreadySent + + /// A client must receive a nonexistent message exactly once before doing + /// anything else. This is ALSO a proxy for the API just being bad at + /// expressing the requirement. + case initialRequestNotSent + + /// Authentication failed, and the underlying mechanism declined to provide + /// a more specific error message. + case genericAuthenticationFailure + + /// This `SASLAuthenticationManager` has already delivered a success or + /// failure result (which may include a fatal state management error). It + /// can not be reused. + case resultAlreadyDelivered +} + +/// Signifies an action to be taken as the result of a single step of a SASL +/// mechanism. +public enum SASLAuthenticationStepResult { + + /// More steps are needed. Assume neither success nor failure. If data is + /// provided, send it. A value of `nil` signifies sending no response at + /// all, whereas a value of `[]` signifies sending an empty response, which + /// may not be the same action depending on the underlying protocol + case `continue`(response: [UInt8]? = nil) + + /// Signal authentication success. If data is provided, send it. A value of + /// `nil` signifies sending no response at all, whereas a value of `[]` + /// signifies sending an empty response, which may not be the same action + /// depending on the underlying protocol. + case succeed(response: [UInt8]? = nil) + + /// Signal authentication failure. If data is provided, send it. A value of + /// `nil` signifies sending no response at all, whereas a value of `[]` + /// signifies sending an empty response, which may not be the same action + /// depending on the underlying protocol. The provided error, if any, is + /// surfaced. If none is provided, a generic failure is surfaced instead. + case fail(response: [UInt8]? = nil, error: Error? = nil) + +} + +/// The protocol to which all SASL mechanism implementations must conform. It is +/// the responsibility of each individual implementation to provide an API for +/// creating instances of itself which are able to retrieve information from the +/// caller (such as usernames and passwords) by some mechanism. +public protocol SASLAuthenticationMechanism { + + /// The IANA-registered SASL mechanism name. This may be a family prefix or + /// a specific mechanism name. It is explicitly suitable for use in + /// negotiation via whatever underlying application-specific protocol is in + /// use for the purpose. + static var name: String { get } + + /// Single-step the mechanism. The message may be `nil` in particular when + /// the local side of the negotiation is a client starting its initial + /// authentication request. + func step(message: [UInt8]?) -> SASLAuthenticationStepResult + +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift new file mode 100644 index 00000000..fb0bfce1 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift @@ -0,0 +1,22 @@ +import _ConnectionPoolModule +import XCTest + +final class ConnectionIDGeneratorTests: XCTestCase { + func testGenerateConnectionIDs() async { + let idGenerator = ConnectionIDGenerator() + + XCTAssertEqual(idGenerator.next(), 0) + XCTAssertEqual(idGenerator.next(), 1) + XCTAssertEqual(idGenerator.next(), 2) + + await withTaskGroup(of: Void.self) { taskGroup in + for _ in 0..<1000 { + taskGroup.addTask { + _ = idGenerator.next() + } + } + } + + XCTAssertEqual(idGenerator.next(), 1003) + } +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift new file mode 100644 index 00000000..c745b4a0 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -0,0 +1,858 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Atomics +import NIOEmbedded +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class ConnectionPoolTests: XCTestCase { + + func test1000ConsecutiveRequestsOnSingleConnection() async { + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + // the same connection is reused 1000 times + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask_ { + await pool.run() + } + + let createdConnection = await factory.nextConnectAttempt { _ in + return 1 + } + XCTAssertNotNil(createdConnection) + + do { + for _ in 0..<1000 { + async let connectionFuture = try await pool.leaseConnection() + var leasedConnection: MockConnection? + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + leasedConnection = try await connectionFuture + XCTAssertNotNil(leasedConnection) + XCTAssert(createdConnection === leasedConnection) + + if let leasedConnection { + pool.releaseConnection(leasedConnection) + } + } + } catch { + XCTFail("Unexpected error: \(error)") + } + + taskGroup.cancelAll() + + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + + XCTAssertEqual(factory.runningConnections.count, 0) + } + + func testShutdownPoolWhileConnectionIsBeingCreated() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask_ { + await pool.run() + } + + let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) + let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) + + taskGroup.addTask_ { + _ = try? await factory.nextConnectAttempt { _ in + blockCancelContinuation.yield() + var iterator = blockConnCreationStream.makeAsyncIterator() + await iterator.next() + throw ConnectionCreationError() + } + } + + var iterator = blockCancelStream.makeAsyncIterator() + await iterator.next() + + taskGroup.cancelAll() + blockConnCreationContinuation.yield() + } + + struct ConnectionCreationError: Error {} + } + + func testShutdownPoolWhileConnectionIsBackingOff() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask_ { + await pool.run() + } + + _ = try? await factory.nextConnectAttempt { _ in + throw ConnectionCreationError() + } + + await clock.nextTimerScheduled() + + taskGroup.cancelAll() + } + + struct ConnectionCreationError: Error {} + } + + func testConnectionHardLimitIsRespected() async { + let factory = MockConnectionFactory() + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 8 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + let hasFinished = ManagedAtomic(false) + let createdConnections = ManagedAtomic(0) + let iterations = 10_000 + + // the same connection is reused 1000 times + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask_ { + await pool.run() + XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) + } + + taskGroup.addTask_ { + var usedConnectionIDs = Set() + for _ in 0..() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + print("clock advanced to: \(newTime)") + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + print(deadline3) + + // race keep alive vs timeout + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testKeepAliveOnClose() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(20) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + + await keepAlive.nextKeepAlive { keepAliveConnection in + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + let failingKeepAliveDidRun = ManagedAtomic(false) + // the following keep alive should not cause a crash + _ = try? await keepAlive.nextKeepAlive { keepAliveConnection in + defer { + XCTAssertFalse(failingKeepAliveDidRun + .compareExchange(expected: false, desired: true, ordering: .relaxed).original) + } + XCTAssertTrue(keepAliveConnection === lease1Connection) + keepAliveConnection.close() + throw CancellationError() // any error + } // will fail and it's expected + XCTAssertTrue(failingKeepAliveDidRun.load(ordering: .relaxed)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testKeepAliveWorksRacesAgainstShutdown() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + taskGroup.cancelAll() + print("cancelled") + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testCancelConnectionRequestWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let leaseTask = Task { + _ = try await pool.leaseConnection() + } + + let connectionAttemptWaiter = Future(of: Void.self) + + taskGroup.addTask { + try await factory.nextConnectAttempt { connectionID in + connectionAttemptWaiter.yield(value: ()) + throw CancellationError() + } + } + + try await connectionAttemptWaiter.success + leaseTask.cancel() + + let taskResult = await leaseTask.result + switch taskResult { + case .success: + XCTFail("Expected task failure") + case .failure(let failure): + XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled) + } + + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testLeasingMultipleConnectionsAtOnceWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + var connections = [MockConnection]() + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that we got 4 distinct connections + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + + // release all 4 leased connections + for connection in connections { + pool.releaseConnection(connection) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + do { + _ = try await pool.leaseConnection() + XCTFail("Expected a failure") + } catch { + print("failed") + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + + print("will close connections: \(factory.runningConnections)") + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } + + func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + + for request in requests { + do { + _ = try await request.future.success + XCTFail("Expected a failure") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + } + + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } + + func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 10 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + let requests = (0..<10).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 10 + } + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + // release all 10 leased streams + for connection in connections { + pool.releaseConnection(connection) + } + + for _ in 0..<9 { + _ = try? await factory.nextConnectAttempt { connectionID in + throw CancellationError() + } + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testIncreasingAvailableStreamsWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + var requests = (0..<21).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 1 + } + + let connection = try await requests.first!.future.success + connections.append(connection) + requests.removeFirst() + + pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + + for (_, request) in requests.enumerated() { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + requests = (22..<42).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + + // release all 21 leased streams in a single call + pool.releaseConnection(connection, streams: 21) + + // ensure all 20 new requests got fulfilled + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // release all 20 leased streams one by one + for _ in requests { + pool.releaseConnection(connection, streams: 1) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } +} + +struct ConnectionFuture: ConnectionRequestProtocol { + let id: Int + let future: Future + + init(id: Int) { + self.id = id + self.future = Future(of: MockConnection.self) + } + + func complete(with result: Result) { + switch result { + case .success(let success): + self.future.yield(value: success) + case .failure(let failure): + self.future.yield(error: failure) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift new file mode 100644 index 00000000..537efbd9 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -0,0 +1,28 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +final class ConnectionRequestTests: XCTestCase { + + func testHappyPath() async throws { + let mockConnection = MockConnection(id: 1) + let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = ConnectionRequest(id: 42, continuation: continuation) + XCTAssertEqual(request.id, 42) + continuation.resume(with: .success(mockConnection)) + } + + XCTAssert(connection === mockConnection) + } + + func testSadPath() async throws { + do { + _ = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + continuation.resume(with: .failure(ConnectionPoolError.requestCancelled)) + } + XCTFail("This point should not be reached") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .requestCancelled) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift new file mode 100644 index 00000000..081e867b --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift @@ -0,0 +1,60 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class Max2SequenceTests: XCTestCase { + func testCountAndIsEmpty() async { + var sequence = Max2Sequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + } + + func testOptionalInitializer() { + let emptySequence = Max2Sequence(nil, nil) + XCTAssertEqual(emptySequence.count, 0) + XCTAssertEqual(emptySequence.isEmpty, true) + var emptySequenceIterator = emptySequence.makeIterator() + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + + let oneElemSequence1 = Max2Sequence(1, nil) + XCTAssertEqual(oneElemSequence1.count, 1) + XCTAssertEqual(oneElemSequence1.isEmpty, false) + var oneElemSequence1Iterator = oneElemSequence1.makeIterator() + XCTAssertEqual(oneElemSequence1Iterator.next(), 1) + XCTAssertNil(oneElemSequence1Iterator.next()) + XCTAssertNil(oneElemSequence1Iterator.next()) + + let oneElemSequence2 = Max2Sequence(nil, 2) + XCTAssertEqual(oneElemSequence2.count, 1) + XCTAssertEqual(oneElemSequence2.isEmpty, false) + var oneElemSequence2Iterator = oneElemSequence2.makeIterator() + XCTAssertEqual(oneElemSequence2Iterator.next(), 2) + XCTAssertNil(oneElemSequence2Iterator.next()) + XCTAssertNil(oneElemSequence2Iterator.next()) + + let twoElemSequence = Max2Sequence(1, 2) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), 1) + XCTAssertEqual(twoElemSequenceIterator.next(), 2) + XCTAssertNil(twoElemSequenceIterator.next()) + } + + func testMap() { + let twoElemSequence = Max2Sequence(1, 2).map({ "\($0)" }) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), "1") + XCTAssertEqual(twoElemSequenceIterator.next(), "2") + XCTAssertNil(twoElemSequenceIterator.next()) + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift new file mode 100644 index 00000000..27035ee9 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift @@ -0,0 +1,18 @@ +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct MockTimerCancellationToken: Hashable, Sendable { + enum Backing: Hashable, Sendable { + case timer(TestPoolStateMachine.Timer) + case connectionTimer(TestPoolStateMachine.ConnectionTimer) + } + var backing: Backing + + init(_ timer: TestPoolStateMachine.Timer) { + self.backing = .timer(timer) + } + + init(_ timer: TestPoolStateMachine.ConnectionTimer) { + self.backing = .connectionTimer(timer) + } +} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift new file mode 100644 index 00000000..b1b54023 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -0,0 +1,11 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class NoKeepAliveBehaviorTests: XCTestCase { + func testNoKeepAlive() { + let keepAliveBehavior = NoOpKeepAliveBehavior(connectionType: MockConnection.self) + XCTAssertNil(keepAliveBehavior.keepAliveFrequency) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift new file mode 100644 index 00000000..b09bfcb4 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -0,0 +1,328 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionGroupTests: XCTestCase { + var idGenerator: ConnectionIDGenerator! + + override func setUp() { + self.idGenerator = ConnectionIDGenerator() + super.setUp() + } + + override func tearDown() { + self.idGenerator = nil + super.tearDown() + } + + func testRefillConnections() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 4, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + XCTAssertTrue(connections.isEmpty) + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + + XCTAssertEqual(requests.count, 4) + XCTAssertNil(connections.createNewDemandConnectionIfPossible()) + XCTAssertNil(connections.createNewOverflowConnectionIfPossible()) + XCTAssertEqual(connections.stats, .init(connecting: 4)) + XCTAssertEqual(connections.soonAvailableConnections, 4) + + let requests2 = connections.refillConnections() + XCTAssertTrue(requests2.isEmpty) + + var connected: UInt16 = 0 + for request in requests { + let newConnection = MockConnection(id: request.connectionID) + let (_, context) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(context.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(context.use, .persisted) + connected += 1 + XCTAssertEqual(connections.stats, .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) + XCTAssertEqual(connections.soonAvailableConnections, 4 - connected) + } + + let requests3 = connections.refillConnections() + XCTAssertTrue(requests3.isEmpty) + } + + func testMakeConnectionLeaseItAndDropItHappyPath() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertTrue(connections.isEmpty) + XCTAssertTrue(requests.isEmpty) + + guard let request = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to receive a connection request") + } + XCTAssertEqual(request, .init(connectionID: 0)) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + let newConnection = MockConnection(id: request.connectionID) + let (_, establishedContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 0) + + guard case .leasedConnection(let leaseResult) = connections.leaseConnectionOrSoonAvailableConnectionCount() else { + return XCTFail("Expected to lease a connection") + } + XCTAssert(newConnection === leaseResult.connection) + XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) + + guard let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) else { + return XCTFail("Expected that this connection is still active") + } + XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + let parkTimers = connections.parkConnection(at: index, hasBecomeIdle: true) + XCTAssertEqual(parkTimers, [ + .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), + .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), + ]) + + guard let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssert(newConnection === keepAliveAction.connection) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, pingPongContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to get an AvailableContext") + } + XCTAssertEqual(pingPongContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + guard let closeAction = connections.closeConnectionIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssertEqual(closeAction.timersToCancel, []) + XCTAssert(closeAction.connection === newConnection) + XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) + + let closeContext = connections.connectionClosed(newConnection.id) + XCTAssertEqual(closeContext.connectionsStarting, 0) + XCTAssertTrue(connections.isEmpty) + XCTAssertEqual(connections.stats, .init()) + } + + func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 1) + + guard let request = requests.first else { return XCTFail("Expected to receive a connection request") } + XCTAssertEqual(request, .init(connectionID: 0)) + + let backoffTimer = connections.backoffNextConnectionAttempt(request.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + let backoffDoneAction = connections.backoffDone(request.connectionID, retry: false) + XCTAssertEqual(backoffDoneAction, .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) + + XCTAssertEqual(connections.stats, .init(connecting: 1)) + } + + func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 2, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 2) + + var requestIterator = requests.makeIterator() + guard let firstRequest = requestIterator.next(), let secondRequest = requestIterator.next() else { + return XCTFail("Expected to get two requests") + } + + guard let thirdRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 3)) + + let newSecondConnection = MockConnection(id: secondRequest.connectionID) + let (_, establishedSecondConnectionContext) = connections.newConnectionEstablished(newSecondConnection, maxStreams: 1) + XCTAssertEqual(establishedSecondConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedSecondConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(connecting: 2, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + + let newThirdConnection = MockConnection(id: thirdRequest.connectionID) + let (thirdConnectionIndex, establishedThirdConnectionContext) = connections.newConnectionEstablished(newThirdConnection, maxStreams: 1) + XCTAssertEqual(establishedThirdConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedThirdConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 2, availableStreams: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) + let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) + let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) + XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true), [thirdConnKeepTimer, thirdConnIdleTimer]) + + XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) + XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) + + let backoffTimer = connections.backoffNextConnectionAttempt(firstRequest.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + + // connection three should be moved to connection one and for this reason become permanent + + XCTAssertEqual(connections.backoffDone(firstRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 2, availableStreams: 2)) + + XCTAssertNil(connections.closeConnectionIfIdle(newThirdConnection.id)) + } + + func testBackoffDoneReturnsNilIfOverflowConnection() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get two requests") + } + + guard let secondRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 2)) + + let newFirstConnection = MockConnection(id: firstRequest.connectionID) + let (_, establishedFirstConnectionContext) = connections.newConnectionEstablished(newFirstConnection, maxStreams: 1) + XCTAssertEqual(establishedFirstConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedFirstConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + + let backoffTimer = connections.backoffNextConnectionAttempt(secondRequest.connectionID) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 1, availableStreams: 1)) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + XCTAssertEqual(connections.backoffDone(secondRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + XCTAssertNotNil(connections.closeConnectionIfIdle(newFirstConnection.id)) + } + + func testPingPong() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + XCTAssertEqual(requests.count, 1) + guard let firstRequest = requests.first else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + let timers = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertEqual(timers, [keepAliveTimer]) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, afterPingIdleContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to receive an AvailableContext") + } + XCTAssertEqual(afterPingIdleContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(afterPingIdleContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + } + + func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 2, + maximumConcurrentConnectionHardLimit: 2, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + _ = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + _ = connections.closeConnectionIfIdle(newConnection.id) + guard connections.keepAliveFailed(newConnection.id) == nil else { + return XCTFail("Expected keepAliveFailed not to cause close again") + } + XCTAssertEqual(connections.stats, .init(closing: 1)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift new file mode 100644 index 00000000..7dd2b726 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -0,0 +1,265 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionStateTests: XCTestCase { + + typealias TestConnectionState = TestPoolStateMachine.ConnectionState + + func testStartupLeaseReleaseParkLease() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + XCTAssertEqual(state.id, connectionID) + XCTAssertEqual(state.isIdle, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, false) + XCTAssertEqual(state.isLeased, false) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(state.isIdle, true) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, false) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.isIdle, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, true) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + let expectLeaseAction = TestConnectionState.LeaseAction( + connection: connection, + timersToCancel: [idleTimerCancellationToken, keepAliveTimerCancellationToken], + wasIdle: true + ) + XCTAssertEqual(state.lease(streams: 1), expectLeaseAction) + } + + func testStartupParkLeaseBeforeTimersRegistered() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssertEqual( + parkResult, + [ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken), keepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken), idleTimerCancellationToken) + } + + func testStartupParkLeasePark() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let initialKeepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let initialIdleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual( + state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true), + [ + .init(timerID: 2, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 3, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken), initialKeepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken), initialIdleTimerCancellationToken) + } + + func testStartupFailed() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let firstBackoffTimer = state.failedToConnect() + let firstBackoffTimerCancellationToken = MockTimerCancellationToken(firstBackoffTimer) + XCTAssertNil(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken)) + XCTAssertEqual(state.retryConnect(), firstBackoffTimerCancellationToken) + + let secondBackoffTimer = state.failedToConnect() + let secondBackoffTimerCancellationToken = MockTimerCancellationToken(secondBackoffTimer) + XCTAssertNil(state.retryConnect()) + XCTAssertEqual( + state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken), + secondBackoffTimerCancellationToken + ) + + let thirdBackoffTimer = state.failedToConnect() + let thirdBackoffTimerCancellationToken = MockTimerCancellationToken(thirdBackoffTimer) + XCTAssertNil(state.retryConnect()) + let forthBackoffTimer = state.failedToConnect() + let forthBackoffTimerCancellationToken = MockTimerCancellationToken(forthBackoffTimer) + XCTAssertEqual( + state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken), + thirdBackoffTimerCancellationToken + ) + XCTAssertNil( + state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) + ) + XCTAssertEqual(state.retryConnect(), forthBackoffTimerCancellationToken) + + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + } + + func testLeaseMultipleStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [keepAliveTimerCancellationToken], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual(state.release(streams: 1), .leased(availableStreams: 1)) + XCTAssertEqual(state.release(streams: 98), .leased(availableStreams: 99)) + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 100, newIdle: true)) + } + + func testRunningKeepAliveReducesAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: true), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 79)) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual( + state.lease(streams: 79), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 1)) + XCTAssertEqual(state.isAvailable, true) + } + + func testRunningKeepAliveDoesNotReduceAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: false), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 80)) + } + + func testRunKeepAliveRacesAgainstIdleClose() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + XCTAssertEqual(keepAliveTimer, .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) + XCTAssertEqual(idleTimer, .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) + XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) + + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift new file mode 100644 index 00000000..b74b86cc --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -0,0 +1,148 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_RequestQueueTests: XCTestCase { + + typealias TestQueue = TestPoolStateMachine.RequestQueue + + func testHappyPath() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + let request1 = MockRequest() + queue.queue(request1) + XCTAssertEqual(queue.count, 1) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + func testEnqueueAndPopMultipleRequests() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request2, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testEnqueueAndPopOnlyOne() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 1) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssertFalse(queue.isEmpty) + XCTAssertEqual(queue.count, 2) + + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request2, request3]) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testRemoveAllAfterCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request1, request3]) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift new file mode 100644 index 00000000..c0b6ddcd --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -0,0 +1,385 @@ +@testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +typealias TestPoolStateMachine = PoolStateMachine< + MockConnection, + ConnectionIDGenerator, + MockConnection.ID, + MockRequest, + MockRequest.ID, + MockTimerCancellationToken +> + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachineTests: XCTestCase { + + func testConnectionsAreCreatedAndParkedOnStartup() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 2 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = .seconds(10) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + let connection2 = MockConnection(id: 1) + + do { + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 2) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + let connection1KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 0, usecase: .keepAlive), duration: .seconds(10)) + let connection1KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection1KeepAliveTimer) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([connection1KeepAliveTimer])) + + XCTAssertEqual(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken), .none) + + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + let connection2KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: .seconds(10)) + let connection2KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection2KeepAliveTimer) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .scheduleTimers([connection2KeepAliveTimer])) + XCTAssertEqual(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken), .none) + } + } + + func testConnectionsNoKeepAliveRun() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(5) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // release connection 1 + XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) + + // lease connection 1 + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) + + // request connection while none is available + let request3 = MockRequest() + let leaseRequest3 = stateMachine.leaseConnection(request3) + XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest3.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request3), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + XCTAssertEqual(stateMachine.timerTriggered(connection2IdleTimer), .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) + } + + func testOnlyOverflowConnections() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 0 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .leaseConnection(.init(element: request2), connection1)) + XCTAssertEqual(releaseRequest1.connection, .none) + + // connection 2 comes up and should be closed right away + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .closeConnection(connection2, [])) + XCTAssertEqual(stateMachine.connectionClosed(connection2), .none()) + + // release connection 1 should be closed as well + let releaseRequest2 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest2.request, .none) + XCTAssertEqual(releaseRequest2.connection, .closeConnection(connection1, [])) + + let shutdownAction = stateMachine.triggerForceShutdown() + XCTAssertEqual(shutdownAction.request, .failRequests(.init(), .poolShutdown)) + XCTAssertEqual(shutdownAction.connection, .shutdown(.init())) + } + + func testDemandConnectionIsMadePermanentIfPermanentIsClose() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + + // connection 1 is dropped + XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) + } + + func testReleaseLoosesRaceAgainstClosed() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // connection got closed + let closedAction = stateMachine.connectionClosed(connection1) + XCTAssertEqual(closedAction.connection, .none) + XCTAssertEqual(closedAction.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .none) + XCTAssertEqual(releaseRequest1.connection, .none) + } + + func testKeepAliveOnClosingConnection() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + _ = stateMachine.releaseConnection(connection1, streams: 1) + + // trigger keep alive + let keepAliveAction1 = stateMachine.connectionKeepAliveTimerTriggered(connection1.id) + XCTAssertEqual(keepAliveAction1.connection, .runKeepAlive(connection1, nil)) + + // fail keep alive and cause closed + let keepAliveFailed1 = stateMachine.connectionKeepAliveFailed(connection1.id) + XCTAssertEqual(keepAliveFailed1.connection, .closeConnection(connection1, [])) + connection1.closeIfClosing() + + // request connection while none exists anymore + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + _ = stateMachine.releaseConnection(connection2, streams: 1) + + // trigger keep alive while connection is still open + let keepAliveAction2 = stateMachine.connectionKeepAliveTimerTriggered(connection2.id) + XCTAssertEqual(keepAliveAction2.connection, .runKeepAlive(connection2, nil)) + + // close connection in the middle of keep alive + connection2.close() + connection2.closeIfClosing() + + // fail keep alive and cause closed + let keepAliveFailed2 = stateMachine.connectionKeepAliveFailed(connection2.id) + XCTAssertEqual(keepAliveFailed2.connection, .closeConnection(connection2, [])) + } + + func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + + // one connection should exist + let request = MockRequest() + let leaseRequest = stateMachine.leaseConnection(request) + XCTAssertEqual(leaseRequest.connection, .none) + XCTAssertEqual(leaseRequest.request, .none) + + // make connection 1 + let connection = MockConnection(id: 0) + let createdAction = stateMachine.connectionEstablished(connection, maxStreams: 1) + XCTAssertEqual(createdAction.request, .leaseConnection(.init(element: request), connection)) + XCTAssertEqual(createdAction.connection, .none) + _ = stateMachine.releaseConnection(connection, streams: 1) + + // trigger keep alive + let keepAliveAction = stateMachine.connectionKeepAliveTimerTriggered(connection.id) + XCTAssertEqual(keepAliveAction.connection, .runKeepAlive(connection, nil)) + + // fail keep alive, cause closed and make new connection + let keepAliveFailed = stateMachine.connectionKeepAliveFailed(connection.id) + XCTAssertEqual(keepAliveFailed.connection, .closeConnection(connection, [])) + let connectionClosed = stateMachine.connectionClosed(connection) + XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) + connection.closeIfClosing() + let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) + XCTAssertEqual(establishAction.request, .none) + guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } + XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) + } + +} diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift new file mode 100644 index 00000000..1a2836b9 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -0,0 +1,72 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class TinyFastSequenceTests: XCTestCase { + func testCountIsEmptyAndIterator() async { + var sequence = TinyFastSequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + XCTAssertEqual(sequence.first, nil) + XCTAssertEqual(Array(sequence), []) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1]) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2]) + sequence.append(3) + XCTAssertEqual(sequence.count, 3) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2, 3]) + } + + func testReserveCapacityIsForwarded() { + var emptySequence = TinyFastSequence() + emptySequence.reserveCapacity(8) + emptySequence.append(1) + emptySequence.append(2) + emptySequence.append(3) + guard case .n(let array) = emptySequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var oneElemSequence = TinyFastSequence(element: 1) + oneElemSequence.reserveCapacity(8) + oneElemSequence.append(2) + oneElemSequence.append(3) + guard case .n(let array) = oneElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var twoElemSequence = TinyFastSequence([1, 2]) + twoElemSequence.reserveCapacity(8) + guard case .n(let array) = twoElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + } + + func testNewSequenceSlowPath() { + let sequence = TinyFastSequence("AB".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) + } + + func testSingleItem() { + let sequence = TinyFastSequence("A".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) + } + + func testEmptyCollection() { + let sequence = TinyFastSequence("".utf8) + XCTAssertTrue(sequence.isEmpty) + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(Array(sequence), []) + } +} diff --git a/Tests/ConnectionPoolModuleTests/Utils/Future.swift b/Tests/ConnectionPoolModuleTests/Utils/Future.swift new file mode 100644 index 00000000..2bee3216 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Utils/Future.swift @@ -0,0 +1,112 @@ +import Atomics +@testable import _ConnectionPoolModule + +/// This is a `Future` type that shall make writing tests a bit simpler. I'm well aware, that this is a pattern +/// that should not be embraced with structured concurrency. However writing all tests in full structured +/// concurrency is an effort, that isn't worth the endgoals in my view. +final class Future: Sendable { + struct State: Sendable { + + var result: Swift.Result? = nil + var continuations: [(Int, CheckedContinuation)] = [] + + } + + let waiterID = ManagedAtomic(0) + let stateBox: NIOLockedValueBox = NIOLockedValueBox(State()) + + init(of: Success.Type) {} + + enum GetAction { + case fail(any Error) + case succeed(Success) + case none + } + + var success: Success { + get async throws { + let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.stateBox.withLockedValue { state -> GetAction in + if Task.isCancelled { + return .fail(CancellationError()) + } + + switch state.result { + case .none: + state.continuations.append((waiterID, continuation)) + return .none + + case .success(let result): + return .succeed(result) + + case .failure(let error): + return .fail(error) + } + } + + switch action { + case .fail(let error): + continuation.resume(throwing: error) + + case .succeed(let result): + continuation.resume(returning: result) + + case .none: + break + } + } + } onCancel: { + let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in + guard state.result == nil else { return nil } + + guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else { + return nil + } + let (_, continuation) = state.continuations.remove(at: contIndex) + return continuation + } + + cont?.resume(throwing: CancellationError()) + } + } + } + + func yield(value: Success) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .success(value) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(returning: value) + } + } + + func yield(error: any Error) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .failure(error) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(throwing: error) + } + } +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift new file mode 100644 index 00000000..b4c8e93f --- /dev/null +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -0,0 +1,583 @@ +import Logging +import XCTest +import PostgresNIO +#if canImport(Network) +import NIOTransportServices +#endif +import NIOPosix +import NIOCore + +final class AsyncPostgresConnectionTests: XCTestCase { + func test1kRoundTrips() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + for _ in 0..<1_000 { + let rows = try await connection.query("SELECT version()", logger: .psqlTest) + var iterator = rows.makeAsyncIterator() + let firstRow = try await iterator.next() + XCTAssertEqual(try firstRow?.decode(String.self, context: .default).contains("PostgreSQL"), true) + let done = try await iterator.next() + XCTAssertNil(done) + } + } + } + + func testSelect10kRows() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + for try await row in rows { + let element = try row.decode(Int.self) + XCTAssertEqual(element, counter + 1) + counter += 1 + } + + XCTAssertEqual(counter, end) + } + } + + func testSelectActiveConnection() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + pid + ,datname + ,usename + ,application_name + ,client_hostname + ,client_port + ,backend_start + ,query_start + ,query + ,state + FROM pg_stat_activity + WHERE state = 'active'; + """ + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode((Int, String, String, String, String?, Int, Date, Date, String, String).self) { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + XCTAssertEqual(element.2, env("POSTGRES_USER") ?? "test_username") + + XCTAssertEqual(element.8, query.sql) + XCTAssertEqual(element.9, "active") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + + func testAdditionalParametersTakeEffect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + current_setting('application_name'); + """ + + let applicationName = "postgres-nio-test" + var options = PostgresConnection.Configuration.Options() + options.additionalStartupParameters = [ + ("application_name", applicationName) + ] + + try await withTestConnection(on: eventLoop, options: options) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode(String.self) { + XCTAssertEqual(element, applicationName) + + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + + func testSelectTimeoutWhileLongRunningQuery() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000000 + + try await withTestConnection(on: eventLoop) { connection -> () in + try await connection.query("SET statement_timeout=1000;", logger: .psqlTest) + + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + do { + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter + 1) + counter += 1 + } + XCTFail("Expected to get cancelled while reading the query") + } catch { + guard let error = error as? PSQLError else { return XCTFail("Unexpected error type") } + + XCTAssertEqual(error.code, .server) + XCTAssertEqual(error.serverInfo?[.severity], "ERROR") + } + + XCTAssertFalse(connection.isClosed, "Connection should survive!") + + for num in 0..<10 { + for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) { + XCTAssertEqual(decoded, num) + } + } + } + } + + func testConnectionSurvives1kQueriesWithATypo() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection -> () in + for _ in 0..<1000 { + do { + try await connection.query("SELECT generte_series(\(start), \(end));", logger: .psqlTest) + XCTFail("Expected to throw from the request") + } catch { + guard let error = error as? PSQLError else { return XCTFail("Unexpected error type: \(error)") } + + XCTAssertEqual(error.code, .server) + XCTAssertEqual(error.serverInfo?[.severity], "ERROR") + } + } + + // the connection survived all of this, we can still run normal queries: + + for num in 0..<10 { + for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) { + XCTAssertEqual(decoded, num) + } + } + } + } + + func testSelect10times10kRows() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + await withThrowingTaskGroup(of: Void.self) { taskGroup in + for _ in 0..<10 { + taskGroup.addTask { + try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + } + } + } + } + } + + func testBindMaximumParameters() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + // Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257 + // Max columns limit is 1664, so we will only make 5 * 257 columns which is less + // Then we will insert 3 * 17 rows + // In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings + // If the test is successful, it means Postgres supports UInt16.max bindings + let columnsCount = 5 * 257 + let rowsCount = 3 * 17 + + let createQuery = PostgresQuery( + unsafeSQL: """ + CREATE TABLE table1 ( + \((0.. String in + "$\(rowIndex * columnsCount + columnIndex + 1)" + } + return "(\(indices.joined(separator: ", ")))" + }.joined(separator: ", ") + let insertionQuery = PostgresQuery( + unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)", + binds: binds + ) + try await connection.query(insertionQuery, logger: .psqlTest) + + let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1") + let countRows = try await connection.query(countQuery, logger: .psqlTest) + var countIterator = countRows.makeAsyncIterator() + let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default) + XCTAssertEqual(rowsCount, insertedRowsCount) + + let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1") + try await connection.query(dropQuery, logger: .psqlTest) + } + } + + func testListenAndNotify() async throws { + let channelNames = [ + "foo", + "default" + ] + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + for channelName in channelNames { + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen(channelName) + var iterator = stream.makeAsyncIterator() + + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'bar';"#, logger: .psqlTest) + + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'foo';"#, logger: .psqlTest) + } + + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") + + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } + } + } + + #if canImport(Network) + func testSelect10kRowsNetworkFramework() async throws { + let eventLoopGroup = NIOTSEventLoopGroup() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 1 + for try await row in rows { + let element = try row.decode(Int.self, context: .default) + XCTAssertEqual(element, counter) + counter += 1 + } + + XCTAssertEqual(counter, end + 1) + } + } + #endif + + func testCancelTaskThatIsVeryLongRunningWhichAlsoFailsWhileInStreamingMode() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + // we cancel the query after 400ms. + // the server times out the query after 1sec. + + try await withTestConnection(on: eventLoop) { connection -> () in + try await connection.query("SET statement_timeout=1000;", logger: .psqlTest) // 1000 milliseconds + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + let start = 1 + let end = 100_000_000 + + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + do { + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter + 1) + counter += 1 + } + XCTFail("Expected to get cancelled while reading the query") + XCTAssertEqual(counter, end) + } catch let error as CancellationError { + XCTAssertGreaterThanOrEqual(counter, 1) + // Expected + print("\(error)") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertTrue(Task.isCancelled) + XCTAssertFalse(connection.isClosed, "Connection should survive!") + } + + let delay: UInt64 = 400_000_000 // 400 milliseconds + try await Task.sleep(nanoseconds: delay) + + group.cancelAll() + } + + try await connection.query("SELECT 1;", logger: .psqlTest) + } + } + + func testPreparedStatement() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + let preparedStatement = TestPreparedStatement(state: "active") + try await withTestConnection(on: eventLoop) { connection in + var results = try await connection.execute(preparedStatement, logger: .psqlTest) + var counter = 0 + + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + + // Second execution, which reuses the existing prepared statement + results = try await connection.execute(preparedStatement, logger: .psqlTest) + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + } + } + + static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable" + func testPreparedStatementWithIntegerBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(uuid)" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } +} + +extension XCTestCase { + + func withTestConnection( + on eventLoop: EventLoop, + options: PostgresConnection.Configuration.Options? = nil, + file: StaticString = #filePath, + line: UInt = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + let connection = try await PostgresConnection.test(on: eventLoop, options: options).get() + + do { + let result = try await closure(connection) + try await connection.close() + return result + } catch { + XCTFail("Unexpected error: \(String(reflecting: error))", file: file, line: line) + try await connection.close() + throw error + } + } +} diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift new file mode 100644 index 00000000..d541899b --- /dev/null +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -0,0 +1,382 @@ +import Atomics +import XCTest +import Logging +import PostgresNIO +import NIOCore +import NIOPosix +import NIOTestUtils + +final class IntegrationTests: XCTestCase { + + func testConnectAndClose() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + XCTAssertNoThrow(try conn?.close().wait()) + } + + func testAuthenticationFailure() throws { + // If the postgres server trusts every connection, it is really hard to create an + // authentication failure. + try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") + + let config = PostgresConnection.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: "wrong_password", + database: env("POSTGRES_DB") ?? "test_database", + tls: .disable + ) + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + var logger = Logger.psqlTest + logger.logLevel = .info + + var connection: PostgresConnection? + XCTAssertThrowsError(connection = try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { + XCTAssertTrue($0 is PSQLError) + } + + // In case of a test failure the created connection must be closed. + XCTAssertNoThrow(try connection?.close().wait()) + } + + func testQueryVersion() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT version()", logger: .psqlTest).wait()) + let rows = result?.rows + var version: String? + XCTAssertNoThrow(version = try rows?.first?.decode(String.self, context: .default)) + XCTAssertEqual(version?.contains("PostgreSQL"), true) + } + + func testQuery10kItems() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var metadata: PostgresQueryMetadata? + let received = ManagedAtomic(0) + XCTAssertNoThrow(metadata = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest) { row in + func workaround() { + let expected = received.wrappingIncrementThenLoad(ordering: .relaxed) + XCTAssertEqual(expected, try row.decode(Int64.self, context: .default)) + } + + workaround() + }.wait()) + + XCTAssertEqual(received.load(ordering: .relaxed), 10000) + XCTAssertEqual(metadata?.command, "SELECT") + XCTAssertEqual(metadata?.rows, 10000) + } + + func test1kRoundTrips() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + for _ in 0..<1_000 { + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT version()", logger: .psqlTest).wait()) + var version: String? + XCTAssertNoThrow(version = try result?.rows.first?.decode(String.self, context: .default)) + XCTAssertEqual(version?.contains("PostgreSQL"), true) + } + } + + func testQuerySelectParameter() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT \("hello")::TEXT as foo", logger: .psqlTest).wait()) + var foo: String? + XCTAssertNoThrow(foo = try result?.rows.first?.decode(String.self, context: .default)) + XCTAssertEqual(foo, "hello") + } + + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + + func testDecodeIntegers() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + 1::SMALLINT as smallint, + -32767::SMALLINT as smallint_min, + 32767::SMALLINT as smallint_max, + 1::INT as int, + -2147483647::INT as int_min, + 2147483647::INT as int_max, + 1::BIGINT as bigint, + -9223372036854775807::BIGINT as bigint_min, + 9223372036854775807::BIGINT as bigint_max + """, logger: .psqlTest).wait()) + + XCTAssertEqual(result?.rows.count, 1) + let row = result?.rows.first + + var cells: (Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64)? + XCTAssertNoThrow(cells = try row?.decode((Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64).self, context: .default)) + + XCTAssertEqual(cells?.0, 1) + XCTAssertEqual(cells?.1, -32_767) + XCTAssertEqual(cells?.2, 32_767) + XCTAssertEqual(cells?.3, 1) + XCTAssertEqual(cells?.4, -2_147_483_647) + XCTAssertEqual(cells?.5, 2_147_483_647) + XCTAssertEqual(cells?.6, 1) + XCTAssertEqual(cells?.7, -9_223_372_036_854_775_807) + XCTAssertEqual(cells?.8, 9_223_372_036_854_775_807) + } + + func testEncodeAndDecodeIntArray() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + let array: [Int64] = [1, 2, 3] + XCTAssertNoThrow(result = try conn?.query("SELECT \(array)::int8[] as array", logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Int64].self, context: .default), array) + } + + func testDecodeEmptyIntegerArray() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) + + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Int64].self, context: .default), []) + } + + func testDoubleArraySerialization() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + let doubles: [Double] = [3.14, 42] + XCTAssertNoThrow(result = try conn?.query("SELECT \(doubles)::double precision[] as doubles", logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Double].self, context: .default), doubles) + } + + func testDecodeDates() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + '2016-01-18 01:02:03 +0042'::DATE as date, + '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, + '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz + """, logger: .psqlTest).wait()) + + XCTAssertEqual(result?.rows.count, 1) + + var cells: (Date, Date, Date)? + XCTAssertNoThrow(cells = try result?.rows.first?.decode((Date, Date, Date).self, context: .default)) + + XCTAssertEqual(cells?.0.description, "2016-01-18 00:00:00 +0000") + XCTAssertEqual(cells?.1.description, "2016-01-18 01:02:03 +0000") + XCTAssertEqual(cells?.2.description, "2016-01-18 00:20:03 +0000") + } + + func testDecodeDecimals() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + \(Decimal(string: "123456.789123")!)::numeric as numeric, + \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative + """, logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + + var cells: (Decimal, Decimal)? + XCTAssertNoThrow(cells = try result?.rows.first?.decode((Decimal, Decimal).self, context: .default)) + + XCTAssertEqual(cells?.0, Decimal(string: "123456.789123")) + XCTAssertEqual(cells?.1, Decimal(string: "-123456.789123")) + } + + func testDecodeRawRepresentables() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + enum StringRR: String, PostgresDecodable { + case a + } + + enum IntRR: Int, PostgresDecodable { + case b + } + + let stringValue = StringRR.a + let intValue = IntRR.b + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + \(stringValue.rawValue)::varchar as string, + \(intValue.rawValue)::int8 as int + """, logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + + var cells: (StringRR, IntRR)? + XCTAssertNoThrow(cells = try result?.rows.first?.decode((StringRR, IntRR).self, context: .default)) + + XCTAssertEqual(cells?.0, stringValue) + XCTAssertEqual(cells?.1, intValue) + } + + func testRoundTripUUID() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + let uuidString = "2c68f645-9ca6-468b-b193-ee97f241c2f8" + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT \(uuidString)::UUID as uuid + """, + logger: .psqlTest + ).wait()) + + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode(UUID.self, context: .default), UUID(uuidString: uuidString)) + } + + func testRoundTripJSONB() { + struct Object: Codable, PostgresCodable { + let foo: Int + let bar: Int + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + do { + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + select \(Object(foo: 1, bar: 2))::jsonb as jsonb + """, logger: .psqlTest).wait()) + + XCTAssertEqual(result?.rows.count, 1) + var obj: Object? + XCTAssertNoThrow(obj = try result?.rows.first?.decode(Object.self, context: .default)) + XCTAssertEqual(obj?.foo, 1) + XCTAssertEqual(obj?.bar, 2) + } + + do { + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + select \(Object(foo: 1, bar: 2))::json as json + """, logger: .psqlTest).wait()) + + XCTAssertEqual(result?.rows.count, 1) + var obj: Object? + XCTAssertNoThrow(obj = try result?.rows.first?.decode(Object.self, context: .default)) + XCTAssertEqual(obj?.foo, 1) + XCTAssertEqual(obj?.bar, 2) + } + } + +} diff --git a/Tests/PostgresNIOTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift similarity index 80% rename from Tests/PostgresNIOTests/PerformanceTests.swift rename to Tests/IntegrationTests/PerformanceTests.swift index a26748c4..6f730560 100644 --- a/Tests/PostgresNIOTests/PerformanceTests.swift +++ b/Tests/IntegrationTests/PerformanceTests.swift @@ -1,6 +1,8 @@ -import Logging -import PostgresNIO import XCTest +import Logging +import NIOCore +import NIOPosix +@testable import PostgresNIO import NIOTestUtils final class PerformanceTests: XCTestCase { @@ -36,7 +38,7 @@ final class PerformanceTests: XCTestCase { do { for _ in 0..<5 { try conn.query("SELECT * FROM generate_series(1, 10000) num") { row in - _ = row.column("num")?.int + _ = try row.decode(Int.self, context: .default) }.wait() } } catch { @@ -63,14 +65,15 @@ final class PerformanceTests: XCTestCase { measure { do { try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("int")?.int - }.wait() + _ = try row.decode(Int.self, context: .default) + }.wait() } catch { XCTFail("\(error)") } } } + @available(*, deprecated, message: "Testing deprecated functionality") func testPerformanceSelectMediumModel() throws { let conn = try PostgresConnection.test(on: eventLoop).wait() defer { try! conn.close().wait() } @@ -99,12 +102,13 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int - _ = row.column("string")?.string - _ = row.column("int")?.int - _ = row.column("date")?.date - _ = row.column("uuid")?.uuid + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int + _ = row[data: "string"].string + _ = row[data: "int"].int + _ = row[data: "date"].date + _ = row[data: "uuid"].uuid }.wait() } catch { XCTFail("\(error)") @@ -112,6 +116,7 @@ final class PerformanceTests: XCTestCase { } } + @available(*, deprecated, message: "Testing deprecated functionality") func testPerformanceSelectLargeModel() throws { let conn = try PostgresConnection.test(on: eventLoop).wait() defer { try! conn.close().wait() } @@ -172,28 +177,29 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int - _ = row.column("string1")?.string - _ = row.column("string2")?.string - _ = row.column("string3")?.string - _ = row.column("string4")?.string - _ = row.column("string5")?.string - _ = row.column("int1")?.int - _ = row.column("int2")?.int - _ = row.column("int3")?.int - _ = row.column("int4")?.int - _ = row.column("int5")?.int - _ = row.column("date1")?.date - _ = row.column("date2")?.date - _ = row.column("date3")?.date - _ = row.column("date4")?.date - _ = row.column("date5")?.date - _ = row.column("uuid1")?.uuid - _ = row.column("uuid2")?.uuid - _ = row.column("uuid3")?.uuid - _ = row.column("uuid4")?.uuid - _ = row.column("uuid5")?.uuid + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int + _ = row[data: "string1"].string + _ = row[data: "string2"].string + _ = row[data: "string3"].string + _ = row[data: "string4"].string + _ = row[data: "string5"].string + _ = row[data: "int1"].int + _ = row[data: "int2"].int + _ = row[data: "int3"].int + _ = row[data: "int4"].int + _ = row[data: "int5"].int + _ = row[data: "date1"].date + _ = row[data: "date2"].date + _ = row[data: "date3"].date + _ = row[data: "date4"].date + _ = row[data: "date5"].date + _ = row[data: "uuid1"].uuid + _ = row[data: "uuid2"].uuid + _ = row[data: "uuid3"].uuid + _ = row[data: "uuid4"].uuid + _ = row[data: "uuid5"].uuid }.wait() } catch { XCTFail("\(error)") @@ -217,10 +223,11 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int for fieldName in fieldNames { - _ = row.column(fieldName)?.int + _ = row[data: fieldName].int } }.wait() } catch { @@ -245,10 +252,11 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int for fieldName in fieldNames { - _ = row.column(fieldName)?.int + _ = row[data: fieldName].int } }.wait() } catch { @@ -265,7 +273,7 @@ private func prepareTableToMeasureSelectPerformance( schema: String, fixtureData: [PostgresData], on eventLoop: EventLoop, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) throws { XCTAssertEqual(rowCount % batchSize, 0, "`rowCount` must be a multiple of `batchSize`", file: (file), line: line) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift new file mode 100644 index 00000000..34a8ad2a --- /dev/null +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -0,0 +1,317 @@ +@_spi(ConnectionPool) import PostgresNIO +import XCTest +import NIOPosix +import NIOSSL +import Logging +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PostgresClientTests: XCTestCase { + + func testGetConnection() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + let iterations = 1000 + + for _ in 0.. PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + for try await (id, uuid) in try await client.execute(Example(id: 200), logger: logger) { + logger.info("id: \(id), uuid: \(uuid.uuidString)") + } + + try await client.query( + """ + DROP TABLE "\(unescaped: tableName)"; + """, + logger: logger + ) + + taskGroup.cancelAll() + } + } catch { + XCTFail("Unexpected error: \(String(reflecting: error))") + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PostgresClient.Configuration { + static func makeTestConfiguration() -> PostgresClient.Configuration { + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + var clientConfig = PostgresClient.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap({ Int($0) }) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database", + tls: .prefer(tlsConfiguration) + ) + clientConfig.options.minimumConnections = 0 + clientConfig.options.maximumConnections = 12*4 + clientConfig.options.keepAliveBehavior = .init(frequency: .seconds(5)) + clientConfig.options.connectionIdleTimeout = .seconds(15) + + return clientConfig + } +} diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift new file mode 100644 index 00000000..9a58f050 --- /dev/null +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -0,0 +1,1467 @@ +import Logging +@testable import PostgresNIO +import Atomics +import XCTest +import NIOCore +import NIOPosix +import NIOTestUtils +import NIOSSL + +final class PostgresNIOTests: XCTestCase { + + private var group: EventLoopGroup! + private var eventLoop: EventLoop { self.group.next() } + + override class func setUp() { + XCTAssertTrue(isLoggingConfigured) + } + + override func setUpWithError() throws { + try super.setUpWithError() + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + + override func tearDownWithError() throws { + try self.group?.syncShutdownGracefully() + self.group = nil + try super.tearDownWithError() + } + + // MARK: Tests + + func testConnectAndClose() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + XCTAssertNoThrow(try conn?.close().wait()) + } + + func testConnectUDSAndClose() throws { + try XCTSkipUnless(env("POSTGRES_SOCKET") != nil) + let conn = try PostgresConnection.testUDS(on: eventLoop).wait() + try conn.close().wait() + } + + func testConnectEstablishedChannelAndClose() throws { + let channel = try ClientBootstrap(group: self.group).connect(to: PostgresConnection.address()).wait() + let conn = try PostgresConnection.testChannel(channel, on: self.eventLoop).wait() + try conn.close().wait() + } + + func testSimpleQueryVersion() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + + func testSimpleQueryVersionUsingUDS() throws { + try XCTSkipUnless(env("POSTGRES_SOCKET") != nil) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.testUDS(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + + func testSimpleQueryVersionUsingEstablishedChannel() throws { + let channel = try ClientBootstrap(group: self.group).connect(to: PostgresConnection.address()).wait() + let conn = try PostgresConnection.testChannel(channel, on: self.eventLoop).wait() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + let rows = try conn.simpleQuery("SELECT version()").wait() + XCTAssertEqual(rows.count, 1) + XCTAssertEqual(try rows.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + + func testQueryVersion() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("SELECT version()", .init()).wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + + func testQuerySelectParameter() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("SELECT $1::TEXT as foo", ["hello"]).wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default), "hello") + } + + func testSQLError() throws { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + XCTAssertThrowsError(_ = try conn?.simpleQuery("SELECT &").wait()) { error in + XCTAssertEqual((error as? PostgresError)?.code, .syntaxError) + } + } + + func testNotificationsEmptyPayload() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + let receivedNotifications = ManagedAtomic(0) + conn?.addListener(channel: "example") { context, notification in + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "") + } + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) + } + + func testNotificationsNonEmptyPayload() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let receivedNotifications = ManagedAtomic(0) + conn?.addListener(channel: "example") { context, notification in + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "Notification payload example") + } + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait()) + // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) + } + + func testNotificationsRemoveHandlerWithinHandler() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let receivedNotifications = ManagedAtomic(0) + conn?.addListener(channel: "example") { context, notification in + receivedNotifications.wrappingIncrement(ordering: .relaxed) + context.stop() + } + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) + } + + func testNotificationsRemoveHandlerOutsideHandler() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let receivedNotifications = ManagedAtomic(0) + let context = conn?.addListener(channel: "example") { context, notification in + receivedNotifications.wrappingIncrement(ordering: .relaxed) + } + XCTAssertNotNil(context) + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + context?.stop() + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) + } + + func testNotificationsMultipleRegisteredHandlers() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let receivedNotifications1 = ManagedAtomic(0) + conn?.addListener(channel: "example") { context, notification in + receivedNotifications1.wrappingIncrement(ordering: .relaxed) + } + let receivedNotifications2 = ManagedAtomic(0) + conn?.addListener(channel: "example") { context, notification in + receivedNotifications2.wrappingIncrement(ordering: .relaxed) + } + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1) + } + + func testNotificationsMultipleRegisteredHandlersRemoval() throws { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let receivedNotifications1 = ManagedAtomic(0) + XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in + receivedNotifications1.wrappingIncrement(ordering: .relaxed) + context.stop() + }) + let receivedNotifications2 = ManagedAtomic(0) + XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in + receivedNotifications2.wrappingIncrement(ordering: .relaxed) + }) + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2) + } + + func testNotificationHandlerFiltersOnChannel() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + XCTAssertNotNil(conn?.addListener(channel: "desired") { context, notification in + XCTFail("Received notification on channel that handler was not registered for") + }) + XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN undesired").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY undesired").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) + } + + func testSelectTypes() throws { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var results: [PostgresRow]? + XCTAssertNoThrow(results = try conn?.simpleQuery("SELECT * FROM pg_type").wait()) + XCTAssert((results?.count ?? 0) > 350, "Results count not large enough") + } + + func testSelectType() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var results: [PostgresRow]? + XCTAssertNoThrow(results = try conn?.simpleQuery("SELECT * FROM pg_type WHERE typname = 'float8'").wait()) + // [ + // "typreceive": "float8recv", + // "typelem": "0", + // "typarray": "1022", + // "typalign": "d", + // "typanalyze": "-", + // "typtypmod": "-1", + // "typname": "float8", + // "typnamespace": "11", + // "typdefault": "", + // "typdefaultbin": "", + // "typcollation": "0", + // "typispreferred": "t", + // "typrelid": "0", + // "typbyval": "t", + // "typnotnull": "f", + // "typinput": "float8in", + // "typlen": "8", + // "typcategory": "N", + // "typowner": "10", + // "typtype": "b", + // "typdelim": ",", + // "typndims": "0", + // "typbasetype": "0", + // "typacl": "", + // "typisdefined": "t", + // "typmodout": "-", + // "typmodin": "-", + // "typsend": "float8send", + // "typstorage": "p", + // "typoutput": "float8out" + // ] + XCTAssertEqual(results?.count, 1) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "typname"].string, "float8") + XCTAssertEqual(row?[data: "typnamespace"].int, 11) + XCTAssertEqual(row?[data: "typowner"].int, 10) + XCTAssertEqual(row?[data: "typlen"].int, 8) + } + + func testIntegers() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + struct Integers: Decodable { + let smallint: Int16 + let smallint_min: Int16 + let smallint_max: Int16 + let int: Int32 + let int_min: Int32 + let int_max: Int32 + let bigint: Int64 + let bigint_min: Int64 + let bigint_max: Int64 + } + var results: PostgresQueryResult? + XCTAssertNoThrow(results = try conn?.query(""" + SELECT + 1::SMALLINT as smallint, + -32767::SMALLINT as smallint_min, + 32767::SMALLINT as smallint_max, + 1::INT as int, + -2147483647::INT as int_min, + 2147483647::INT as int_max, + 1::BIGINT as bigint, + -9223372036854775807::BIGINT as bigint_min, + 9223372036854775807::BIGINT as bigint_max + """).wait()) + XCTAssertEqual(results?.count, 1) + + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "smallint"].int16, 1) + XCTAssertEqual(row?[data: "smallint_min"].int16, -32_767) + XCTAssertEqual(row?[data: "smallint_max"].int16, 32_767) + XCTAssertEqual(row?[data: "int"].int32, 1) + XCTAssertEqual(row?[data: "int_min"].int32, -2_147_483_647) + XCTAssertEqual(row?[data: "int_max"].int32, 2_147_483_647) + XCTAssertEqual(row?[data: "bigint"].int64, 1) + XCTAssertEqual(row?[data: "bigint_min"].int64, -9_223_372_036_854_775_807) + XCTAssertEqual(row?[data: "bigint_max"].int64, 9_223_372_036_854_775_807) + } + + func testPi() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + struct Pi: Decodable { + let text: String + let numeric_string: String + let numeric_decimal: Decimal + let double: Double + let float: Float + } + var results: PostgresQueryResult? + XCTAssertNoThrow(results = try conn?.query(""" + SELECT + pi()::TEXT as text, + pi()::NUMERIC as numeric_string, + pi()::NUMERIC as numeric_decimal, + pi()::FLOAT8 as double, + pi()::FLOAT4 as float + """).wait()) + XCTAssertEqual(results?.count, 1) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "text"].string?.hasPrefix("3.14159265"), true) + XCTAssertEqual(row?[data: "numeric_string"].string?.hasPrefix("3.14159265"), true) + XCTAssertTrue(row?[data: "numeric_decimal"].decimal?.isLess(than: 3.14159265358980) ?? false) + XCTAssertFalse(row?[data: "numeric_decimal"].decimal?.isLess(than: 3.14159265358978) ?? true) + XCTAssertTrue(row?[data: "double"].double?.description.hasPrefix("3.141592") ?? false) + XCTAssertTrue(row?[data: "float"].float?.description.hasPrefix("3.141592") ?? false) + } + + func testUUID() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + struct Model: Decodable { + let id: UUID + let string: String + } + var results: PostgresQueryResult? + XCTAssertNoThrow(results = try conn?.query(""" + SELECT + '123e4567-e89b-12d3-a456-426655440000'::UUID as id, + '123e4567-e89b-12d3-a456-426655440000'::UUID as string + """).wait()) + XCTAssertEqual(results?.count, 1) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "id"].uuid, UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) + XCTAssertEqual(UUID(uuidString: row?[data: "id"].string ?? ""), UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) + } + + func testInt4Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let results1: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(Int32.min), \(Int32.max))'::int4range AS range + """).get() + XCTAssertEqual(results1.count, 1) + var row = results1.first?.makeRandomAccess() + let expectedRange: Range = Int32.min...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + let results2 = try await conn.query(""" + SELECT + ARRAY[ + '[0, 1)'::int4range, + '[10, 11)'::int4range + ] AS ranges + """).get() + XCTAssertEqual(results2.count, 1) + row = results2.first?.makeRandomAccess() + let decodedRangeArray = try row?.decode(column: "ranges", as: [Range].self, context: .default) + let decodedClosedRangeArray = try row?.decode(column: "ranges", as: [ClosedRange].self, context: .default) + XCTAssertEqual(decodedRangeArray, [0..<1, 10..<11]) + XCTAssertEqual(decodedClosedRangeArray, [0...0, 10...10]) + } + + func testEmptyInt4Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let randomValue = Int32.random(in: Int32.min...Int32.max) + let results: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(randomValue),\(randomValue))'::int4range AS range + """).get() + XCTAssertEqual(results.count, 1) + let row = results.first?.makeRandomAccess() + let expectedRange: Range = Int32.valueForEmptyRange...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + XCTAssertThrowsError( + try row?.decode(column: "range", as: ClosedRange.self, context: .default) + ) + } + + func testInt8Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let results1: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(Int64.min), \(Int64.max))'::int8range AS range + """).get() + XCTAssertEqual(results1.count, 1) + var row = results1.first?.makeRandomAccess() + let expectedRange: Range = Int64.min...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + let results2: PostgresQueryResult = try await conn.query(""" + SELECT + ARRAY[ + '[0, 1)'::int8range, + '[10, 11)'::int8range + ] AS ranges + """).get() + XCTAssertEqual(results2.count, 1) + row = results2.first?.makeRandomAccess() + let decodedRangeArray = try row?.decode(column: "ranges", as: [Range].self, context: .default) + let decodedClosedRangeArray = try row?.decode(column: "ranges", as: [ClosedRange].self, context: .default) + XCTAssertEqual(decodedRangeArray, [0..<1, 10..<11]) + XCTAssertEqual(decodedClosedRangeArray, [0...0, 10...10]) + } + + func testEmptyInt8Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let randomValue = Int64.random(in: Int64.min...Int64.max) + let results: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(randomValue),\(randomValue))'::int8range AS range + """).get() + XCTAssertEqual(results.count, 1) + let row = results.first?.makeRandomAccess() + let expectedRange: Range = Int64.valueForEmptyRange...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + XCTAssertThrowsError( + try row?.decode(column: "range", as: ClosedRange.self, context: .default) + ) + } + + func testDates() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + struct Dates: Decodable { + var date: Date + var timestamp: Date + var timestamptz: Date + } + var results: PostgresQueryResult? + XCTAssertNoThrow(results = try conn?.query(""" + SELECT + '2016-01-18 01:02:03 +0042'::DATE as date, + '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, + '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz + """).wait()) + XCTAssertEqual(results?.count, 1) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "date"].date?.description, "2016-01-18 00:00:00 +0000") + XCTAssertEqual(row?[data: "timestamp"].date?.description, "2016-01-18 01:02:03 +0000") + XCTAssertEqual(row?[data: "timestamptz"].date?.description, "2016-01-18 00:20:03 +0000") + } + + /// https://github.com/vapor/nio-postgres/issues/20 + func testBindInteger() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + XCTAssertNoThrow(_ = try conn?.simpleQuery("drop table if exists person;").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("create table person(id serial primary key, first_name text, last_name text);").wait()) + defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("drop table person;").wait()) } + let id = PostgresData(int32: 5) + XCTAssertNoThrow(_ = try conn?.query("SELECT id, first_name, last_name FROM person WHERE id = $1", [id]).wait()) + } + + // https://github.com/vapor/nio-postgres/issues/21 + func testAverageLengthNumeric() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var results: PostgresQueryResult? + XCTAssertNoThrow(results = try conn?.query("select avg(length('foo')) as average_length").wait()) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: 0].double, 3.0) + } + + func testNumericParsing() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + '1234.5678'::numeric as a, + '-123.456'::numeric as b, + '123456.789123'::numeric as c, + '3.14159265358979'::numeric as d, + '10000'::numeric as e, + '0.00001'::numeric as f, + '100000000'::numeric as g, + '0.000000001'::numeric as h, + '100000000000'::numeric as i, + '0.000000000001'::numeric as j, + '123000000000'::numeric as k, + '0.000000000123'::numeric as l, + '0.5'::numeric as m + """).wait()) + XCTAssertEqual(rows?.count, 1) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "a"].string, "1234.5678") + XCTAssertEqual(row?[data: "b"].string, "-123.456") + XCTAssertEqual(row?[data: "c"].string, "123456.789123") + XCTAssertEqual(row?[data: "d"].string, "3.14159265358979") + XCTAssertEqual(row?[data: "e"].string, "10000") + XCTAssertEqual(row?[data: "f"].string, "0.00001") + XCTAssertEqual(row?[data: "g"].string, "100000000") + XCTAssertEqual(row?[data: "h"].string, "0.000000001") + XCTAssertEqual(row?[data: "k"].string, "123000000000") + XCTAssertEqual(row?[data: "l"].string, "0.000000000123") + XCTAssertEqual(row?[data: "m"].string, "0.5") + } + + func testSingleNumericParsing() { + // this seemingly duped test is useful for debugging numeric parsing + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let numeric = "790226039477542363.6032384900176272473" + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + '\(numeric)'::numeric as n + """).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "n"].string, numeric) + } + + func testRandomlyGeneratedNumericParsing() throws { + // this test takes a long time to run + try XCTSkipUnless(Self.shouldRunLongRunningTests) + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + for _ in 0..<1_000_000 { + let integer = UInt.random(in: UInt.min.. = Int32.min..? = try row?.decode(Range.self, context: .default) + XCTAssertEqual(range, decodedRange) + } + do { + let emptyRange: Range = Int32.min..? = try row?.decode(Range.self, context: .default) + let expectedRange: Range = Int32.valueForEmptyRange.. = Int32.min...(Int32.max - 1) + var binds = PostgresBindings() + binds.append(closedRange, context: .default) + let query = PostgresQuery( + unsafeSQL: "select $1::int4range as range", + binds: binds + ) + let rowSequence: PostgresRowSequence? = try await conn.query(query, logger: .psqlTest) + var rowIterator: PostgresRowSequence.AsyncIterator? = rowSequence?.makeAsyncIterator() + let row: PostgresRow? = try await rowIterator?.next() + let decodedClosedRange: ClosedRange? = try row?.decode(ClosedRange.self, context: .default) + XCTAssertEqual(closedRange, decodedClosedRange) + } + } + + func testInt8RangeSerialize() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + do { + let range: Range = Int64.min..? = try row?.decode(Range.self, context: .default) + XCTAssertEqual(range, decodedRange) + } + do { + let emptyRange: Range = Int64.min..? = try row?.decode(Range.self, context: .default) + let expectedRange: Range = Int64.valueForEmptyRange.. = Int64.min...(Int64.max - 1) + var binds = PostgresBindings() + binds.append(closedRange, context: .default) + let query = PostgresQuery( + unsafeSQL: "select $1::int8range as range", + binds: binds + ) + let rowSequence: PostgresRowSequence? = try await conn.query(query, logger: .psqlTest) + var rowIterator: PostgresRowSequence.AsyncIterator? = rowSequence?.makeAsyncIterator() + let row: PostgresRow? = try await rowIterator?.next() + let decodedClosedRange: ClosedRange? = try row?.decode(ClosedRange.self, context: .default) + XCTAssertEqual(closedRange, decodedClosedRange) + } + } + + @available(*, deprecated, message: "Test deprecated functionality") + func testFailingTLSConnectionClosesConnection() { + // There was a bug (https://github.com/vapor/postgres-nio/issues/133) where we would hit + // an assert because we didn't close the connection. This test should succeed without hitting + // the assert + + // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj + + // We should get an error because you can't use an IP address for SNI, but we shouldn't bomb out by + // hitting the assert + XCTAssertThrowsError( + try PostgresConnection.connect( + to: SocketAddress.makeAddressResolvingHost("elmer.db.elephantsql.com", port: 5432), + tlsConfiguration: .makeClientConfiguration(), + serverHostname: "34.228.73.168", + on: eventLoop + ).wait() + ) + // If we hit this, we're all good + XCTAssertTrue(true) + } + + @available(*, deprecated, message: "Test deprecated functionality") + func testInvalidPassword() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.testUnauthenticated(on: eventLoop).wait()) + let authFuture = conn?.authenticate(username: "invalid", database: "invalid", password: "bad") + XCTAssertThrowsError(_ = try authFuture?.wait()) { error in + XCTAssert((error as? PostgresError)?.code == .invalidPassword || (error as? PostgresError)?.code == .invalidAuthorizationSpecification) + } + + // in this case the connection will be closed by the remote + XCTAssertNoThrow(try conn?.closeFuture.wait()) + } + + func testColumnsInJoin() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + let dateInTable1 = Date(timeIntervalSince1970: 1234) + let dateInTable2 = Date(timeIntervalSince1970: 5678) + XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS \"table1\"").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery(""" + CREATE TABLE table1 ( + "id" int8 NOT NULL, + "table2_id" int8, + "intValue" int8, + "stringValue" text, + "dateValue" timestamptz, + PRIMARY KEY ("id") + ); + """).wait()) + defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE \"table1\"").wait()) } + + XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS \"table2\"").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery(""" + CREATE TABLE table2 ( + "id" int8 NOT NULL, + "intValue" int8, + "stringValue" text, + "dateValue" timestamptz, + PRIMARY KEY ("id") + ); + """).wait()) + defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE \"table2\"").wait()) } + + XCTAssertNoThrow(_ = try conn?.simpleQuery("INSERT INTO table1 VALUES (12, 34, 56, 'stringInTable1', to_timestamp(1234))").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("INSERT INTO table2 VALUES (34, 78, 'stringInTable2', to_timestamp(5678))").wait()) + + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + SELECT + "table1"."id" as "t1_id", + "table1"."intValue" as "t1_intValue", + "table1"."dateValue" as "t1_dateValue", + "table1"."stringValue" as "t1_stringValue", + "table2"."id" as "t2_id", + "table2"."intValue" as "t2_intValue", + "table2"."dateValue" as "t2_dateValue", + "table2"."stringValue" as "t2_stringValue", + * + FROM table1 INNER JOIN table2 ON table1.table2_id = table2.id + """).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "t1_id"].int, 12) + XCTAssertEqual(row?[data: "table2_id"].int, 34) + XCTAssertEqual(row?[data: "t1_intValue"].int, 56) + XCTAssertEqual(row?[data: "t1_stringValue"].string, "stringInTable1") + XCTAssertEqual(row?[data: "t1_dateValue"].date, dateInTable1) + XCTAssertEqual(row?[data: "t2_id"].int, 34) + XCTAssertEqual(row?[data: "t2_intValue"].int, 78) + XCTAssertEqual(row?[data: "t2_stringValue"].string, "stringInTable2") + XCTAssertEqual(row?[data: "t2_dateValue"].date, dateInTable2) + } + + @available(*, deprecated, message: "Testing deprecated functionality") + func testStringArrays() { + let query = """ + SELECT + $1::uuid as "id", + $2::bigint as "revision", + $3::timestamp as "updated_at", + $4::timestamp as "created_at", + $5::text as "name", + $6::text[] as "countries", + $7::text[] as "languages", + $8::text[] as "currencies" + """ + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(query, [ + PostgresData(uuid: UUID(uuidString: "D2710E16-EB07-4FD6-A87E-B1BE41C9BD3D")!), + PostgresData(int: Int(0)), + PostgresData(date: Date(timeIntervalSince1970: 0)), + PostgresData(date: Date(timeIntervalSince1970: 0)), + PostgresData(string: "Foo"), + PostgresData(array: ["US"]), + PostgresData(array: ["en"]), + PostgresData(array: ["USD", "DKK"]), + ]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "countries"].array(of: String.self), ["US"]) + XCTAssertEqual(row?[data: "languages"].array(of: String.self), ["en"]) + XCTAssertEqual(row?[data: "currencies"].array(of: String.self), ["USD", "DKK"]) + } + + func testBindDate() { + // https://github.com/vapor/postgres-nio/issues/53 + let date = Date(timeIntervalSince1970: 1571425782) + let query = """ + SELECT $1::json as "date" + """ + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + XCTAssertThrowsError(_ = try conn?.query(query, [.init(date: date)]).wait()) { error in + guard let postgresError = try? XCTUnwrap(error as? PostgresError) else { return } + guard case let .server(serverError) = postgresError else { + XCTFail("Expected a .serverError but got \(postgresError)") + return + } + XCTAssertEqual(serverError.fields[.routine], "transformTypeCast") + } + + } + + func testBindCharString() { + // https://github.com/vapor/postgres-nio/issues/53 + let query = """ + SELECT $1::char as "char" + """ + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(query, [.init(string: "f")]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "char"].string, "f") + } + + func testBindCharUInt8() { + // https://github.com/vapor/postgres-nio/issues/53 + let query = """ + SELECT $1::char as "char" + """ + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(query, [.init(uint8: 42)]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "char"].string, "*") + } + + @available(*, deprecated, message: "Testing deprecated functionality") + func testDoubleArraySerialization() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let doubles: [Double] = [3.14, 42] + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + $1::double precision[] as doubles + """, [ + .init(array: doubles) + ]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "doubles"].array(of: Double.self), doubles) + } + + // https://github.com/vapor/postgres-nio/issues/42 + func testUInt8Serialization() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + $1::"char" as int + """, [ + .init(uint8: 5) + ]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "int"].uint8, 5) + } + + func testPreparedQuery() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var prepared: PreparedQuery? + XCTAssertNoThrow(prepared = try conn?.prepare(query: "SELECT 1 as one;").wait()) + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try prepared?.execute().wait()) + + XCTAssertEqual(rows?.count, 1) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "one"].int, 1) + } + + func testPrepareQueryClosure() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var queries: [[PostgresRow]]? + XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in + let a = query.execute(["a"]) + let b = query.execute(["b"]) + let c = query.execute(["c"]) + return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop) + }).wait()) + XCTAssertEqual(queries?.count, 3) + var resultIterator = queries?.makeIterator() + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "a") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "b") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "c") + } + + // https://github.com/vapor/postgres-nio/issues/122 + func testPreparedQueryNoResults() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS \"table_no_results\"").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery(""" + CREATE TABLE table_no_results ( + "id" int8 NOT NULL, + "stringValue" text, + PRIMARY KEY ("id") + ); + """).wait()) + defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE \"table_no_results\"").wait()) } + + XCTAssertNoThrow(_ = try conn?.prepare(query: "DELETE FROM \"table_no_results\" WHERE id = $1").wait()) + } + + + // https://github.com/vapor/postgres-nio/issues/71 + func testChar1Serialization() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + '5'::char(1) as one, + '5'::char(2) as two + """).wait()) + + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "one"].uint8, 53) + XCTAssertEqual(row?[data: "one"].int16, 53) + XCTAssertEqual(row?[data: "one"].string, "5") + XCTAssertEqual(row?[data: "two"].uint8, nil) + XCTAssertEqual(row?[data: "two"].int16, nil) + XCTAssertEqual(row?[data: "two"].string, "5 ") + } + + func testUserDefinedType() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + XCTAssertNoThrow(_ = try conn?.query("DROP TYPE IF EXISTS foo").wait()) + XCTAssertNoThrow(_ = try conn?.query("CREATE TYPE foo AS ENUM ('bar', 'qux')").wait()) + defer { + XCTAssertNoThrow(_ = try conn?.query("DROP TYPE foo").wait()) + } + var res: PostgresQueryResult? + XCTAssertNoThrow(res = try conn?.query("SELECT 'qux'::foo as foo").wait()) + let row = res?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "foo"].string, "qux") + } + + @available(*, deprecated, message: "Testing deprecated functionality") + func testNullBind() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + var res: PostgresQueryResult? + XCTAssertNoThrow(res = try conn?.query("SELECT $1::text as foo", [String?.none.postgresData!]).wait()) + let row = res?.first?.makeRandomAccess() + XCTAssertNil(row?[data: "foo"].string) + } + + func testUpdateMetadata() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS test_table").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("CREATE TABLE test_table(pk int PRIMARY KEY)").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery("INSERT INTO test_table VALUES(1)").wait()) + XCTAssertNoThrow(try conn?.query("DELETE FROM test_table", onMetadata: { metadata in + XCTAssertEqual(metadata.command, "DELETE") + XCTAssertEqual(metadata.oid, nil) + XCTAssertEqual(metadata.rows, 1) + }, onRow: { _ in }).wait()) + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("DELETE FROM test_table").wait()) + XCTAssertEqual(rows?.metadata.command, "DELETE") + XCTAssertEqual(rows?.metadata.oid, nil) + XCTAssertEqual(rows?.metadata.rows, 0) + } + + func testTooManyBinds() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let binds = [PostgresData].init(repeating: .null, count: Int(UInt16.max) + 1) + XCTAssertThrowsError(try conn?.query("SELECT version()", binds).wait()) { error in + guard case .tooManyParameters = (error as? PSQLError)?.code.base else { + return XCTFail("Unexpected error: \(error)") + } + } + } + + func testRemoteClose() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + XCTAssertNoThrow( try conn?.channel.close().wait() ) + } + + // https://github.com/vapor/postgres-nio/issues/113 + @available(*, deprecated, message: "Testing deprecated functionality") + func testVaryingCharArray() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + var res: PostgresQueryResult? + XCTAssertNoThrow(res = try conn?.query(#"SELECT '{"foo", "bar", "baz"}'::VARCHAR[] as foo"#).wait()) + let row = res?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "foo"].array(of: String.self), ["foo", "bar", "baz"]) + } + + // https://github.com/vapor/postgres-nio/issues/115 + func testSetTimeZone() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + XCTAssertNoThrow(_ = try conn?.simpleQuery("SET TIME ZONE INTERVAL '+5:45' HOUR TO MINUTE").wait()) + XCTAssertNoThrow(_ = try conn?.query("SET TIME ZONE INTERVAL '+5:45' HOUR TO MINUTE").wait()) + } + + func testIntegerConversions() throws { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + 'a'::char as test8, + + '-32768'::smallint as min16, + '32767'::smallint as max16, + + '-2147483648'::integer as min32, + '2147483647'::integer as max32, + + '-9223372036854775808'::bigint as min64, + '9223372036854775807'::bigint as max64 + """).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "test8"].uint8, 97) + XCTAssertEqual(row?[data: "test8"].int16, 97) + XCTAssertEqual(row?[data: "test8"].int32, 97) + XCTAssertEqual(row?[data: "test8"].int64, 97) + + XCTAssertEqual(row?[data: "min16"].uint8, nil) + XCTAssertEqual(row?[data: "max16"].uint8, nil) + XCTAssertEqual(row?[data: "min16"].int16, .min) + XCTAssertEqual(row?[data: "max16"].int16, .max) + XCTAssertEqual(row?[data: "min16"].int32, -32768) + XCTAssertEqual(row?[data: "max16"].int32, 32767) + XCTAssertEqual(row?[data: "min16"].int64, -32768) + XCTAssertEqual(row?[data: "max16"].int64, 32767) + + XCTAssertEqual(row?[data: "min32"].uint8, nil) + XCTAssertEqual(row?[data: "max32"].uint8, nil) + XCTAssertEqual(row?[data: "min32"].int16, nil) + XCTAssertEqual(row?[data: "max32"].int16, nil) + XCTAssertEqual(row?[data: "min32"].int32, .min) + XCTAssertEqual(row?[data: "max32"].int32, .max) + XCTAssertEqual(row?[data: "min32"].int64, -2147483648) + XCTAssertEqual(row?[data: "max32"].int64, 2147483647) + + XCTAssertEqual(row?[data: "min64"].uint8, nil) + XCTAssertEqual(row?[data: "max64"].uint8, nil) + XCTAssertEqual(row?[data: "min64"].int16, nil) + XCTAssertEqual(row?[data: "max64"].int16, nil) + XCTAssertEqual(row?[data: "min64"].int32, nil) + XCTAssertEqual(row?[data: "max64"].int32, nil) + XCTAssertEqual(row?[data: "min64"].int64, .min) + XCTAssertEqual(row?[data: "max64"].int64, .max) + } +} + +let isLoggingConfigured: Bool = { + LoggingSystem.bootstrap { label in + var handler = StreamLogHandler.standardOutput(label: label) + handler.logLevel = env("LOG_LEVEL").flatMap { .init(rawValue: $0) } ?? .info + return handler + } + return true +}() diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift new file mode 100644 index 00000000..91dbb62e --- /dev/null +++ b/Tests/IntegrationTests/Utilities.swift @@ -0,0 +1,100 @@ +import XCTest +import PostgresNIO +import NIOCore +import Logging +#if canImport(Darwin) +import Darwin.C +#else +import Glibc +#endif + +extension PostgresConnection { + static func address() throws -> SocketAddress { + try .makeAddressResolvingHost(env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432) + } + + @available(*, deprecated, message: "Test deprecated functionality") + static func testUnauthenticated(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { + var logger = Logger(label: "postgres.connection.test") + logger.logLevel = logLevel + do { + return connect(to: try address(), logger: logger, on: eventLoop) + } catch { + return eventLoop.makeFailedFuture(error) + } + } + + static func test(on eventLoop: EventLoop, options: Configuration.Options? = nil) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") + var config = PostgresConnection.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database", + tls: .disable + ) + if let options { + config.options = options + } + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) + } + + static func testUDS(on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") + let config = PostgresConnection.Configuration( + unixSocketPath: env("POSTGRES_SOCKET") ?? "/tmp/.s.PGSQL.\(env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432)", + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database" + ) + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) + } + + static func testChannel(_ channel: Channel, on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") + let config = PostgresConnection.Configuration( + establishedChannel: channel, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database" + ) + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) + } +} + +extension Logger { + static var psqlTest: Logger { + .init(label: "psql.test") + } +} + +func env(_ name: String) -> String? { + getenv(name).flatMap { String(cString: $0) } +} + +extension XCTestCase { + + public static var shouldRunLongRunningTests: Bool { + // The env var must be set and have the value `"true"`, `"1"`, or `"yes"` (case-insensitive). + // For the sake of sheer annoying pedantry, values like `"2"` are treated as false. + guard let rawValue = env("POSTGRES_LONG_RUNNING_TESTS") else { return false } + if let boolValue = Bool(rawValue) { return boolValue } + if let intValue = Int(rawValue) { return intValue == 1 } + return rawValue.lowercased() == "yes" + } + + public static var shouldRunPerformanceTests: Bool { + // Same semantics as above. Any present non-truthy value will explicitly disable performance + // tests even if they would've overwise run in the current configuration. + let defaultValue = !_isDebugAssertConfiguration() // default to not running in debug builds + + guard let rawValue = env("POSTGRES_PERFORMANCE_TESTS") else { return defaultValue } + if let boolValue = Bool(rawValue) { return boolValue } + if let intValue = Int(rawValue) { return intValue == 1 } + return rawValue.lowercased() == "yes" + } +} diff --git a/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift new file mode 100644 index 00000000..47dd89a1 --- /dev/null +++ b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift @@ -0,0 +1,21 @@ +import PostgresNIO +import XCTest + +class PostgresData_JSONTests: XCTestCase { + @available(*, deprecated, message: "Testing deprecated functionality") + func testJSONBConvertible() { + struct Object: PostgresJSONBCodable { + let foo: Int + let bar: Int + } + + XCTAssertEqual(Object.postgresDataType, .jsonb) + + let postgresData = Object(foo: 1, bar: 2).postgresData + XCTAssertEqual(postgresData?.type, .jsonb) + + let object = Object(postgresData: postgresData!) + XCTAssertEqual(object?.foo, 1) + XCTAssertEqual(object?.bar, 2) + } +} diff --git a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift new file mode 100644 index 00000000..bbd022db --- /dev/null +++ b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift @@ -0,0 +1,39 @@ +import PostgresNIO +import XCTest +import NIOTestUtils +import NIOCore + +class PostgresMessageDecoderTests: XCTestCase { + @available(*, deprecated, message: "Tests deprecated API") + func testMessageDecoder() { + let sample: [UInt8] = [ + 0x52, // R - authentication + 0x00, 0x00, 0x00, 0x0C, // length = 12 + 0x00, 0x00, 0x00, 0x05, // md5 + 0x01, 0x02, 0x03, 0x04, // salt + 0x4B, // B - backend key data + 0x00, 0x00, 0x00, 0x0C, // length = 12 + 0x05, 0x05, 0x05, 0x05, // process id + 0x01, 0x01, 0x01, 0x01, // secret key + ] + var input = ByteBufferAllocator().buffer(capacity: 0) + input.writeBytes(sample) + + let output: [PostgresMessage] = [ + PostgresMessage(identifier: .authentication, bytes: [ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x02, 0x03, 0x04, + ]), + PostgresMessage(identifier: .backendKeyData, bytes: [ + 0x05, 0x05, 0x05, 0x05, + 0x01, 0x01, 0x01, 0x01, + ]) + ] + XCTAssertNoThrow(try XCTUnwrap(ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(input, output)], + decoderFactory: { + PostgresMessageDecoder() + } + ))) + } +} diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift new file mode 100644 index 00000000..df881f90 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -0,0 +1,160 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class AuthenticationStateMachineTests: XCTestCase { + + func testAuthenticatePlaintext() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + } + + func testAuthenticateMD5() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + let salt: UInt32 = 0x00_01_02_03 + + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + } + + func testAuthenticateMD5WithoutPassword() { + let authContext = AuthContext(username: "test", password: nil, database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + let salt: UInt32 = 0x00_01_02_03 + + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil))) + } + + func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + } + + func testAuthenticateSCRAMSHA256WithAtypicalEncoding() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + + let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"])) + guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else { + return XCTFail("\(saslResponse) is not .sendSaslInitialResponse") + } + let responseString = String(decoding: responseData, as: UTF8.self) + XCTAssertEqual(name, "SCRAM-SHA-256") + XCTAssert(responseString.starts(with: "n,,n=test,r=")) + + let saslContinueResponse = state.authenticationMessageReceived(.saslContinue(data: .init(bytes: + "r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,s=ijgUVaWgCDLRJyF963BKNA==,i=4096".utf8 + ))) + guard case .sendSaslResponse(let responseData2) = saslContinueResponse else { + return XCTFail("\(saslContinueResponse) is not .sendSaslResponse") + } + let response2String = String(decoding: responseData2, as: UTF8.self) + XCTAssertEqual(response2String.prefix(76), "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=") + } + + func testAuthenticationFailure() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + let salt: UInt32 = 0x00_01_02_03 + + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + let fields: [PostgresBackendMessage.Field: String] = [ + .message: "password authentication failed for user \"postgres\"", + .severity: "FATAL", + .sqlState: "28P01", + .localizedSeverity: "FATAL", + .routine: "auth_failed", + .line: "334", + .file: "auth.c" + ] + XCTAssertEqual(state.errorReceived(.init(fields: fields)), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .server(.init(fields: fields)), closePromise: nil))) + } + + // MARK: Test unsupported messages + + func testUnsupportedAuthMechanism() { + let unsupported: [(PostgresBackendMessage.Authentication, PSQLError.UnsupportedAuthScheme)] = [ + (.kerberosV5, .kerberosV5), + (.scmCredential, .scmCredential), + (.gss, .gss), + (.sspi, .sspi), + (.sasl(names: ["haha"]), .sasl(mechanisms: ["haha"])), + ] + + for (message, mechanism) in unsupported { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(message), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil))) + } + } + + func testUnexpectedMessagesAfterStartUp() { + var buffer = ByteBuffer() + buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) + let unexpected: [PostgresBackendMessage.Authentication] = [ + .gssContinue(data: buffer), + .saslContinue(data: buffer), + .saslFinal(data: buffer) + ] + + for message in unexpected { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(message), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil))) + } + } + + func testUnexpectedMessagesAfterPasswordSent() { + let salt: UInt32 = 0x00_01_02_03 + var buffer = ByteBuffer() + buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) + let unexpected: [PostgresBackendMessage.Authentication] = [ + .kerberosV5, + .md5(salt: salt), + .plaintext, + .scmCredential, + .gss, + .sspi, + .gssContinue(data: buffer), + .sasl(names: ["haha"]), + .saslContinue(data: buffer), + .saslFinal(data: buffer), + ] + + for message in unexpected { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + XCTAssertEqual(state.authenticationMessageReceived(message), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil))) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift new file mode 100644 index 00000000..f3d72a5e --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -0,0 +1,188 @@ +import XCTest +@testable import PostgresNIO +@testable import NIOCore +import NIOPosix +import NIOSSL + +class ConnectionStateMachineTests: XCTestCase { + + func testStartup() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + } + + func testSSLStartupSuccess() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) + XCTAssertEqual(state.sslHandlerAdded(), .wait) + XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + } + + func testSSLStartupFailureTooManyBytesRemaining() { + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + let failError = PSQLError.receivedUnencryptedDataAfterSSLRequest + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 1), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + } + + func testSSLStartupFailHandler() { + struct SSLHandlerAddError: Error, Equatable {} + + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) + let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) + XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + } + + func testTLSRequiredStartupSSLUnsupported() { + var state = ConnectionStateMachine(requireBackendKeyData: true) + + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + XCTAssertEqual(state.sslUnsupportedReceived(), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.sslUnsupported, closePromise: nil))) + } + + func testTLSPreferredStartupSSLUnsupported() { + var state = ConnectionStateMachine(requireBackendKeyData: true) + + XCTAssertEqual(state.connected(tls: .prefer), .sendSSLRequest) + XCTAssertEqual(state.sslUnsupportedReceived(), .provideAuthenticationContext) + } + + func testParameterStatusReceivedAndBackendKeyAfterAuthenticated() { + var state = ConnectionStateMachine(.authenticated(nil, [:])) + + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) + + XCTAssertEqual(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testBackendKeyAndParameterStatusReceivedAfterAuthenticated() { + var state = ConnectionStateMachine(.authenticated(nil, [:])) + + XCTAssertEqual(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)), .wait) + + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testReadyForQueryReceivedWithoutBackendKeyAfterAuthenticated() { + var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: true) + + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) + } + + func testReadyForQueryReceivedWithoutUnneededBackendKeyAfterAuthenticated() { + var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: false) + + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testErrorIsIgnoredWhenClosingConnection() { + // test ignore unclean shutdown when closing connection + var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil)) + + XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait) + XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) + + // test ignore any other error when closing connection + + var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil)) + XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait) + XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive) + } + + func testFailQueuedQueriesOnAuthenticationFailure() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + let salt: UInt32 = 0x00_01_02_03 + + let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) + + var state = ConnectionStateMachine(requireBackendKeyData: true) + let extendedQueryContext = ExtendedQueryContext( + query: "Select version()", + logger: .psqlTest, + promise: queryPromise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext)), .wait) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + let fields: [PostgresBackendMessage.Field: String] = [ + .message: "password authentication failed for user \"postgres\"", + .severity: "FATAL", + .sqlState: "28P01", + .localizedSeverity: "FATAL", + .routine: "auth_failed", + .line: "334", + .file: "auth.c" + ] + XCTAssertEqual(state.errorReceived(.init(fields: fields)), + .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext)], error: .server(.init(fields: fields)), closePromise: nil))) + + XCTAssertNil(queryPromise.futureResult._value) + + // make sure we don't crash + queryPromise.fail(PSQLError.server(.init(fields: fields))) + } +} diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift new file mode 100644 index 00000000..ae484acc --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -0,0 +1,298 @@ +import XCTest +import NIOCore +import NIOEmbedded +import Logging +@testable import PostgresNIO + +class ExtendedQueryStateMachineTests: XCTestCase { + + func testExtendedQueryWithoutDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "DELETE FROM table WHERE id=\(1)" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testExtendedQueryWithDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) + let row1: DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + + let row2: DataRow = [ByteBuffer(string: "test2")] + let row3: DataRow = [ByteBuffer(string: "test3")] + let row4: DataRow = [ByteBuffer(string: "test4")] + XCTAssertEqual(state.dataRowReceived(row2), .wait) + XCTAssertEqual(state.dataRowReceived(row3), .wait) + XCTAssertEqual(state.dataRowReceived(row4), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row2, row3, row4])) + XCTAssertEqual(state.requestQueryRows(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + let row5: DataRow = [ByteBuffer(string: "test5")] + let row6: DataRow = [ByteBuffer(string: "test6")] + XCTAssertEqual(state.dataRowReceived(row5), .wait) + XCTAssertEqual(state.dataRowReceived(row6), .wait) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testExtendedQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "-- some comments" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testReceiveTotallyUnexpectedMessageInQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "DELETE FROM table WHERE id=\(1)" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + let psqlError = PSQLError.unexpectedBackendMessage(.authentication(.ok)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), + .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) + } + + func testExtendedQueryIsCancelledImmediatly() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) + XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testExtendedQueryIsCancelledWithReadPending() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) + let row1: DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 4"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testCancelQueryAfterServerError() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) + let dataRows1: [DataRow] = [ + [ByteBuffer(string: "test1")], + [ByteBuffer(string: "test2")], + [ByteBuffer(string: "test3")] + ] + for row in dataRows1 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + XCTAssertEqual(state.channelReadComplete(), .forwardRows(dataRows1)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + let dataRows2: [DataRow] = [ + [ByteBuffer(string: "test4")], + [ByteBuffer(string: "test5")], + [ByteBuffer(string: "test6")] + ] + for row in dataRows2 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .forwardStreamError(.server(serverError), read: false, cleanupContext: .none)) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual( + state.errorReceived(serverError), .failQuery(promise, with: .server(serverError), cleanupContext: .none) + ) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorAfterCancelDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + +} diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift new file mode 100644 index 00000000..547f5cdf --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -0,0 +1,78 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PrepareStatementStateMachineTests: XCTestCase { + func testCreatePreparedStatementReturningRowDescription() { + var state = ConnectionStateMachine.readyForQuery() + + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + + let name = "haha" + let query = #"SELECT id FROM users WHERE id = $1 "# + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + let columns: [RowDescription.Column] = [ + .init(name: "id", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: -1, format: .binary) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: columns)), + .succeedPreparedStatementCreation(promise, with: .init(columns: columns))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testCreatePreparedStatementReturningNoData() { + var state = ConnectionStateMachine.readyForQuery() + + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + + let name = "haha" + let query = #"DELETE FROM users WHERE id = $1 "# + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + XCTAssertEqual(state.noDataReceived(), + .succeedPreparedStatementCreation(promise, with: nil)) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testErrorReceivedAfter() { + var state = ConnectionStateMachine.readyForQuery() + + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + + let name = "haha" + let query = #"DELETE FROM users WHERE id = $1 "# + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + XCTAssertEqual(state.noDataReceived(), + .succeedPreparedStatementCreation(promise, with: nil)) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + + XCTAssertEqual(state.authenticationMessageReceived(.ok), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(.ok)), closePromise: nil))) + } +} diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift new file mode 100644 index 00000000..e35e93f7 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -0,0 +1,160 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PreparedStatementStateMachineTests: XCTestCase { + func testPrepareAndExecuteStatement() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + XCTAssertNil(preparationCompleteAction.rowDescription) + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success(.tag("tag"))), + eventLoop: eventLoop, + logger: .psqlTest + )) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // The statement is already preparead, lookups tell us to execute it + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .executeStatement(nil) = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success(.tag("tag"))), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + func testPrepareAndExecuteStatementWithError() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Simulate an error occurring during preparation + let error = PSQLError(code: .server) + let preparationCompleteAction = stateMachine.errorHappened( + name: "test", + error: error + ) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + firstPreparedStatement.promise.fail(error) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Ensure that we don't try again to prepare a statement we know will fail + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .returnError = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.fail(error) + } + + func testBatchStatementPreparation() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // A new request comes in before the statement completes + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .waitForAlreadyInFlightPreparation = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state. + // The action tells us to execute both the pending statements. + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 2) + XCTAssertNil(preparationCompleteAction.rowDescription) + + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success(.tag("tag"))), + eventLoop: eventLoop, + logger: .psqlTest + )) + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success(.tag("tag"))), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + private func makePreparedStatementContext(eventLoop: EmbeddedEventLoop) -> PreparedStatementContext { + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + return PreparedStatementContext( + name: "test", + sql: "INSERT INTO test_table (column1) VALUES (1)", + bindings: PostgresBindings(), + bindingDataTypes: [], + logger: .psqlTest, + promise: promise + ) + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift new file mode 100644 index 00000000..bfffef52 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -0,0 +1,186 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Array_PSQLCodableTests: XCTestCase { + + func testArrayTypes() { + + XCTAssertEqual(Bool.psqlArrayType, .boolArray) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual([Bool].psqlType, .boolArray) + + XCTAssertEqual(ByteBuffer.psqlArrayType, .byteaArray) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) + XCTAssertEqual([ByteBuffer].psqlType, .byteaArray) + + XCTAssertEqual(UInt8.psqlArrayType, .charArray) + XCTAssertEqual(UInt8.psqlType, .char) + XCTAssertEqual([UInt8].psqlType, .charArray) + + XCTAssertEqual(Int16.psqlArrayType, .int2Array) + XCTAssertEqual(Int16.psqlType, .int2) + XCTAssertEqual([Int16].psqlType, .int2Array) + + XCTAssertEqual(Int32.psqlArrayType, .int4Array) + XCTAssertEqual(Int32.psqlType, .int4) + XCTAssertEqual([Int32].psqlType, .int4Array) + + XCTAssertEqual(Int64.psqlArrayType, .int8Array) + XCTAssertEqual(Int64.psqlType, .int8) + XCTAssertEqual([Int64].psqlType, .int8Array) + + #if (arch(i386) || arch(arm)) + XCTAssertEqual(Int.psqlArrayType, .int4Array) + XCTAssertEqual(Int.psqlType, .int4) + XCTAssertEqual([Int].psqlType, .int4Array) + #else + XCTAssertEqual(Int.psqlArrayType, .int8Array) + XCTAssertEqual(Int.psqlType, .int8) + XCTAssertEqual([Int].psqlType, .int8Array) + #endif + + XCTAssertEqual(Float.psqlArrayType, .float4Array) + XCTAssertEqual(Float.psqlType, .float4) + XCTAssertEqual([Float].psqlType, .float4Array) + + XCTAssertEqual(Double.psqlArrayType, .float8Array) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual([Double].psqlType, .float8Array) + + XCTAssertEqual(String.psqlArrayType, .textArray) + XCTAssertEqual(String.psqlType, .text) + XCTAssertEqual([String].psqlType, .textArray) + + XCTAssertEqual(UUID.psqlArrayType, .uuidArray) + XCTAssertEqual(UUID.psqlType, .uuid) + XCTAssertEqual([UUID].psqlType, .uuidArray) + + XCTAssertEqual(Date.psqlArrayType, .timestamptzArray) + XCTAssertEqual(Date.psqlType, .timestamptz) + XCTAssertEqual([Date].psqlType, .timestamptzArray) + + XCTAssertEqual(Range.psqlArrayType, .int4RangeArray) + XCTAssertEqual(Range.psqlType, .int4Range) + XCTAssertEqual([Range].psqlType, .int4RangeArray) + + XCTAssertEqual(ClosedRange.psqlArrayType, .int4RangeArray) + XCTAssertEqual(ClosedRange.psqlType, .int4Range) + XCTAssertEqual([ClosedRange].psqlType, .int4RangeArray) + + XCTAssertEqual(Range.psqlArrayType, .int8RangeArray) + XCTAssertEqual(Range.psqlType, .int8Range) + XCTAssertEqual([Range].psqlType, .int8RangeArray) + + XCTAssertEqual(ClosedRange.psqlArrayType, .int8RangeArray) + XCTAssertEqual(ClosedRange.psqlType, .int8Range) + XCTAssertEqual([ClosedRange].psqlType, .int8RangeArray) + } + + func testStringArrayRoundTrip() { + let values = ["foo", "bar", "hello", "world"] + + var buffer = ByteBuffer() + values.encode(into: &buffer, context: .default) + + var result: [String]? + XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) + XCTAssertEqual(values, result) + } + + func testEmptyStringArrayRoundTrip() { + let values: [String] = [] + + var buffer = ByteBuffer() + values.encode(into: &buffer, context: .default) + + var result: [String]? + XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) + XCTAssertEqual(values, result) + } + + func testDecodeFailureIsNotEmptyOutOfScope() { + var buffer = ByteBuffer() + buffer.writeInteger(Int32(2)) // invalid value + buffer.writeInteger(Int32(0)) + buffer.writeInteger(String.psqlType.rawValue) + + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureSecondValueIsUnexpected() { + var buffer = ByteBuffer() + buffer.writeInteger(Int32(0)) // is empty + buffer.writeInteger(Int32(1)) // invalid value, must always be 0 + buffer.writeInteger(String.psqlType.rawValue) + + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureTriesDecodeInt8() { + let value: Int64 = 1 << 32 + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureInvalidNumberOfArrayElements() { + var buffer = ByteBuffer() + buffer.writeInteger(Int32(1)) // invalid value + buffer.writeInteger(Int32(0)) + buffer.writeInteger(String.psqlType.rawValue) + buffer.writeInteger(Int32(-123)) // expected element count + buffer.writeInteger(Int32(1)) // dimensions... must be one + + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureInvalidNumberOfDimensions() { + var buffer = ByteBuffer() + buffer.writeInteger(Int32(1)) // invalid value + buffer.writeInteger(Int32(0)) + buffer.writeInteger(String.psqlType.rawValue) + buffer.writeInteger(Int32(1)) // expected element count + buffer.writeInteger(Int32(2)) // dimensions... must be one + + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeUnexpectedEnd() { + var unexpectedEndInElementLengthBuffer = ByteBuffer() + unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // invalid value + unexpectedEndInElementLengthBuffer.writeInteger(Int32(0)) + unexpectedEndInElementLengthBuffer.writeInteger(String.psqlType.rawValue) + unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // expected element count + unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions + unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 + + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + + var unexpectedEndInElementBuffer = ByteBuffer() + unexpectedEndInElementBuffer.writeInteger(Int32(1)) // invalid value + unexpectedEndInElementBuffer.writeInteger(Int32(0)) + unexpectedEndInElementBuffer.writeInteger(String.psqlType.rawValue) + unexpectedEndInElementBuffer.writeInteger(Int32(1)) // expected element count + unexpectedEndInElementBuffer.writeInteger(Int32(1)) // dimensions + unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 + unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! + + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift new file mode 100644 index 00000000..e6e43f0b --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -0,0 +1,89 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Bool_PSQLCodableTests: XCTestCase { + + // MARK: - Binary + + func testBinaryTrueRoundTrip() { + let value = true + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual(Bool.psqlFormat, .binary) + XCTAssertEqual(buffer.readableBytes, 1) + XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) + + var result: Bool? + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testBinaryFalseRoundTrip() { + let value = false + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual(Bool.psqlFormat, .binary) + XCTAssertEqual(buffer.readableBytes, 1) + XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) + + var result: Bool? + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testBinaryDecodeBoolInvalidLength() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(1)) + + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testBinaryDecodeBoolInvalidValue() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(13)) + + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + // MARK: - Text + + func testTextTrueDecode() { + let value = true + + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "t")) + + var result: Bool? + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) + XCTAssertEqual(value, result) + } + + func testTextFalseDecode() { + let value = false + + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "f")) + + var result: Bool? + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) + XCTAssertEqual(value, result) + } + + func testTextDecodeBoolInvalidValue() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(13)) + + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .text, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift new file mode 100644 index 00000000..9230aee7 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -0,0 +1,53 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Bytes_PSQLCodableTests: XCTestCase { + + func testDataRoundTrip() { + let data = Data((0...UInt8.max)) + + var buffer = ByteBuffer() + data.encode(into: &buffer, context: .default) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) + + var result: Data? + result = Data(from: &buffer, type: .bytea, format: .binary, context: .default) + XCTAssertEqual(data, result) + } + + func testByteBufferRoundTrip() { + let bytes = ByteBuffer(bytes: (0...UInt8.max)) + + var buffer = ByteBuffer() + bytes.encode(into: &buffer, context: .default) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) + + var result: ByteBuffer? + result = ByteBuffer(from: &buffer, type: .bytea, format: .binary, context: .default) + XCTAssertEqual(bytes, result) + } + + func testEncodeSequenceWhereElementUInt8() { + struct ByteSequence: Sequence, PostgresEncodable { + typealias Element = UInt8 + typealias Iterator = Array.Iterator + + let bytes: [UInt8] + + init() { + self.bytes = [UInt8]((0...UInt8.max)) + } + + func makeIterator() -> Array.Iterator { + self.bytes.makeIterator() + } + } + + let sequence = ByteSequence() + var buffer = ByteBuffer() + sequence.encode(into: &buffer, context: .default) + XCTAssertEqual(ByteSequence.psqlType, .bytea) + XCTAssertEqual(buffer.readableBytes, 256) + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift new file mode 100644 index 00000000..3f406598 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -0,0 +1,89 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Date_PSQLCodableTests: XCTestCase { + + func testNowRoundTrip() { + let value = Date() + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Date.psqlType, .timestamptz) + XCTAssertEqual(buffer.readableBytes, 8) + + var result: Date? + XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) + } + + func testDecodeRandomDate() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + + var result: Date? + XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) + XCTAssertNotNil(result) + } + + func testDecodeFailureInvalidLength() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + + XCTAssertThrowsError(try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeDate() { + var firstDateBuffer = ByteBuffer() + firstDateBuffer.writeInteger(Int32.min) + + var firstDate: Date? + XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNotNil(firstDate) + + var lastDateBuffer = ByteBuffer() + lastDateBuffer.writeInteger(Int32.max) + + var lastDate: Date? + XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNotNil(lastDate) + } + + func testDecodeDateFromTimestamp() { + var firstDateBuffer = ByteBuffer() + firstDateBuffer.writeInteger(Int32.min) + + var firstDate: Date? + XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNotNil(firstDate) + + var lastDateBuffer = ByteBuffer() + lastDateBuffer.writeInteger(Int32.max) + + var lastDate: Date? + XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNotNil(lastDate) + } + + func testDecodeDateFailsWithTooMuchData() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(0)) + + XCTAssertThrowsError(try Date(from: &buffer, type: .date, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeDateFailsWithWrongDataType() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(0)) + + XCTAssertThrowsError(try Date(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } + +} diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift new file mode 100644 index 00000000..f9d57397 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -0,0 +1,30 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Decimal_PSQLCodableTests: XCTestCase { + + func testRoundTrip() { + let values: [Decimal] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Decimal.psqlType, .numeric) + + var result: Decimal? + XCTAssertNoThrow(result = try Decimal(from: &buffer, type: .numeric, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + } + + func testDecodeFailureInvalidType() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(0)) + + XCTAssertThrowsError(try Decimal(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } + +} diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift new file mode 100644 index 00000000..728b87b7 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -0,0 +1,134 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Float_PSQLCodableTests: XCTestCase { + + func testRoundTripDoubles() { + let values: [Double] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual(buffer.readableBytes, 8) + + var result: Double? + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + } + + func testRoundTripFloat() { + let values: [Float] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Float.psqlType, .float4) + XCTAssertEqual(buffer.readableBytes, 4) + + var result: Float? + XCTAssertNoThrow(result = try Float(from: &buffer, type: .float4, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + } + + func testRoundTripDoubleNaN() { + let value: Double = .nan + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual(buffer.readableBytes, 8) + + var result: Double? + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertEqual(result?.isNaN, true) + } + + func testRoundTripDoubleInfinity() { + let value: Double = .infinity + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual(buffer.readableBytes, 8) + + var result: Double? + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertEqual(result?.isInfinite, true) + } + + func testRoundTripFromFloatToDouble() { + let values: [Float] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Float.psqlType, .float4) + XCTAssertEqual(buffer.readableBytes, 4) + + var result: Double? + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float4, format: .binary, context: .default)) + XCTAssertEqual(result, Double(value)) + } + } + + func testRoundTripFromDoubleToFloat() { + let values: [Double] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual(buffer.readableBytes, 8) + + var result: Float? + XCTAssertNoThrow(result = try Float(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertEqual(result, Float(value)) + } + } + + func testDecodeFailureInvalidLength() { + var eightByteBuffer = ByteBuffer() + eightByteBuffer.writeInteger(Int64(0)) + var fourByteBuffer = ByteBuffer() + fourByteBuffer.writeInteger(Int32(0)) + + var toLongBuffer1 = eightByteBuffer + XCTAssertThrowsError(try Double(from: &toLongBuffer1, type: .float4, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + + var toLongBuffer2 = eightByteBuffer + XCTAssertThrowsError(try Float(from: &toLongBuffer2, type: .float4, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + + var toShortBuffer1 = fourByteBuffer + XCTAssertThrowsError(try Double(from: &toShortBuffer1, type: .float8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + + var toShortBuffer2 = fourByteBuffer + XCTAssertThrowsError(try Float(from: &toShortBuffer2, type: .float8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureInvalidType() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(0)) + + var copy1 = buffer + XCTAssertThrowsError(try Double(from: ©1, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + + var copy2 = buffer + XCTAssertThrowsError(try Float(from: ©2, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Int+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Int+PSQLCodableTests.swift new file mode 100644 index 00000000..0f58fc72 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Int+PSQLCodableTests.swift @@ -0,0 +1,6 @@ +import XCTest +@testable import PostgresNIO + +class Int_PSQLCodableTests: XCTestCase { + +} diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift new file mode 100644 index 00000000..52dead6a --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -0,0 +1,91 @@ +import XCTest +import Atomics +import NIOCore +@testable import PostgresNIO + +class JSON_PSQLCodableTests: XCTestCase { + + struct Hello: Equatable, Codable, PostgresCodable { + let hello: String + + init(name: String) { + self.hello = name + } + } + + func testRoundTrip() { + var buffer = ByteBuffer() + let hello = Hello(name: "world") + XCTAssertNoThrow(try hello.encode(into: &buffer, context: .default)) + XCTAssertEqual(Hello.psqlType, .jsonb) + + // verify jsonb prefix byte + XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) + + var result: Hello? + XCTAssertNoThrow(result = try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertEqual(result, hello) + } + + func testDecodeFromJSON() { + var buffer = ByteBuffer() + buffer.writeString(#"{"hello":"world"}"#) + + var result: Hello? + XCTAssertNoThrow(result = try Hello(from: &buffer, type: .json, format: .binary, context: .default)) + XCTAssertEqual(result, Hello(name: "world")) + } + + func testDecodeFromJSONAsText() { + let combinations : [(PostgresFormat, PostgresDataType)] = [ + (.text, .json), (.text, .jsonb), + ] + var buffer = ByteBuffer() + buffer.writeString(#"{"hello":"world"}"#) + + for (format, dataType) in combinations { + var loopBuffer = buffer + var result: Hello? + XCTAssertNoThrow(result = try Hello(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertEqual(result, Hello(name: "world")) + } + } + + func testDecodeFromJSONBWithoutVersionPrefixByte() { + var buffer = ByteBuffer() + buffer.writeString(#"{"hello":"world"}"#) + + XCTAssertThrowsError(try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFromJSONBWithWrongDataType() { + var buffer = ByteBuffer() + buffer.writeString(#"{"hello":"world"}"#) + + XCTAssertThrowsError(try Hello(from: &buffer, type: .text, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } + + func testCustomEncoderIsUsed() { + final class TestEncoder: PostgresJSONEncoder { + let encodeHits = ManagedAtomic(0) + + func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { + self.encodeHits.wrappingIncrement(ordering: .relaxed) + } + + func encode(_ value: T) throws -> Data where T : Encodable { + preconditionFailure() + } + } + + let hello = Hello(name: "world") + let encoder = TestEncoder() + var buffer = ByteBuffer() + XCTAssertNoThrow(try hello.encode(into: &buffer, context: .init(jsonEncoder: encoder))) + XCTAssertEqual(encoder.encodeHits.load(ordering: .relaxed), 1) + } +} diff --git a/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift new file mode 100644 index 00000000..a040c3f4 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift @@ -0,0 +1,105 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Range_PSQLCodableTests: XCTestCase { + func testInt32RangeRoundTrip() { + let lowerBound = Int32.min + let upperBound = Int32.max + let value: Range = lowerBound...psqlType, .int4Range) + XCTAssertEqual(buffer.readableBytes, 17) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 2) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int32.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 9, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 13, as: Int32.self), upperBound) + + var result: Range? + XCTAssertNoThrow(result = try Range(from: &buffer, type: .int4Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt32ClosedRangeRoundTrip() { + let lowerBound = Int32.min + let upperBound = Int32.max - 1 + let value: ClosedRange = lowerBound...upperBound + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(ClosedRange.psqlType, .int4Range) + XCTAssertEqual(buffer.readableBytes, 17) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 6) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int32.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 9, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 13, as: Int32.self), upperBound) + + var result: ClosedRange? + XCTAssertNoThrow(result = try ClosedRange(from: &buffer, type: .int4Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64RangeRoundTrip() { + let lowerBound = Int64.min + let upperBound = Int64.max + let value: Range = lowerBound...psqlType, .int8Range) + XCTAssertEqual(buffer.readableBytes, 25) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 2) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int64.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 13, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 17, as: Int64.self), upperBound) + + var result: Range? + XCTAssertNoThrow(result = try Range(from: &buffer, type: .int8Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64ClosedRangeRoundTrip() { + let lowerBound = Int64.min + let upperBound = Int64.max - 1 + let value: ClosedRange = lowerBound...upperBound + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(ClosedRange.psqlType, .int8Range) + XCTAssertEqual(buffer.readableBytes, 25) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 6) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int64.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 13, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 17, as: Int64.self), upperBound) + + var result: ClosedRange? + XCTAssertNoThrow(result = try ClosedRange(from: &buffer, type: .int8Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64RangeDecodeFailureInvalidLength() { + var buffer = ByteBuffer() + buffer.writeInteger(0) + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + + XCTAssertThrowsError(try Range(from: &buffer, type: .int8Range, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testInt64RangeDecodeFailureWrongDataType() { + var buffer = ByteBuffer() + (Int64.min...Int64.max).encode(into: &buffer, context: .default) + + XCTAssertThrowsError(try Range(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift new file mode 100644 index 00000000..0868a4ee --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -0,0 +1,46 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class RawRepresentable_PSQLCodableTests: XCTestCase { + + enum MyRawRepresentable: Int16, PostgresCodable { + case testing = 1 + case staging = 2 + case production = 3 + } + + func testRoundTrip() { + let values: [MyRawRepresentable] = [.testing, .staging, .production] + + for value in values { + var buffer = ByteBuffer() + XCTAssertNoThrow(try value.encode(into: &buffer, context: .default)) + XCTAssertEqual(MyRawRepresentable.psqlType, Int16.psqlType) + XCTAssertEqual(buffer.readableBytes, 2) + + var result: MyRawRepresentable? + XCTAssertNoThrow(result = try MyRawRepresentable(from: &buffer, type: Int16.psqlType, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + } + + func testDecodeInvalidRawTypeValue() { + var buffer = ByteBuffer() + buffer.writeInteger(Int16(4)) // out of bounds + + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int16.psqlType, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeInvalidUnderlyingTypeValue() { + var buffer = ByteBuffer() + buffer.writeInteger(Int32(1)) // out of bounds + + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int32.psqlType, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + +} diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift new file mode 100644 index 00000000..6ff35130 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -0,0 +1,67 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class String_PSQLCodableTests: XCTestCase { + + func testEncode() { + let value = "Hello World" + var buffer = ByteBuffer() + + value.encode(into: &buffer, context: .default) + + XCTAssertEqual(String.psqlType, .text) + XCTAssertEqual(buffer.readString(length: buffer.readableBytes), value) + } + + func testDecodeStringFromTextVarchar() { + let expected = "Hello World" + var buffer = ByteBuffer() + buffer.writeString(expected) + + let dataTypes: [PostgresDataType] = [ + .text, .varchar, .name, .bpchar + ] + + for dataType in dataTypes { + var loopBuffer = buffer + var result: String? + XCTAssertNoThrow(result = try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) + XCTAssertEqual(result, expected) + } + } + + func testDecodeFailureFromInvalidType() { + let buffer = ByteBuffer() + let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array] + + for dataType in dataTypes { + var loopBuffer = buffer + XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } + } + + func testDecodeFromUUID() { + let uuid = UUID() + var buffer = ByteBuffer() + uuid.encode(into: &buffer, context: .default) + + var decoded: String? + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .uuid, format: .binary, context: .default)) + XCTAssertEqual(decoded, uuid.uuidString) + } + + func testDecodeFailureFromInvalidUUID() { + let uuid = UUID() + var buffer = ByteBuffer() + uuid.encode(into: &buffer, context: .default) + // this makes only 15 bytes readable. this should lead to an error + buffer.moveReaderIndex(forwardBy: 1) + + XCTAssertThrowsError(try String(from: &buffer, type: .uuid, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift new file mode 100644 index 00000000..2ca2d1d8 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -0,0 +1,121 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class UUID_PSQLCodableTests: XCTestCase { + + func testRoundTrip() { + for _ in 0..<100 { + let uuid = UUID() + var buffer = ByteBuffer() + + uuid.encode(into: &buffer, context: .default) + + XCTAssertEqual(UUID.psqlType, .uuid) + XCTAssertEqual(UUID.psqlFormat, .binary) + XCTAssertEqual(buffer.readableBytes, 16) + var byteIterator = buffer.readableBytesView.makeIterator() + + XCTAssertEqual(byteIterator.next(), uuid.uuid.0) + XCTAssertEqual(byteIterator.next(), uuid.uuid.1) + XCTAssertEqual(byteIterator.next(), uuid.uuid.2) + XCTAssertEqual(byteIterator.next(), uuid.uuid.3) + XCTAssertEqual(byteIterator.next(), uuid.uuid.4) + XCTAssertEqual(byteIterator.next(), uuid.uuid.5) + XCTAssertEqual(byteIterator.next(), uuid.uuid.6) + XCTAssertEqual(byteIterator.next(), uuid.uuid.7) + XCTAssertEqual(byteIterator.next(), uuid.uuid.8) + XCTAssertEqual(byteIterator.next(), uuid.uuid.9) + XCTAssertEqual(byteIterator.next(), uuid.uuid.10) + XCTAssertEqual(byteIterator.next(), uuid.uuid.11) + XCTAssertEqual(byteIterator.next(), uuid.uuid.12) + XCTAssertEqual(byteIterator.next(), uuid.uuid.13) + XCTAssertEqual(byteIterator.next(), uuid.uuid.14) + XCTAssertEqual(byteIterator.next(), uuid.uuid.15) + + var decoded: UUID? + XCTAssertNoThrow(decoded = try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) + XCTAssertEqual(decoded, uuid) + } + } + + func testDecodeFromString() { + let options: [(PostgresFormat, PostgresDataType)] = [ + (.binary, .text), + (.binary, .varchar), + (.text, .uuid), + (.text, .text), + (.text, .varchar), + ] + + for _ in 0..<100 { + // use uppercase + let uuid = UUID() + var lowercaseBuffer = ByteBuffer() + lowercaseBuffer.writeString(uuid.uuidString.lowercased()) + + for (format, dataType) in options { + var loopBuffer = lowercaseBuffer + var decoded: UUID? + XCTAssertNoThrow(decoded = try UUID(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertEqual(decoded, uuid) + } + + // use lowercase + var uppercaseBuffer = ByteBuffer() + uppercaseBuffer.writeString(uuid.uuidString) + + for (format, dataType) in options { + var loopBuffer = uppercaseBuffer + var decoded: UUID? + XCTAssertNoThrow(decoded = try UUID(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertEqual(decoded, uuid) + } + } + } + + func testDecodeFailureFromBytes() { + let uuid = UUID() + var buffer = ByteBuffer() + + uuid.encode(into: &buffer, context: .default) + // this makes only 15 bytes readable. this should lead to an error + buffer.moveReaderIndex(forwardBy: 1) + + XCTAssertThrowsError(try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) { error in + XCTAssertEqual(error as? PostgresDecodingError.Code, .failure) + } + } + + func testDecodeFailureFromString() { + let uuid = UUID() + var buffer = ByteBuffer() + buffer.writeString(uuid.uuidString) + // this makes only 15 bytes readable. this should lead to an error + buffer.moveReaderIndex(forwardBy: 1) + + let dataTypes: [PostgresDataType] = [.varchar, .text] + + for dataType in dataTypes { + var loopBuffer = buffer + XCTAssertThrowsError(try UUID(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + } + + func testDecodeFailureFromInvalidPostgresType() { + let uuid = UUID() + var buffer = ByteBuffer() + buffer.writeString(uuid.uuidString) + + let dataTypes: [PostgresDataType] = [.bool, .int8, .int2, .int4Array] + + for dataType in dataTypes { + var copy = buffer + XCTAssertThrowsError(try UUID(from: ©, type: dataType, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) + } + } + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift new file mode 100644 index 00000000..7d073873 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -0,0 +1,23 @@ +import NIOCore +@testable import PostgresNIO + +extension ByteBuffer { + mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { + self.writeInteger(messageID.rawValue) + } + + static func backendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer { + var byteBuffer = ByteBuffer() + try byteBuffer.writeBackendMessage(id: id, payload) + return byteBuffer + } + + mutating func writeBackendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows { + self.psqlWriteBackendMessageID(id) + let lengthIndex = self.writerIndex + self.writeInteger(Int32(0)) + try payload(&self) + let length = self.writerIndex - lengthIndex + self.setInteger(Int32(length), at: lengthIndex) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift new file mode 100644 index 00000000..9a1224d8 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -0,0 +1,120 @@ +import class Foundation.JSONEncoder +import NIOCore +@testable import PostgresNIO + +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { + public static func == (lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.read, read): + return true + case (.wait, .wait): + return true + case (.provideAuthenticationContext, .provideAuthenticationContext): + return true + case (.sendStartupMessage, sendStartupMessage): + return true + case (.sendSSLRequest, sendSSLRequest): + return true + case (.establishSSLConnection, establishSSLConnection): + return true + case (.closeConnectionAndCleanup(let lhs), .closeConnectionAndCleanup(let rhs)): + return lhs == rhs + case (.sendPasswordMessage(let lhsMethod, let lhsAuthContext), sendPasswordMessage(let rhsMethod, let rhsAuthContext)): + return lhsMethod == rhsMethod && lhsAuthContext == rhsAuthContext + case (.sendParseDescribeBindExecuteSync(let lquery), sendParseDescribeBindExecuteSync(let rquery)): + return lquery == rquery + case (.fireEventReadyForQuery, .fireEventReadyForQuery): + return true + case (.succeedQuery(let lhsPromise, let lhsResult), .succeedQuery(let rhsPromise, let rhsResult)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsResult.value == rhsResult.value + case (.failQuery(let lhsPromise, let lhsError, let lhsCleanupContext), .failQuery(let rhsPromise, let rhsError, let rhsCleanupContext)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext + case (.forwardRows(let lhsRows), .forwardRows(let rhsRows)): + return lhsRows == rhsRows + case (.forwardStreamComplete(let lhsBuffer, let lhsCommandTag), .forwardStreamComplete(let rhsBuffer, let rhsCommandTag)): + return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag + case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): + return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext + case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes)): + return lhsName == rhsName && lhsQuery == rhsQuery && lhsDataTypes == rhsDataTypes + case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription + case (.fireChannelInactive, .fireChannelInactive): + return true + default: + return false + } + } +} + +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol' +extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable { + public static func == (lhs: Self, rhs: Self) -> Bool { + guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else { + return false + } + + guard lhs.error == rhs.error else { + return false + } + + guard lhs.tasks == rhs.tasks else { + return false + } + + return true + } +} + +extension ConnectionStateMachine { + static func readyForQuery(transactionState: PostgresBackendMessage.TransactionState = .idle) -> Self { + let connectionContext = Self.createConnectionContext(transactionState: transactionState) + return ConnectionStateMachine(.readyForQuery(connectionContext)) + } + + static func createConnectionContext(transactionState: PostgresBackendMessage.TransactionState = .idle) -> ConnectionContext { + let backendKeyData = BackendKeyData(processID: 2730, secretKey: 882037977) + + let paramaters = [ + "DateStyle": "ISO, MDY", + "application_name": "", + "server_encoding": "UTF8", + "integer_datetimes": "on", + "client_encoding": "UTF8", + "TimeZone": "Etc/UTC", + "is_superuser": "on", + "server_version": "13.1 (Debian 13.1-1.pgdg100+1)", + "session_authorization": "postgres", + "IntervalStyle": "postgres", + "standard_conforming_strings": "on" + ] + + return ConnectionContext( + backendKeyData: backendKeyData, + parameters: paramaters, + transactionState: transactionState + ) + } +} + +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLError: Swift.Equatable { + public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool { + return true + } +} + +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLTask: Swift.Equatable { + public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool { + switch (lhs, rhs) { + case (.extendedQuery(let lhs), .extendedQuery(let rhs)): + return lhs === rhs + case (.closeCommand(let lhs), .closeCommand(let rhs)): + return lhs === rhs + default: + return false + } + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift new file mode 100644 index 00000000..9614bf1e --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -0,0 +1,263 @@ +import NIOCore +@testable import PostgresNIO + +struct PSQLBackendMessageEncoder: MessageToByteEncoder { + typealias OutboundIn = PostgresBackendMessage + + /// Called once there is data to encode. + /// + /// - parameters: + /// - data: The data to encode into a `ByteBuffer`. + /// - out: The `ByteBuffer` into which we want to encode. + func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) { + switch message { + case .authentication(let authentication): + self.encode(messageID: message.id, payload: authentication, into: &buffer) + + case .backendKeyData(let keyData): + self.encode(messageID: message.id, payload: keyData, into: &buffer) + + case .bindComplete, + .closeComplete, + .emptyQueryResponse, + .noData, + .parseComplete, + .portalSuspended: + self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) + + case .commandComplete(let string): + self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + + case .dataRow(let row): + self.encode(messageID: message.id, payload: row, into: &buffer) + + case .error(let errorResponse): + self.encode(messageID: message.id, payload: errorResponse, into: &buffer) + + case .notice(let noticeResponse): + self.encode(messageID: message.id, payload: noticeResponse, into: &buffer) + + case .notification(let notificationResponse): + self.encode(messageID: message.id, payload: notificationResponse, into: &buffer) + + case .parameterDescription(let description): + self.encode(messageID: message.id, payload: description, into: &buffer) + + case .parameterStatus(let status): + self.encode(messageID: message.id, payload: status, into: &buffer) + + case .readyForQuery(let transactionState): + self.encode(messageID: message.id, payload: transactionState, into: &buffer) + + case .rowDescription(let description): + self.encode(messageID: message.id, payload: description, into: &buffer) + + case .sslSupported: + buffer.writeInteger(UInt8(ascii: "S")) + + case .sslUnsupported: + buffer.writeInteger(UInt8(ascii: "N")) + } + } + + private struct EmptyPayload: PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) {} + } + + private struct StringPayload: PSQLMessagePayloadEncodable { + var string: String + init(_ string: String) { self.string = string } + func encode(into buffer: inout ByteBuffer) { + buffer.writeNullTerminatedString(self.string) + } + } + + private func encode( + messageID: PostgresBackendMessage.ID, + payload: Payload, + into buffer: inout ByteBuffer) + { + buffer.psqlWriteBackendMessageID(messageID) + let startIndex = buffer.writerIndex + buffer.writeInteger(Int32(0)) // placeholder for length + payload.encode(into: &buffer) + let length = Int32(buffer.writerIndex - startIndex) + buffer.setInteger(length, at: startIndex) + } +} + +extension PostgresBackendMessage { + var id: ID { + switch self { + case .authentication: + return .authentication + case .backendKeyData: + return .backendKeyData + case .bindComplete: + return .bindComplete + case .closeComplete: + return .closeComplete + case .commandComplete: + return .commandComplete + case .dataRow: + return .dataRow + case .emptyQueryResponse: + return .emptyQueryResponse + case .error: + return .error + case .noData: + return .noData + case .notice: + return .noticeResponse + case .notification: + return .notificationResponse + case .parameterDescription: + return .parameterDescription + case .parameterStatus: + return .parameterStatus + case .parseComplete: + return .parseComplete + case .portalSuspended: + return .portalSuspended + case .readyForQuery: + return .readyForQuery + case .rowDescription: + return .rowDescription + case .sslSupported, + .sslUnsupported: + preconditionFailure("Message has no id.") + } + } +} + +extension PostgresBackendMessage.Authentication: PSQLMessagePayloadEncodable { + + public func encode(into buffer: inout ByteBuffer) { + switch self { + case .ok: + buffer.writeInteger(Int32(0)) + + case .kerberosV5: + buffer.writeInteger(Int32(2)) + + case .plaintext: + buffer.writeInteger(Int32(3)) + + case .md5(salt: let salt): + buffer.writeMultipleIntegers(Int32(5), salt) + + case .scmCredential: + buffer.writeInteger(Int32(6)) + + case .gss: + buffer.writeInteger(Int32(7)) + + case .gssContinue(var data): + buffer.writeInteger(Int32(8)) + buffer.writeBuffer(&data) + + case .sspi: + buffer.writeInteger(Int32(9)) + + case .sasl(names: let names): + buffer.writeInteger(Int32(10)) + for name in names { + buffer.writeNullTerminatedString(name) + } + + case .saslContinue(data: var data): + buffer.writeInteger(Int32(11)) + buffer.writeBuffer(&data) + + case .saslFinal(data: var data): + buffer.writeInteger(Int32(12)) + buffer.writeBuffer(&data) + } + } + +} + +extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.processID) + buffer.writeInteger(self.secretKey) + } +} + +extension DataRow: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.columnCount, as: Int16.self) + buffer.writeBytes(self.bytes.readableBytesView) + } +} + +extension PostgresBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + for (key, value) in self.fields { + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } +} + +extension PostgresBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + for (key, value) in self.fields { + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } +} + +extension PostgresBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.backendPID) + buffer.writeNullTerminatedString(self.channel) + buffer.writeNullTerminatedString(self.payload) + } +} + +extension PostgresBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int16(self.dataTypes.count)) + + for dataType in self.dataTypes { + buffer.writeInteger(dataType.rawValue) + } + } +} + +extension PostgresBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeNullTerminatedString(self.parameter) + buffer.writeNullTerminatedString(self.value) + } +} + +extension PostgresBackendMessage.TransactionState: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.rawValue) + } +} + +extension RowDescription: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int16(self.columns.count)) + + for column in self.columns { + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + } +} + +protocol PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift new file mode 100644 index 00000000..55ccd0a9 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -0,0 +1,245 @@ +@testable import PostgresNIO +import NIOCore + +struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { + typealias InboundOut = PostgresFrontendMessage + + private(set) var isInStartup: Bool + + init() { + self.isInStartup = true + } + + mutating func decode(buffer: inout ByteBuffer) throws -> PostgresFrontendMessage? { + // make sure we have at least one byte to read + guard buffer.readableBytes > 0 else { + return nil + } + + if self.isInStartup { + guard let length = buffer.getInteger(at: buffer.readerIndex, as: UInt32.self) else { + return nil + } + + guard var messageSlice = buffer.getSlice(at: buffer.readerIndex + 4, length: Int(length) - 4) else { + return nil + } + buffer.moveReaderIndex(to: Int(length)) + let finalIndex = buffer.readerIndex + + guard let code = messageSlice.readInteger(as: UInt32.self) else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self) + } + + switch code { + case 80877103: + self.isInStartup = true + return .sslRequest + + case 196608: + var user: String? + var database: String? + var options = [(String, String)]() + + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { + let value = messageSlice.readNullTerminatedString() + + switch name { + case "user": + user = value + + case "database": + database = value + + default: + if let value = value { + options.append((name, value)) + } + } + } + + let parameters = PostgresFrontendMessage.Startup.Parameters( + user: user!, + database: database, + options: options, + replication: .false + ) + + let startup = PostgresFrontendMessage.Startup( + protocolVersion: 0x00_03_00_00, + parameters: parameters + ) + + precondition(buffer.readerIndex == finalIndex) + self.isInStartup = false + + return .startup(startup) + + default: + throw PostgresMessageDecodingError.unknownStartupCodeReceived(code: code, messageBytes: messageSlice) + } + } + + // all other packages have an Int32 after the identifier that determines their length. + // do we have enough bytes for that? + guard let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), + let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self) else { + return nil + } + + // At this point we are sure, that we have enough bytes to decode the next message. + // 1. Create a byteBuffer that represents exactly the next message. This can be force + // unwrapped, since it was verified that enough bytes are available. + guard let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length)) else { + return nil + } + + // 2. make sure we have a known message identifier + guard let messageID = PostgresFrontendMessage.ID(rawValue: idByte) else { + throw PostgresMessageDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) + } + + // 3. decode the message + do { + // get a mutable byteBuffer copy + var slice = completeMessageBuffer + // move reader index forward by five bytes + slice.moveReaderIndex(forwardBy: 5) + + return try PostgresFrontendMessage.decode(from: &slice, for: messageID) + } catch let error as PSQLPartialDecodingError { + throw PostgresMessageDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) + } catch { + preconditionFailure("Expected to only see `PartialDecodingError`s here.") + } + } + + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PostgresFrontendMessage? { + try self.decode(buffer: &buffer) + } +} + +extension PostgresFrontendMessage { + + static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresFrontendMessage { + switch messageID { + case .bind: + guard let portalName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let preparedStatementName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let parameterFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let parameterFormats = (0.. ByteBuffer? in + let length = buffer.readInteger(as: UInt32.self) + switch length { + case .some(..<0): + return nil + case .some(0...): + return buffer.readSlice(length: Int(length!)) + default: + preconditionFailure("TODO: Unimplemented") + } + } + + guard let resultColumnFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let resultColumnFormats = (0.. Self + { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PostgresMessageDecodingError( + messageID: 0, + payload: data.base64EncodedString(), + description: "Received a startup code '\(code)'. There is no message associated with this code.", + file: file, + line: line) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift new file mode 100644 index 00000000..2532959a --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -0,0 +1,293 @@ +import NIOCore +import PostgresNIO + +/// A wire message that is created by a Postgres client to be consumed by Postgres server. +/// +/// All messages are defined in the official Postgres Documentation in the section +/// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) +enum PostgresFrontendMessage: Equatable { + + struct Bind: Hashable { + /// The name of the destination portal (an empty string selects the unnamed portal). + var portalName: String + + /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). + var preparedStatementName: String + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var parameterFormats: [PostgresFormat] + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var parameters: [ByteBuffer?] + + var resultColumnFormats: [PostgresFormat] + } + + struct Cancel: Equatable { + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let requestCode: Int32 = 80877102 + + /// The process ID of the target backend. + let processID: Int32 + + /// The secret key for the target backend. + let secretKey: Int32 + } + + enum Close: Hashable { + case preparedStatement(String) + case portal(String) + } + + enum Describe: Hashable { + case preparedStatement(String) + case portal(String) + } + + struct Execute: Hashable { + /// The name of the portal to execute (an empty string selects the unnamed portal). + let portalName: String + + /// Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes “no limit”. + let maxNumberOfRows: Int32 + + init(portalName: String, maxNumberOfRows: Int32 = 0) { + self.portalName = portalName + self.maxNumberOfRows = maxNumberOfRows + } + } + + struct Parse: Hashable { + /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). + let preparedStatementName: String + + /// The query string to be parsed. + let query: String + + /// The number of parameter data types specified (can be zero). Note that this is not an indication of the number of parameters that might appear in the query string, only the number that the frontend wants to prespecify types for. + let parameters: [PostgresDataType] + } + + struct Password: Hashable { + let value: String + } + + struct SASLInitialResponse: Hashable { + + let saslMechanism: String + let initialData: [UInt8] + + /// Creates a new `SSLRequest`. + init(saslMechanism: String, initialData: [UInt8]) { + self.saslMechanism = saslMechanism + self.initialData = initialData + } + } + + struct SASLResponse: Hashable { + var data: [UInt8] + + /// Creates a new `SSLRequest`. + init(data: [UInt8]) { + self.data = data + } + } + + /// A message asking the PostgreSQL server if TLS is supported + /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 + struct SSLRequest: Hashable { + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let requestCode: Int32 = 80877103 + } + + struct Startup: Equatable { + static let versionThree: Int32 = 0x00_03_00_00 + + /// Creates a `Startup` with "3.0" as the protocol version. + static func versionThree(parameters: Parameters) -> Startup { + return .init(protocolVersion: Self.versionThree, parameters: parameters) + } + + /// The protocol version number. The most significant 16 bits are the major + /// version number (3 for the protocol described here). The least significant + /// 16 bits are the minor version number (0 for the protocol described here). + var protocolVersion: Int32 + + /// The protocol version number is followed by one or more pairs of parameter + /// name and value strings. A zero byte is required as a terminator after + /// the last name/value pair. `user` is required, others are optional. + struct Parameters: Equatable { + enum Replication { + case `true` + case `false` + case database + } + + /// The database user name to connect as. Required; there is no default. + var user: String + + /// The database to connect to. Defaults to the user name. + var database: String? + + /// Command-line arguments for the backend. (This is deprecated in favor + /// of setting individual run-time parameters.) Spaces within this string are + /// considered to separate arguments, unless escaped with a + /// backslash (\); write \\ to represent a literal backslash. + var options: [(String, String)] + + /// Used to connect in streaming replication mode, where a small set of + /// replication commands can be issued instead of SQL statements. Value + /// can be true, false, or database, and the default is false. + var replication: Replication + + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.user == rhs.user + && lhs.database == rhs.database + && lhs.replication == rhs.replication + && lhs.options.count == rhs.options.count + else { + return false + } + + var lhsIterator = lhs.options.makeIterator() + var rhsIterator = rhs.options.makeIterator() + + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else { + return false + } + } + return true + } + + } + + var parameters: Parameters + } + + case bind(Bind) + case cancel(Cancel) + case close(Close) + case describe(Describe) + case execute(Execute) + case flush + case parse(Parse) + case password(Password) + case saslInitialResponse(SASLInitialResponse) + case saslResponse(SASLResponse) + case sslRequest + case sync + case startup(Startup) + case terminate + + enum ID: UInt8, Equatable { + + case bind + case close + case describe + case execute + case flush + case parse + case password + case saslInitialResponse + case saslResponse + case sync + case terminate + + init?(rawValue: UInt8) { + switch rawValue { + case UInt8(ascii: "B"): + self = .bind + case UInt8(ascii: "C"): + self = .close + case UInt8(ascii: "D"): + self = .describe + case UInt8(ascii: "E"): + self = .execute + case UInt8(ascii: "H"): + self = .flush + case UInt8(ascii: "P"): + self = .parse + case UInt8(ascii: "p"): + self = .password + case UInt8(ascii: "p"): + self = .saslInitialResponse + case UInt8(ascii: "p"): + self = .saslResponse + case UInt8(ascii: "S"): + self = .sync + case UInt8(ascii: "X"): + self = .terminate + default: + return nil + } + } + + var rawValue: UInt8 { + switch self { + case .bind: + return UInt8(ascii: "B") + case .close: + return UInt8(ascii: "C") + case .describe: + return UInt8(ascii: "D") + case .execute: + return UInt8(ascii: "E") + case .flush: + return UInt8(ascii: "H") + case .parse: + return UInt8(ascii: "P") + case .password: + return UInt8(ascii: "p") + case .saslInitialResponse: + return UInt8(ascii: "p") + case .saslResponse: + return UInt8(ascii: "p") + case .sync: + return UInt8(ascii: "S") + case .terminate: + return UInt8(ascii: "X") + } + } + } +} + +extension PostgresFrontendMessage { + + var id: ID { + switch self { + case .bind: + return .bind + case .cancel: + preconditionFailure("Cancel messages don't have an identifier") + case .close: + return .close + case .describe: + return .describe + case .execute: + return .execute + case .flush: + return .flush + case .parse: + return .parse + case .password: + return .password + case .saslInitialResponse: + return .saslInitialResponse + case .saslResponse: + return .saslResponse + case .sslRequest: + preconditionFailure("SSL requests don't have an identifier") + case .startup: + preconditionFailure("Startup messages don't have an identifier") + case .sync: + return .sync + case .terminate: + return .terminate + + } + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ReverseByteToMessageHandler.swift b/Tests/PostgresNIOTests/New/Extensions/ReverseByteToMessageHandler.swift new file mode 100644 index 00000000..654a2546 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ReverseByteToMessageHandler.swift @@ -0,0 +1,36 @@ +import NIOCore + +/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes +/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages +/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s. +class ReverseByteToMessageHandler: ChannelOutboundHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = Decoder.InboundOut + + let processor: NIOSingleStepByteToMessageProcessor + + init(_ decoder: Decoder) { + self.processor = .init(decoder, maximumBufferSize: nil) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = self.unwrapOutboundIn(data) + + do { + var messages = [Decoder.InboundOut]() + try self.processor.process(buffer: buffer) { message in + messages.append(message) + } + + for (index, message) in messages.enumerated() { + if index == messages.index(before: messages.endIndex) { + context.write(self.wrapOutboundOut(message), promise: promise) + } else { + context.write(self.wrapOutboundOut(message), promise: nil) + } + } + } catch { + context.fireErrorCaught(error) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift b/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift new file mode 100644 index 00000000..135c881d --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift @@ -0,0 +1,32 @@ +import NIOCore + +/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes +/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages +/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s. +class ReverseMessageToByteHandler: ChannelInboundHandler { + typealias InboundIn = Encoder.OutboundIn + typealias InboundOut = ByteBuffer + + var byteBuffer: ByteBuffer! + let encoder: Encoder + + init(_ encoder: Encoder) { + self.encoder = encoder + } + + func handlerAdded(context: ChannelHandlerContext) { + self.byteBuffer = context.channel.allocator.buffer(capacity: 128) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let message = self.unwrapInboundIn(data) + + do { + self.byteBuffer.clear() + try self.encoder.encode(data: message, out: &self.byteBuffer) + context.fireChannelRead(self.wrapInboundOut(self.byteBuffer)) + } catch { + context.fireErrorCaught(error) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift new file mode 100644 index 00000000..06e39aae --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -0,0 +1,47 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class AuthenticationTests: XCTestCase { + + func testDecodeAuthentication() { + var expected = [PostgresBackendMessage]() + var buffer = ByteBuffer() + let encoder = PSQLBackendMessageEncoder() + + // add ok + encoder.encode(data: .authentication(.ok), out: &buffer) + expected.append(.authentication(.ok)) + + // add kerberos + encoder.encode(data: .authentication(.kerberosV5), out: &buffer) + expected.append(.authentication(.kerberosV5)) + + // add plaintext + encoder.encode(data: .authentication(.plaintext), out: &buffer) + expected.append(.authentication(.plaintext)) + + // add md5 + let salt: UInt32 = 0x01_02_03_04 + encoder.encode(data: .authentication(.md5(salt: salt)), out: &buffer) + expected.append(.authentication(.md5(salt: salt))) + + // add scm credential + encoder.encode(data: .authentication(.scmCredential), out: &buffer) + expected.append(.authentication(.scmCredential)) + + // add gss + encoder.encode(data: .authentication(.gss), out: &buffer) + expected.append(.authentication(.gss)) + + // add sspi + encoder.encode(data: .authentication(.sspi), out: &buffer) + expected.append(.authentication(.sspi)) + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + )) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift new file mode 100644 index 00000000..d41607e3 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -0,0 +1,39 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class BackendKeyDataTests: XCTestCase { + func testDecode() { + let buffer = ByteBuffer.backendMessage(id: .backendKeyData) { buffer in + buffer.writeInteger(Int32(1234)) + buffer.writeInteger(Int32(4567)) + } + + let expectedInOuts = [ + (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), + ] + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + } + + func testDecodeInvalidLength() { + var buffer = ByteBuffer() + buffer.psqlWriteBackendMessageID(.backendKeyData) + buffer.writeInteger(Int32(11)) + buffer.writeInteger(Int32(1234)) + buffer.writeInteger(Int32(4567)) + + let expected = [ + (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), + ] + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expected, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift new file mode 100644 index 00000000..d5ec5b30 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -0,0 +1,47 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class BindTests: XCTestCase { + + func testEncodeBind() { + var bindings = PostgresBindings() + bindings.append("Hello", context: .default) + bindings.append("World", context: .default) + + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + + encoder.bind(portalName: "", preparedStatementName: "", bind: bindings) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 37) + XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + // the number of parameters + XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) + // all (two) parameters have the same format (binary) + XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + + // read number of parameters + XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) + + // hello length + XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual("Hello", byteBuffer.readString(length: 5)) + + // world length + XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual("World", byteBuffer.readString(length: 5)) + + // all response values have the same format: therefore one format byte is next + XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + // all response values have the same format (binary) + XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + + // nothing left to read + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift new file mode 100644 index 00000000..5548aae3 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -0,0 +1,21 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class CancelTests: XCTestCase { + + func testEncodeCancel() { + let processID: Int32 = 1234 + let secretKey: Int32 = 4567 + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.cancel(processID: processID, secretKey: secretKey) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 16) + XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length + XCTAssertEqual(80877102, byteBuffer.readInteger(as: Int32.self)) // cancel request code + XCTAssertEqual(processID, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(secretKey, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift new file mode 100644 index 00000000..a8e1cfeb --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -0,0 +1,31 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class CloseTests: XCTestCase { + func testEncodeClosePortal() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePortal("Hello") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 12) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testEncodeCloseUnnamedStatement() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 7) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift new file mode 100644 index 00000000..a90d1e93 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -0,0 +1,145 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class DataRowTests: XCTestCase { + func testDecode() { + let buffer = ByteBuffer.backendMessage(id: .dataRow) { buffer in + // the data row has 3 columns + buffer.writeInteger(3, as: Int16.self) + + // this is a null value + buffer.writeInteger(-1, as: Int32.self) + + // this is an empty value. for example a empty string + buffer.writeInteger(0, as: Int32.self) + + // this is a column with ten bytes + buffer.writeInteger(10, as: Int32.self) + buffer.writeBytes([UInt8](repeating: 5, count: 10)) + } + + let rowSlice = buffer.getSlice(at: 7, length: buffer.readableBytes - 7)! + + let expectedInOuts = [ + (buffer, [PostgresBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), + ] + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + } + + func testIteratingElements() { + let dataRow = DataRow.makeTestDataRow(nil, ByteBuffer(), ByteBuffer(repeating: 5, count: 10)) + var iterator = dataRow.makeIterator() + + XCTAssertEqual(dataRow.count, 3) + XCTAssertEqual(iterator.next(), .some(.none)) + XCTAssertEqual(iterator.next(), ByteBuffer()) + XCTAssertEqual(iterator.next(), ByteBuffer(repeating: 5, count: 10)) + XCTAssertEqual(iterator.next(), .none) + } + + func testIndexAfterAndSubscript() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + var index = dataRow.startIndex + XCTAssertEqual(dataRow[index], .none) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], ByteBuffer()) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], ByteBuffer(repeating: 5, count: 10)) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], .none) + index = dataRow.index(after: index) + XCTAssertEqual(index, dataRow.endIndex) + } + + func testIndexComparison() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + let startIndex = dataRow.startIndex + let secondIndex = dataRow.index(after: startIndex) + + XCTAssertLessThanOrEqual(startIndex, secondIndex) + XCTAssertLessThan(startIndex, secondIndex) + + XCTAssertGreaterThanOrEqual(secondIndex, startIndex) + XCTAssertGreaterThan(secondIndex, startIndex) + + XCTAssertFalse(secondIndex == startIndex) + XCTAssertEqual(secondIndex, secondIndex) + XCTAssertEqual(startIndex, startIndex) + } + + func testColumnSubscript() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + XCTAssertEqual(dataRow.count, 4) + XCTAssertEqual(dataRow[column: 0], .none) + XCTAssertEqual(dataRow[column: 1], ByteBuffer()) + XCTAssertEqual(dataRow[column: 2], ByteBuffer(repeating: 5, count: 10)) + XCTAssertEqual(dataRow[column: 3], .none) + } + + func testWithContiguousStorageIfAvailable() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + XCTAssertNil(dataRow.withContiguousStorageIfAvailable { _ in + return XCTFail("DataRow does not have a contiguous storage") + }) + } +} + +extension PostgresNIO.DataRow: Swift.ExpressibleByArrayLiteral { + public typealias ArrayLiteralElement = PostgresEncodable + + public init(arrayLiteral elements: PostgresEncodable...) { + + var buffer = ByteBuffer() + let encodingContext = PostgresEncodingContext(jsonEncoder: JSONEncoder()) + elements.forEach { element in + try! element.encodeRaw(into: &buffer, context: encodingContext) + } + + self.init(columnCount: Int16(elements.count), bytes: buffer) + } + + static func makeTestDataRow(_ buffers: ByteBuffer?...) -> DataRow { + var bytes = ByteBuffer() + buffers.forEach { column in + switch column { + case .none: + bytes.writeInteger(Int32(-1)) + case .some(var input): + bytes.writeInteger(Int32(input.readableBytes)) + bytes.writeBuffer(&input) + } + } + + return DataRow(columnCount: Int16(buffers.count), bytes: bytes) + } +} + diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift new file mode 100644 index 00000000..cb3c745b --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -0,0 +1,33 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class DescribeTests: XCTestCase { + + func testEncodeDescribePortal() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePortal("Hello") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 12) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testEncodeDescribeUnnamedStatement() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 7) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + +} diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift new file mode 100644 index 00000000..80015ea0 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -0,0 +1,35 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class ErrorResponseTests: XCTestCase { + + func testDecode() { + let fields: [PostgresBackendMessage.Field : String] = [ + .file: "auth.c", + .routine: "auth_failed", + .line: "334", + .localizedSeverity: "FATAL", + .sqlState: "28P01", + .severity: "FATAL", + .message: "password authentication failed for user \"postgre3\"", + ] + + let buffer = ByteBuffer.backendMessage(id: .error) { buffer in + fields.forEach { (key, value) in + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } + + let expectedInOuts = [ + (buffer, [PostgresBackendMessage.error(.init(fields: fields))]), + ] + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift new file mode 100644 index 00000000..834ad0dd --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -0,0 +1,18 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class ExecuteTests: XCTestCase { + + func testEncodeExecute() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.execute(portalName: "", maxNumberOfRows: 0) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) + XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift new file mode 100644 index 00000000..9a8a1529 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -0,0 +1,62 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class NotificationResponseTests: XCTestCase { + + func testDecode() { + let expected: [PostgresBackendMessage] = [ + .notification(.init(backendPID: 123, channel: "test", payload: "hello")), + .notification(.init(backendPID: 123, channel: "test", payload: "world")), + .notification(.init(backendPID: 123, channel: "foo", payload: "bar")) + ] + + var buffer = ByteBuffer() + expected.forEach { message in + guard case .notification(let notification) = message else { + return XCTFail("Expected only to get notifications here!") + } + + buffer.writeBackendMessage(id: .notificationResponse) { buffer in + buffer.writeInteger(notification.backendPID) + buffer.writeNullTerminatedString(notification.channel) + buffer.writeNullTerminatedString(notification.payload) + } + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + } + + func testDecodeFailureBecauseOfMissingNullTermination() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .notificationResponse) { buffer in + buffer.writeInteger(Int32(123)) + buffer.writeString("test") + buffer.writeString("hello") + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingNullTerminationInValue() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .notificationResponse) { buffer in + buffer.writeInteger(Int32(123)) + buffer.writeNullTerminatedString("hello") + buffer.writeString("world") + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift new file mode 100644 index 00000000..a6bc32a1 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -0,0 +1,69 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class ParameterDescriptionTests: XCTestCase { + + func testDecode() { + let expected: [PostgresBackendMessage] = [ + .parameterDescription(.init(dataTypes: [.bool, .varchar, .uuid, .json, .jsonbArray])), + ] + + var buffer = ByteBuffer() + expected.forEach { message in + guard case .parameterDescription(let description) = message else { + return XCTFail("Expected only to get parameter descriptions here!") + } + + buffer.writeBackendMessage(id: .parameterDescription) { buffer in + buffer.writeInteger(Int16(description.dataTypes.count)) + + description.dataTypes.forEach { dataType in + buffer.writeInteger(dataType.rawValue) + } + } + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + } + + func testDecodeWithNegativeCount() { + let dataTypes: [PostgresDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .parameterDescription) { buffer in + buffer.writeInteger(Int16(-4)) + + dataTypes.forEach { dataType in + buffer.writeInteger(dataType.rawValue) + } + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeColumnCountDoesntMatchMessageLength() { + let dataTypes: [PostgresDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .parameterDescription) { buffer in + // means three columns comming, but 5 are in the buffer actually. + buffer.writeInteger(Int16(3)) + + dataTypes.forEach { dataType in + buffer.writeInteger(dataType.rawValue) + } + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift new file mode 100644 index 00000000..4513bbce --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -0,0 +1,75 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class ParameterStatusTests: XCTestCase { + + func testDecode() { + var buffer = ByteBuffer() + + let expected: [PostgresBackendMessage] = [ + .parameterStatus(.init(parameter: "DateStyle", value: "ISO, MDY")), + .parameterStatus(.init(parameter: "application_name", value: "")), + .parameterStatus(.init(parameter: "server_encoding", value: "UTF8")), + .parameterStatus(.init(parameter: "integer_datetimes", value: "on")), + .parameterStatus(.init(parameter: "client_encoding", value: "UTF8")), + .parameterStatus(.init(parameter: "TimeZone", value: "Etc/UTC")), + .parameterStatus(.init(parameter: "is_superuser", value: "on")), + .parameterStatus(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), + .parameterStatus(.init(parameter: "session_authorization", value: "postgres")), + .parameterStatus(.init(parameter: "IntervalStyle", value: "postgres")), + .parameterStatus(.init(parameter: "standard_conforming_strings", value: "on")), + .backendKeyData(.init(processID: 1234, secretKey: 5678)) + ] + + expected.forEach { message in + switch message { + case .parameterStatus(let parameterStatus): + buffer.writeBackendMessage(id: .parameterStatus) { buffer in + buffer.writeNullTerminatedString(parameterStatus.parameter) + buffer.writeNullTerminatedString(parameterStatus.value) + } + case .backendKeyData(let backendKeyData): + buffer.writeBackendMessage(id: .backendKeyData) { buffer in + buffer.writeInteger(backendKeyData.processID) + buffer.writeInteger(backendKeyData.secretKey) + } + default: + XCTFail("Unexpected message type") + } + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + } + + func testDecodeFailureBecauseOfMissingNullTermination() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .parameterStatus) { buffer in + buffer.writeString("DateStyle") + buffer.writeString("ISO, MDY") + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingNullTerminationInValue() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .parameterStatus) { buffer in + buffer.writeNullTerminatedString("DateStyle") + buffer.writeString("ISO, MDY") + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift new file mode 100644 index 00000000..9f81e4e4 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -0,0 +1,35 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class ParseTests: XCTestCase { + func testEncode() { + let preparedStatementName = "test" + let query = "SELECT version()" + let parameters: [PostgresDataType] = [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray] + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.parse( + preparedStatementName: preparedStatementName, + query: query, + parameters: parameters + ) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (preparedStatementName.count + 1) + (query.count + 1) + 2 + parameters.count * 4 + + // 1 id + // + 4 length + // + 4 preparedStatement (3 + 1 null terminator) + // + 1 query () + + XCTAssertEqual(byteBuffer.readableBytes, length) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), preparedStatementName) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), query) + XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parameters.count)) + for dataType in parameters { + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), dataType.rawValue) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift new file mode 100644 index 00000000..4a4833d2 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -0,0 +1,21 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class PasswordTests: XCTestCase { + + func testEncodePassword() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 + let password = "md522d085ed8dc3377968dc1c1a40519a2a" + encoder.password(password.utf8) + var byteBuffer = encoder.flushBuffer() + + let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) + + XCTAssertEqual(byteBuffer.readableBytes, expectedLength) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.password.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift new file mode 100644 index 00000000..62a8c62f --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -0,0 +1,74 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class ReadyForQueryTests: XCTestCase { + + func testDecode() { + var buffer = ByteBuffer() + + let states: [PostgresBackendMessage.TransactionState] = [ + .idle, + .inFailedTransaction, + .inTransaction, + ] + + states.forEach { state in + buffer.writeBackendMessage(id: .readyForQuery) { buffer in + switch state { + case .idle: + buffer.writeInteger(UInt8(ascii: "I")) + case .inTransaction: + buffer.writeInteger(UInt8(ascii: "T")) + case .inFailedTransaction: + buffer.writeInteger(UInt8(ascii: "E")) + } + } + } + + let expected = states.map { state -> PostgresBackendMessage in + .readyForQuery(state) + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + + } + + func testDecodeInvalidLength() { + var buffer = ByteBuffer() + + buffer.writeBackendMessage(id: .readyForQuery) { buffer in + buffer.writeInteger(UInt8(ascii: "I")) + buffer.writeInteger(UInt8(ascii: "I")) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeUnexpectedAscii() { + var buffer = ByteBuffer() + + buffer.writeBackendMessage(id: .readyForQuery) { buffer in + buffer.writeInteger(UInt8(ascii: "F")) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDebugDescription() { + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.idle), ".idle") + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.inTransaction), ".inTransaction") + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.inFailedTransaction), ".inFailedTransaction") + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift new file mode 100644 index 00000000..4eed785a --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -0,0 +1,135 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class RowDescriptionTests: XCTestCase { + + func testDecode() { + let columns: [RowDescription.Column] = [ + .init(name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary), + .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, format: .text), + ] + + let expected: [PostgresBackendMessage] = [ + .rowDescription(.init(columns: columns)) + ] + + var buffer = ByteBuffer() + expected.forEach { message in + guard case .rowDescription(let description) = message else { + return XCTFail("Expected only to get row descriptions here!") + } + + buffer.writeBackendMessage(id: .rowDescription) { buffer in + buffer.writeInteger(Int16(description.columns.count)) + + description.columns.forEach { column in + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + } + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + } + + func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { + let column = RowDescription.Column( + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) + + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .rowDescription) { buffer in + buffer.writeInteger(Int16(1)) + buffer.writeString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingColumnCount() { + let column = RowDescription.Column( + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) + + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .rowDescription) { buffer in + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseInvalidFormatCode() { + let column = RowDescription.Column( + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) + + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .rowDescription) { buffer in + buffer.writeInteger(Int16(1)) + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(UInt16(2)) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseNegativeColumnCount() { + let column = RowDescription.Column( + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) + + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .rowDescription) { buffer in + buffer.writeInteger(Int16(-1)) + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + +} diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift new file mode 100644 index 00000000..90aa6b34 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -0,0 +1,54 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class SASLInitialResponseTests: XCTestCase { + + func testEncode() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count + + // 1 id + // + 4 length + // + 6 saslMechanism (5 + 1 null terminator) + // + 4 initialData length + // + 8 initialData + + XCTAssertEqual(byteBuffer.readableBytes, length) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(initialData.count)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testEncodeWithoutData() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count + + // 1 id + // + 4 length + // + 6 saslMechanism (5 + 1 null terminator) + // + 4 initialData length + // + 0 initialData + + XCTAssertEqual(byteBuffer.readableBytes, length) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift new file mode 100644 index 00000000..cdb0f10b --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -0,0 +1,35 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class SASLResponseTests: XCTestCase { + + func testEncodeWithData() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (data.count) + + XCTAssertEqual(byteBuffer.readableBytes, length) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) + XCTAssertEqual(byteBuffer.readBytes(length: data.count), data) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testEncodeWithoutData() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + + XCTAssertEqual(byteBuffer.readableBytes, length) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift new file mode 100644 index 00000000..e9e6af81 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -0,0 +1,19 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class SSLRequestTests: XCTestCase { + + func testSSLRequest() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.ssl() + var byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.SSLRequest.requestCode, byteBuffer.readInteger()) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + +} diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift new file mode 100644 index 00000000..5af3bf34 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -0,0 +1,84 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class StartupTests: XCTestCase { + + func testStartupMessageWithDatabase() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database, options: []) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testStartupMessageWithoutDatabase() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + + encoder.startup(user: user, database: nil, options: []) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } + + func testStartupMessageWithAdditionalOptions() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database, options: [("some", "options")]) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} + +extension PostgresFrontendMessage.Startup.Parameters.Replication { + var stringValue: String { + switch self { + case .true: + return "true" + case .false: + return "false" + case .database: + return "replication" + } + } +} diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift new file mode 100644 index 00000000..195c7fb4 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -0,0 +1,296 @@ +import NIOCore +import NIOEmbedded +import NIOTestUtils +import XCTest +@testable import PostgresNIO + +class PSQLBackendMessageTests: XCTestCase { + + // MARK: ID + + func testInitMessageIDWithBytes() { + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "R")), .authentication) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "K")), .backendKeyData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "2")), .bindComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "3")), .closeComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "C")), .commandComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "d")), .copyData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "c")), .copyDone) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "G")), .copyInResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "H")), .copyOutResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "W")), .copyBothResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "D")), .dataRow) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "I")), .emptyQueryResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "E")), .error) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "V")), .functionCallResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "v")), .negotiateProtocolVersion) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "n")), .noData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "N")), .noticeResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "A")), .notificationResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "t")), .parameterDescription) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "S")), .parameterStatus) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "1")), .parseComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "s")), .portalSuspended) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "Z")), .readyForQuery) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "T")), .rowDescription) + + XCTAssertNil(PostgresBackendMessage.ID(rawValue: 0)) + } + + func testMessageIDHasCorrectRawValue() { + XCTAssertEqual(PostgresBackendMessage.ID.authentication.rawValue, UInt8(ascii: "R")) + XCTAssertEqual(PostgresBackendMessage.ID.backendKeyData.rawValue, UInt8(ascii: "K")) + XCTAssertEqual(PostgresBackendMessage.ID.bindComplete.rawValue, UInt8(ascii: "2")) + XCTAssertEqual(PostgresBackendMessage.ID.closeComplete.rawValue, UInt8(ascii: "3")) + XCTAssertEqual(PostgresBackendMessage.ID.commandComplete.rawValue, UInt8(ascii: "C")) + XCTAssertEqual(PostgresBackendMessage.ID.copyData.rawValue, UInt8(ascii: "d")) + XCTAssertEqual(PostgresBackendMessage.ID.copyDone.rawValue, UInt8(ascii: "c")) + XCTAssertEqual(PostgresBackendMessage.ID.copyInResponse.rawValue, UInt8(ascii: "G")) + XCTAssertEqual(PostgresBackendMessage.ID.copyOutResponse.rawValue, UInt8(ascii: "H")) + XCTAssertEqual(PostgresBackendMessage.ID.copyBothResponse.rawValue, UInt8(ascii: "W")) + XCTAssertEqual(PostgresBackendMessage.ID.dataRow.rawValue, UInt8(ascii: "D")) + XCTAssertEqual(PostgresBackendMessage.ID.emptyQueryResponse.rawValue, UInt8(ascii: "I")) + XCTAssertEqual(PostgresBackendMessage.ID.error.rawValue, UInt8(ascii: "E")) + XCTAssertEqual(PostgresBackendMessage.ID.functionCallResponse.rawValue, UInt8(ascii: "V")) + XCTAssertEqual(PostgresBackendMessage.ID.negotiateProtocolVersion.rawValue, UInt8(ascii: "v")) + XCTAssertEqual(PostgresBackendMessage.ID.noData.rawValue, UInt8(ascii: "n")) + XCTAssertEqual(PostgresBackendMessage.ID.noticeResponse.rawValue, UInt8(ascii: "N")) + XCTAssertEqual(PostgresBackendMessage.ID.notificationResponse.rawValue, UInt8(ascii: "A")) + XCTAssertEqual(PostgresBackendMessage.ID.parameterDescription.rawValue, UInt8(ascii: "t")) + XCTAssertEqual(PostgresBackendMessage.ID.parameterStatus.rawValue, UInt8(ascii: "S")) + XCTAssertEqual(PostgresBackendMessage.ID.parseComplete.rawValue, UInt8(ascii: "1")) + XCTAssertEqual(PostgresBackendMessage.ID.portalSuspended.rawValue, UInt8(ascii: "s")) + XCTAssertEqual(PostgresBackendMessage.ID.readyForQuery.rawValue, UInt8(ascii: "Z")) + XCTAssertEqual(PostgresBackendMessage.ID.rowDescription.rawValue, UInt8(ascii: "T")) + } + + // MARK: Decoder + + func testSSLSupportedAsFirstByte() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "S")) + + var expectedMessages: [PostgresBackendMessage] = [.sslSupported] + + // we test tons of ParameterStatus messages after the SSLSupported message, since those are + // also identified by an "S" + let parameterStatus: [PostgresBackendMessage.ParameterStatus] = [ + .init(parameter: "DateStyle", value: "ISO, MDY"), + .init(parameter: "application_name", value: ""), + .init(parameter: "server_encoding", value: "UTF8"), + .init(parameter: "integer_datetimes", value: "on"), + .init(parameter: "client_encoding", value: "UTF8"), + .init(parameter: "TimeZone", value: "Etc/UTC"), + .init(parameter: "is_superuser", value: "on"), + .init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)"), + .init(parameter: "session_authorization", value: "postgres"), + .init(parameter: "IntervalStyle", value: "postgres"), + .init(parameter: "standard_conforming_strings", value: "on"), + ] + + parameterStatus.forEach { parameterStatus in + buffer.writeBackendMessage(id: .parameterStatus) { buffer in + buffer.writeNullTerminatedString(parameterStatus.parameter) + buffer.writeNullTerminatedString(parameterStatus.value) + } + + expectedMessages.append(.parameterStatus(parameterStatus)) + } + + let handler = ByteToMessageHandler(PostgresBackendMessageDecoder()) + let embedded = EmbeddedChannel(handler: handler) + XCTAssertNoThrow(try embedded.writeInbound(buffer)) + + for expected in expectedMessages { + var message: PostgresBackendMessage? + XCTAssertNoThrow(message = try embedded.readInbound(as: PostgresBackendMessage.self)) + XCTAssertEqual(message, expected) + } + } + + func testSSLUnsupportedAsFirstByte() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "N")) + + // we test a NoticeResponse messages after the SSLUnupported message, since NoticeResponse + // is identified by a "N" + let fields: [PostgresBackendMessage.Field : String] = [ + .file: "auth.c", + .routine: "auth_failed", + .line: "334", + .localizedSeverity: "FATAL", + .sqlState: "28P01", + .severity: "FATAL", + .message: "password authentication failed for user \"postgre3\"", + ] + + let expectedMessages: [PostgresBackendMessage] = [ + .sslUnsupported, + .notice(.init(fields: fields)) + ] + + buffer.writeBackendMessage(id: .noticeResponse) { buffer in + fields.forEach { (key, value) in + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } + + let handler = ByteToMessageHandler(PostgresBackendMessageDecoder()) + let embedded = EmbeddedChannel(handler: handler) + XCTAssertNoThrow(try embedded.writeInbound(buffer)) + + for expected in expectedMessages { + var message: PostgresBackendMessage? + XCTAssertNoThrow(message = try embedded.readInbound(as: PostgresBackendMessage.self)) + XCTAssertEqual(message, expected) + } + } + + func testPayloadsWithoutAssociatedValues() { + let messageIDs: [PostgresBackendMessage.ID] = [ + .bindComplete, + .closeComplete, + .emptyQueryResponse, + .noData, + .parseComplete, + .portalSuspended + ] + + var buffer = ByteBuffer() + messageIDs.forEach { messageID in + buffer.writeBackendMessage(id: messageID) { _ in } + } + + let expected: [PostgresBackendMessage] = [ + .bindComplete, + .closeComplete, + .emptyQueryResponse, + .noData, + .parseComplete, + .portalSuspended + ] + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + } + + func testPayloadsWithoutAssociatedValuesInvalidLength() { + let messageIDs: [PostgresBackendMessage.ID] = [ + .bindComplete, + .closeComplete, + .emptyQueryResponse, + .noData, + .parseComplete, + .portalSuspended + ] + + for messageID in messageIDs { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: messageID) { buffer in + buffer.writeInteger(UInt8(0)) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + } + + func testDecodeCommandCompleteMessage() { + let expected: [PostgresBackendMessage] = [ + .commandComplete("SELECT 100"), + .commandComplete("INSERT 0 1"), + .commandComplete("UPDATE 1"), + .commandComplete("DELETE 1") + ] + + var okBuffer = ByteBuffer() + expected.forEach { message in + guard case .commandComplete(let commandTag) = message else { + return XCTFail("Programming error!") + } + + okBuffer.writeBackendMessage(id: .commandComplete) { buffer in + buffer.writeNullTerminatedString(commandTag) + } + } + + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(okBuffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + + // test commandTag is not null terminated + for message in expected { + guard case .commandComplete(let commandTag) = message else { + return XCTFail("Programming error!") + } + + var failBuffer = ByteBuffer() + failBuffer.writeBackendMessage(id: .commandComplete) { buffer in + buffer.writeString(commandTag) + } + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(failBuffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + } + + func testDecodeMessageWithUnknownMessageID() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "x")) + buffer.writeInteger(Int32(4)) + + XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDebugDescription() { + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual("\(PostgresBackendMessage.authentication(.ok))", ".authentication(.ok)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.kerberosV5))", + ".authentication(.kerberosV5)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: salt)))", + ".authentication(.md5(salt: \(salt)))") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.plaintext))", + ".authentication(.plaintext)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.scmCredential))", + ".authentication(.scmCredential)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.gss))", + ".authentication(.gss)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.sspi))", + ".authentication(.sspi)") + + XCTAssertEqual("\(PostgresBackendMessage.parameterStatus(.init(parameter: "foo", value: "bar")))", + #".parameterStatus(parameter: "foo", value: "bar")"#) + XCTAssertEqual("\(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567)))", + ".backendKeyData(processID: 1234, secretKey: 4567)") + + XCTAssertEqual("\(PostgresBackendMessage.bindComplete)", ".bindComplete") + XCTAssertEqual("\(PostgresBackendMessage.closeComplete)", ".closeComplete") + XCTAssertEqual("\(PostgresBackendMessage.commandComplete("SELECT 123"))", #".commandComplete("SELECT 123")"#) + XCTAssertEqual("\(PostgresBackendMessage.emptyQueryResponse)", ".emptyQueryResponse") + XCTAssertEqual("\(PostgresBackendMessage.noData)", ".noData") + XCTAssertEqual("\(PostgresBackendMessage.parseComplete)", ".parseComplete") + XCTAssertEqual("\(PostgresBackendMessage.portalSuspended)", ".portalSuspended") + + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.idle))", ".readyForQuery(.idle)") + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.inTransaction))", + ".readyForQuery(.inTransaction)") + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.inFailedTransaction))", + ".readyForQuery(.inFailedTransaction)") + XCTAssertEqual("\(PostgresBackendMessage.sslSupported)", ".sslSupported") + XCTAssertEqual("\(PostgresBackendMessage.sslUnsupported)", ".sslUnsupported") + } + +} diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift new file mode 100644 index 00000000..33afbe0d --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -0,0 +1,55 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class PSQLFrontendMessageTests: XCTestCase { + + // MARK: ID + + func testMessageIDs() { + XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, UInt8(ascii: "B")) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, UInt8(ascii: "C")) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, UInt8(ascii: "D")) + XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, UInt8(ascii: "E")) + XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, UInt8(ascii: "H")) + XCTAssertEqual(PostgresFrontendMessage.ID.parse.rawValue, UInt8(ascii: "P")) + XCTAssertEqual(PostgresFrontendMessage.ID.password.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.saslInitialResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, UInt8(ascii: "S")) + XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, UInt8(ascii: "X")) + } + + // MARK: Encoder + + func testEncodeFlush() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.flush() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length + } + + func testEncodeSync() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.sync() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length + } + + func testEncodeTerminate() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.terminate() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length + } + +} diff --git a/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift new file mode 100644 index 00000000..b069b4f0 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift @@ -0,0 +1,18 @@ +import NIOCore +import XCTest +@testable import PostgresNIO + +class PSQLMetadataTests: XCTestCase { + func testSelect() { + XCTAssertEqual(100, PostgresQueryMetadata(string: "SELECT 100")?.rows) + XCTAssertNotNil(PostgresQueryMetadata(string: "SELECT")) + XCTAssertNil(PostgresQueryMetadata(string: "SELECT")?.rows) + XCTAssertNil(PostgresQueryMetadata(string: "SELECT 100 100")) + } + + func testUpdate() { + XCTAssertEqual(100, PostgresQueryMetadata(string: "UPDATE 100")?.rows) + XCTAssertNil(PostgresQueryMetadata(string: "UPDATE")) + XCTAssertNil(PostgresQueryMetadata(string: "UPDATE 100 100")) + } +} diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift new file mode 100644 index 00000000..65ca26c3 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -0,0 +1,265 @@ +import Atomics +import NIOCore +import Logging +import XCTest +@testable import PostgresNIO +import NIOCore +import NIOEmbedded + +final class PSQLRowStreamTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + let eventLoop = EmbeddedEventLoop() + + func testEmptyStream() { + let stream = PSQLRowStream( + source: .noRows(.success(.tag("INSERT 0 1"))), + eventLoop: self.eventLoop, + logger: self.logger + ) + + XCTAssertEqual(try stream.all().wait(), []) + XCTAssertEqual(stream.commandTag, "INSERT 0 1") + } + + func testFailedStream() { + let stream = PSQLRowStream( + source: .noRows(.failure(PSQLError.serverClosedConnection(underlying: nil))), + eventLoop: self.eventLoop, + logger: self.logger + ) + + XCTAssertThrowsError(try stream.all().wait()) { + XCTAssertEqual($0 as? PSQLError, .serverClosedConnection(underlying: nil)) + } + } + + func testGetArrayAfterStreamHasFinished() { + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger + ) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + stream.receive(completion: .success("SELECT 2")) + + // attach consumer + let future = stream.all() + XCTAssertEqual(dataSource.hitDemand, 0) // TODO: Is this right? + + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try future.wait()) + XCTAssertEqual(rows?.count, 2) + } + + func testGetArrayBeforeStreamHasFinished() { + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger + ) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + + // attach consumer + let future = stream.all() + XCTAssertEqual(dataSource.hitDemand, 1) + + stream.receive([ + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")] + ]) + XCTAssertEqual(dataSource.hitDemand, 2) + + stream.receive([ + [ByteBuffer(string: "4")], + [ByteBuffer(string: "5")] + ]) + XCTAssertEqual(dataSource.hitDemand, 3) + + stream.receive(completion: .success("SELECT 2")) + + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try future.wait()) + XCTAssertEqual(rows?.count, 6) + } + + func testOnRowAfterStreamHasFinished() { + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger + ) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + stream.receive(completion: .success("SELECT 2")) + + XCTAssertEqual(dataSource.hitDemand, 0) + + // attach consumer + let counter = ManagedAtomic(0) + let future = stream.onRow { row in + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") + } + XCTAssertEqual(counter.load(ordering: .relaxed), 2) + XCTAssertEqual(dataSource.hitDemand, 0) + + XCTAssertNoThrow(try future.wait()) + XCTAssertEqual(stream.commandTag, "SELECT 2") + } + + func testOnRowThrowsErrorOnInitialBatch() { + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger + ) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")], + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")], + ]) + + stream.receive(completion: .success("SELECT 2")) + + XCTAssertEqual(dataSource.hitDemand, 0) + + // attach consumer + let counter = ManagedAtomic(0) + let future = stream.onRow { row in + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") + if expected == 1 { + throw OnRowError(row: expected) + } + } + XCTAssertEqual(counter.load(ordering: .relaxed), 2) // one more than where we excited, because we already incremented + XCTAssertEqual(dataSource.hitDemand, 0) + + XCTAssertThrowsError(try future.wait()) { + XCTAssertEqual($0 as? OnRowError, OnRowError(row: 1)) + } + } + + func testOnRowBeforeStreamHasFinished() { + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger + ) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + + // attach consumer + let counter = ManagedAtomic(0) + let future = stream.onRow { row in + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") + } + XCTAssertEqual(counter.load(ordering: .relaxed), 2) + XCTAssertEqual(dataSource.hitDemand, 1) + + stream.receive([ + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")] + ]) + XCTAssertEqual(counter.load(ordering: .relaxed), 4) + XCTAssertEqual(dataSource.hitDemand, 2) + + stream.receive([ + [ByteBuffer(string: "4")], + [ByteBuffer(string: "5")] + ]) + XCTAssertEqual(counter.load(ordering: .relaxed), 6) + XCTAssertEqual(dataSource.hitDemand, 3) + + stream.receive(completion: .success("SELECT 6")) + + XCTAssertNoThrow(try future.wait()) + XCTAssertEqual(stream.commandTag, "SELECT 6") + } + + func makeColumnDescription(name: String, dataType: PostgresDataType, format: PostgresFormat) -> RowDescription.Column { + RowDescription.Column( + name: "test", + tableOID: 123, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: -1, + dataTypeModifier: 0, + format: .binary + ) + } +} + +private struct OnRowError: Error, Equatable { + var row: Int +} + +class CountingDataSource: PSQLRowsDataSource { + + var hitDemand: Int = 0 + var hitCancel: Int = 0 + + init() {} + + func cancel(for stream: PSQLRowStream) { + self.hitCancel += 1 + } + + func request(for stream: PSQLRowStream) { + self.hitDemand += 1 + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift new file mode 100644 index 00000000..6458d063 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -0,0 +1,58 @@ +@testable import PostgresNIO +import XCTest +import NIOCore + +final class PostgresCellTests: XCTestCase { + func testDecodingANonOptionalString() { + let cell = PostgresCell( + bytes: ByteBuffer(string: "Hello world"), + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + var result: String? + XCTAssertNoThrow(result = try cell.decode(String.self, context: .default)) + XCTAssertEqual(result, "Hello world") + } + + func testDecodingAnOptionalString() { + let cell = PostgresCell( + bytes: nil, + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + var result: String? = "test" + XCTAssertNoThrow(result = try cell.decode(String?.self, context: .default)) + XCTAssertNil(result) + } + + func testDecodingFailure() { + let cell = PostgresCell( + bytes: ByteBuffer(string: "Hello world"), + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + XCTAssertThrowsError(try cell.decode(Int?.self, context: .default)) { + guard let error = $0 as? PostgresDecodingError else { + return XCTFail("Unexpected error") + } + + XCTAssertEqual(error.file, #fileID) + XCTAssertEqual(error.line, #line - 6) + XCTAssertEqual(error.code, .typeMismatch) + XCTAssertEqual(error.columnName, "hello") + XCTAssertEqual(error.columnIndex, 1) + XCTAssert(error.targetType == Int?.self) + XCTAssertEqual(error.postgresType, .text) + XCTAssertEqual(error.postgresFormat, .binary) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift new file mode 100644 index 00000000..206f38a3 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -0,0 +1,288 @@ +import XCTest +import NIOCore +import NIOTLS +import NIOSSL +import NIOEmbedded +@testable import PostgresNIO + +class PostgresChannelHandlerTests: XCTestCase { + + var eventLoop: EmbeddedEventLoop! + + override func setUp() { + super.setUp() + self.eventLoop = EmbeddedEventLoop() + } + + // MARK: Startup + + func testHandlerAddedWithoutSSL() { + let config = self.testConnectionConfiguration() + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ], loop: self.eventLoop) + defer { + XCTAssertNoThrow({ try embedded.finish() }) + } + + var maybeMessage: PostgresFrontendMessage? + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .startup(let startup) = maybeMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startup.parameters.user, config.username) + XCTAssertEqual(startup.parameters.database, config.database) + XCTAssert(startup.parameters.options.isEmpty) + + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle))) + } + + func testEstablishSSLCallbackIsCalledIfSSLIsSupported() { + var config = self.testConnectionConfiguration() + XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) + var addSSLCallbackIsHit = false + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in + addSSLCallbackIsHit = true + } + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ], loop: self.eventLoop) + + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) + + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslSupported)) + + // a NIOSSLHandler has been added, after it SSL had been negotiated + XCTAssertTrue(addSSLCallbackIsHit) + + // signal that the ssl connection has been established + embedded.pipeline.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: "")) + + // startup message should be issued + var maybeStartupMessage: PostgresFrontendMessage? + XCTAssertNoThrow(maybeStartupMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .startup(let startupMessage) = maybeStartupMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startupMessage.parameters.user, config.username) + XCTAssertEqual(startupMessage.parameters.database, config.database) + XCTAssertEqual(startupMessage.parameters.replication, .false) + } + + func testEstablishSSLCallbackIsNotCalledIfSSLIsSupportedButAnotherMEssageIsSentAsWell() { + var config = self.testConnectionConfiguration() + XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) + var addSSLCallbackIsHit = false + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in + addSSLCallbackIsHit = true + } + let eventHandler = TestEventHandler() + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler, + eventHandler + ], loop: self.eventLoop) + + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) + + var responseBuffer = ByteBuffer() + responseBuffer.writeInteger(UInt8(ascii: "S")) + responseBuffer.writeInteger(UInt8(ascii: "1")) + XCTAssertNoThrow(try embedded.writeInbound(responseBuffer)) + + XCTAssertFalse(addSSLCallbackIsHit) + + // the event handler should have seen an error + XCTAssertEqual(eventHandler.errors.count, 1) + + // the connections should be closed + XCTAssertFalse(embedded.isActive) + } + + func testSSLUnsupportedClosesConnection() throws { + let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) + + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in + XCTFail("This callback should never be exectuded") + throw PSQLError.sslUnsupported + } + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ], loop: self.eventLoop) + let eventHandler = TestEventHandler() + try embedded.pipeline.syncOperations.addHandler(eventHandler, position: .last) + + embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) + XCTAssertTrue(embedded.isActive) + + // read the ssl request message + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest) + try embedded.writeInbound(PostgresBackendMessage.sslUnsupported) + + // the event handler should have seen an error + XCTAssertEqual(eventHandler.errors.count, 1) + + // the connections should be closed + XCTAssertFalse(embedded.isActive) + } + + // MARK: Run Actions + + func testRunAuthenticateMD5Password() { + let config = self.testConnectionConfiguration() + let authContext = AuthContext( + username: config.username ?? "something wrong", + password: config.password, + database: config.database + ) + let state = ConnectionStateMachine(.waitingToStartAuthentication) + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, state: state, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler + ], loop: self.eventLoop) + + embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + let salt: UInt32 = 0x00_01_02_03 + + let encoder = PSQLBackendMessageEncoder() + var byteBuffer = ByteBuffer() + encoder.encode(data: .authentication(.md5(salt: salt)), out: &byteBuffer) + XCTAssertNoThrow(try embedded.writeInbound(byteBuffer)) + + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) + } + + func testRunAuthenticateCleartext() { + let password = "postgres" + let config = self.testConnectionConfiguration(password: password) + let authContext = AuthContext( + username: config.username ?? "something wrong", + password: config.password, + database: config.database + ) + let state = ConnectionStateMachine(.waitingToStartAuthentication) + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, state: state, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ], loop: self.eventLoop) + + embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.plaintext))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: password))) + } + + func testHandlerThatSendsMultipleWrongMessages() { + let config = self.testConnectionConfiguration() + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler + ], loop: self.eventLoop) + + var maybeMessage: PostgresFrontendMessage? + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .startup(let startup) = maybeMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startup.parameters.user, config.username) + XCTAssertEqual(startup.parameters.database, config.database) + XCTAssert(startup.parameters.options.isEmpty) + XCTAssertEqual(startup.parameters.replication, .false) + + var buffer = ByteBuffer() + buffer.writeMultipleIntegers(UInt8(ascii: "R"), UInt32(8), Int32(0)) + buffer.writeMultipleIntegers(UInt8(ascii: "K"), UInt32(12), Int32(1234), Int32(5678)) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + XCTAssertNoThrow(try embedded.writeInbound(buffer)) + XCTAssertTrue(embedded.isActive) + + buffer.clear() + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + + XCTAssertThrowsError(try embedded.writeInbound(buffer)) + XCTAssertFalse(embedded.isActive) + } + + // MARK: Helpers + + func testConnectionConfiguration( + host: String = "127.0.0.1", + port: Int = 5432, + username: String = "test", + database: String = "postgres", + password: String = "password", + tls: PostgresConnection.Configuration.TLS = .disable, + connectTimeout: TimeAmount = .seconds(10), + requireBackendKeyData: Bool = true + ) -> PostgresConnection.InternalConfiguration { + var options = PostgresConnection.Configuration.Options() + options.connectTimeout = connectTimeout + options.requireBackendKeyData = requireBackendKeyData + + return PostgresConnection.InternalConfiguration( + connection: .unresolvedTCP(host: host, port: port), + username: username, + password: password, + database: database, + tls: tls, + options: options + ) + } +} + +class TestEventHandler: ChannelInboundHandler { + typealias InboundIn = Never + + var errors = [PSQLError]() + var events = [PSQLEvent]() + + func errorCaught(context: ChannelHandlerContext, error: Error) { + guard let psqlError = error as? PSQLError else { + return XCTFail("Unexpected error type received: \(error)") + } + self.errors.append(psqlError) + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + guard let psqlEvent = event as? PSQLEvent else { + return XCTFail("Unexpected event type received: \(event)") + } + self.events.append(psqlEvent) + } +} + +extension AuthContext { + func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { + PostgresFrontendMessage.Startup.Parameters( + user: self.username, + database: self.database, + options: self.additionalParameters, + replication: .false + ) + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift new file mode 100644 index 00000000..94a0253b --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -0,0 +1,65 @@ +import XCTest +@testable import PostgresNIO +import NIOCore + +final class PostgresCodableTests: XCTestCase { + + func testDecodeAnOptionalFromARow() { + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + ) + + var result: (String?, String?) + XCTAssertNoThrow(result = try row.decode((String?, String?).self, context: .default)) + XCTAssertNil(result.0) + XCTAssertEqual(result.1, "Hello world!") + } + + func testDecodeMissingValueError() { + let row = PostgresRow( + data: .makeTestDataRow(nil), + lookupTable: ["name": 0], + columns: [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + ) + + XCTAssertThrowsError(try row.decode(String.self, context: .default)) { + XCTAssertEqual(($0 as? PostgresDecodingError)?.line, #line - 1) + XCTAssertEqual(($0 as? PostgresDecodingError)?.file, #fileID) + + XCTAssertEqual(($0 as? PostgresDecodingError)?.code, .missingData) + XCTAssert(($0 as? PostgresDecodingError)?.targetType == String.self) + XCTAssertEqual(($0 as? PostgresDecodingError)?.postgresType, .text) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift new file mode 100644 index 00000000..d0f8e2b0 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -0,0 +1,771 @@ +import NIOCore +import NIOPosix +import NIOEmbedded +import XCTest +import Logging +@testable import PostgresNIO + +class PostgresConnectionTests: XCTestCase { + + let logger = Logger(label: "PostgresConnectionTests") + + func testConnectionFailure() { + // We start a local server and close it immediately to ensure that the port + // number we try to connect to is not used by any other process. + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + var tempChannel: Channel? + XCTAssertNoThrow(tempChannel = try ServerBootstrap(group: eventLoopGroup) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait()) + let maybePort = tempChannel?.localAddress?.port + XCTAssertNoThrow(try tempChannel?.close().wait()) + guard let port = maybePort else { + return XCTFail("Could not get port number from temp started server") + } + + let config = PostgresConnection.Configuration( + host: "127.0.0.1", port: port, + username: "postgres", password: "abc123", database: "postgres", + tls: .disable + ) + + var logger = Logger.psqlTest + logger.logLevel = .trace + + XCTAssertThrowsError(try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { + XCTAssertTrue($0 is PSQLError) + } + } + + func testOptionsAreSentOnTheWire() async throws { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = { + var config = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + config.options.additionalStartupParameters = [ + ("DateStyle", "ISO, MDY"), + ("application_name", "postgres-nio-test"), + ("server_encoding", "UTF8"), + ("integer_datetimes", "on"), + ("client_encoding", "UTF8"), + ("TimeZone", "Etc/UTC"), + ("is_superuser", "on"), + ("server_version", "13.1 (Debian 13.1-1.pgdg100+1)"), + ("session_authorization", "postgres"), + ("IntervalStyle", "postgres"), + ("standard_conforming_strings", "on") + ] + return config + }() + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + try await connection.close() + } + + func testSimpleListen() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + XCTAssertEqual(event.payload, "wooohooo") + break + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + + func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + XCTAssertEqual(event.payload, "wooohooo") + break + } + } + + taskGroup.addTask { + let events = try await connection.listen("foo") + var counter = 0 + loop: for try await event in events { + defer { counter += 1 } + switch counter { + case 0: + XCTAssertEqual(event.payload, "wooohooo") + case 1: + XCTAssertEqual(event.payload, "wooohooo2") + break loop + default: + XCTFail("Unexpected message: \(event)") + } + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + + func testSimpleListenConnectionDrops() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "wooohooo") + do { + _ = try await iterator.next() + XCTFail("Did not expect to not throw") + } catch { + logger.error("error", metadata: ["error": "\(error)"]) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + struct MyWeirdError: Error {} + channel.pipeline.fireErrorCaught(MyWeirdError()) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, 1) + let second = try await iterator.next() + XCTAssertNil(second) + } + } + + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(terminate, .terminate) + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + } + + func testCloseClosesImmediatly() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: logger) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + async let close: () = connection.close() + + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + try await close + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + XCTFail("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + return XCTFail("Unexpected error type: \(failure)") + } + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + } + } + + func testIfServerJustClosesTheErrorReflectsThat() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let logger = self.logger + + async let response = try await connection.query("SELECT 1;", logger: logger) + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + + do { + _ = try await response + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + + // retry on same connection + + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + } + + struct TestPrepareStatement: PostgresPreparedStatement { + static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + typealias Row = String + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + func testPreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest = try await channel.waitForPreparedRequest() + XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(preparedRequest.bind.parameters.count, 1) + XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } + } + + func testWeDontCrashOnUnexpectedChannelEvents() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() + } + + func testSerialExecutionOfSamePreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter1, "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter2, "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_active", database) + } + XCTAssertEqual(rows, 1) + } + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_idle", database) + } + XCTAssertEqual(rows, 1) + } + + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationFailure() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + } + } + + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + + let logger = self.logger + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + + self.addTeardownBlock { + try await connection.close() + } + + return (connection, channel) + } +} + +extension NIOAsyncTestingChannel { + + func waitForUnpreparedRequest() async throws -> UnpreparedRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) + } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .sync = sync + else { + fatalError("Unexpected message") + } + + return PrepareRequest(parse: parse, describe: describe) + } + + func sendPrepareResponse( + parameterDescription: PostgresBackendMessage.ParameterDescription, + rowDescription: RowDescription + ) async throws { + try await self.writeInbound(PostgresBackendMessage.parseComplete) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.parameterDescription(parameterDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } + + func waitForPreparedRequest() async throws -> PreparedRequest { + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return PreparedRequest(bind: bind, execute: execute) + } + + func sendPreparedResponse( + dataRows: [DataRow], + commandTag: String + ) async throws { + try await self.writeInbound(PostgresBackendMessage.bindComplete) + try await self.testingEventLoop.executeInContext { self.read() } + for dataRow in dataRows { + try await self.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + } + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.commandComplete(commandTag)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } +} + +struct UnpreparedRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} + +struct PrepareRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe +} + +struct PreparedRequest { + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift new file mode 100644 index 00000000..33df5439 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -0,0 +1,137 @@ +@testable import PostgresNIO +import XCTest +import NIOCore + +final class PSQLErrorTests: XCTestCase { + func testPostgresBindingsDescription() { + let testBinds1 = PostgresBindings(capacity: 0) + var testBinds2 = PostgresBindings(capacity: 1) + var testBinds3 = PostgresBindings(capacity: 2) + testBinds2.append(1, context: .default) + testBinds3.appendUnprotected(1, context: .default) + testBinds3.appendUnprotected("foo", context: .default) + testBinds3.append("secret", context: .default) + + XCTAssertEqual(String(describing: testBinds1), "[]") + XCTAssertEqual(String(reflecting: testBinds1), "[]") + XCTAssertEqual(String(describing: testBinds2), "[****]") + XCTAssertEqual(String(reflecting: testBinds2), "[(****; BIGINT; format: binary)]") + XCTAssertEqual(String(describing: testBinds3), #"[1, "foo", ****]"#) + XCTAssertEqual(String(reflecting: testBinds3), #"[(1; BIGINT; format: binary), ("foo"; TEXT; format: binary), (****; TEXT; format: binary)]"#) + } + + func testPostgresQueryDescription() { + let testBinds1 = PostgresBindings(capacity: 0) + var testBinds2 = PostgresBindings(capacity: 1) + testBinds2.append(1, context: .default) + let testQuery1 = PostgresQuery(unsafeSQL: "TEST QUERY") + let testQuery2 = PostgresQuery(unsafeSQL: "TEST QUERY", binds: testBinds1) + let testQuery3 = PostgresQuery(unsafeSQL: "TEST QUERY", binds: testBinds2) + + XCTAssertEqual(String(describing: testQuery1), "TEST QUERY []") + XCTAssertEqual(String(reflecting: testQuery1), "PostgresQuery(sql: TEST QUERY, binds: [])") + XCTAssertEqual(String(describing: testQuery2), "TEST QUERY []") + XCTAssertEqual(String(reflecting: testQuery2), "PostgresQuery(sql: TEST QUERY, binds: [])") + XCTAssertEqual(String(describing: testQuery3), "TEST QUERY [****]") + XCTAssertEqual(String(reflecting: testQuery3), "PostgresQuery(sql: TEST QUERY, binds: [(****; BIGINT; format: binary)])") + } + + func testPSQLErrorDescription() { + var error1 = PSQLError.server(.init(fields: [.localizedSeverity: "ERROR", .severity: "ERROR", .sqlState: "00000", .message: "Test message", .detail: "More test message", .hint: "It's a test, that's your hint", .position: "1", .schemaName: "testsch", .tableName: "testtab", .columnName: "testcol", .dataTypeName: "testtyp", .constraintName: "testcon", .file: #fileID, .line: "0", .routine: #function])) + var testBinds = PostgresBindings(capacity: 1) + testBinds.append(1, context: .default) + error1.query = .init(unsafeSQL: "TEST QUERY", binds: testBinds) + + XCTAssertEqual(String(describing: error1), """ + PSQLError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + XCTAssertEqual(String(reflecting: error1), """ + PSQLError(code: server, serverInfo: [sqlState: 00000, detail: More test message, file: PostgresNIOTests/PostgresErrorTests.swift, hint: It's a test, that's your hint, line: 0, message: Test message, position: 1, routine: testPSQLErrorDescription(), localizedSeverity: ERROR, severity: ERROR, columnName: testcol, dataTypeName: testtyp, constraintName: testcon, schemaName: testsch, tableName: testtab], query: PostgresQuery(sql: TEST QUERY, binds: [(****; BIGINT; format: binary)])) + """) + } +} + +final class PostgresDecodingErrorTests: XCTestCase { + func testPostgresDecodingErrorEquality() { + let error1 = PostgresDecodingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: String.self, + postgresType: .text, + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 + ) + + let error2 = PostgresDecodingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: Int.self, + postgresType: .text, + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 + ) + + XCTAssertNotEqual(error1, error2) + let error3 = error1 + XCTAssertEqual(error1, error3) + } + + func testPostgresDecodingErrorDescription() { + let error1 = PostgresDecodingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: String.self, + postgresType: .text, + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 + ) + + let error2 = PostgresDecodingError( + code: .missingData, + columnName: "column", + columnIndex: 0, + targetType: [[String: String]].self, + postgresType: .jsonbArray, + postgresFormat: .binary, + postgresData: nil, + file: "bar.swift", + line: 123 + ) + + // Plain description + XCTAssertEqual(String(describing: error1), """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + XCTAssertEqual(String(describing: error2), """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + + // Extended debugDescription + XCTAssertEqual(String(reflecting: error1), """ + PostgresDecodingError(code: typeMismatch,\ + columnName: "column", columnIndex: 0,\ + targetType: Swift.String,\ + postgresType: TEXT, postgresFormat: binary,\ + postgresData: \(error1.postgresData?.debugDescription ?? "nil"),\ + file: foo.swift, line: 123\ + ) + """) + XCTAssertEqual(String(reflecting: error2), """ + PostgresDecodingError(code: missingData,\ + columnName: "column", columnIndex: 0,\ + targetType: Swift.Array>,\ + postgresType: JSONB[], postgresFormat: binary,\ + file: bar.swift, line: 123\ + ) + """) + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift new file mode 100644 index 00000000..4930f0c4 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -0,0 +1,128 @@ +@testable import PostgresNIO +import XCTest +import NIOCore + +final class PostgresQueryTests: XCTestCase { + + func testStringInterpolationWithOptional() { + let string = "Hello World" + let null: UUID? = nil + let uuid: UUID? = UUID() + + let query: PostgresQuery = """ + INSERT INTO foo (id, title, something) SET (\(uuid), \(string), \(null)); + """ + + XCTAssertEqual(query.sql, "INSERT INTO foo (id, title, something) SET ($1, $2, $3);") + + var expected = ByteBuffer() + expected.writeInteger(Int32(16)) + expected.writeBytes([ + uuid!.uuid.0, uuid!.uuid.1, uuid!.uuid.2, uuid!.uuid.3, + uuid!.uuid.4, uuid!.uuid.5, uuid!.uuid.6, uuid!.uuid.7, + uuid!.uuid.8, uuid!.uuid.9, uuid!.uuid.10, uuid!.uuid.11, + uuid!.uuid.12, uuid!.uuid.13, uuid!.uuid.14, uuid!.uuid.15, + ]) + + expected.writeInteger(Int32(string.utf8.count)) + expected.writeString(string) + expected.writeInteger(Int32(-1)) + + XCTAssertEqual(query.binds.bytes, expected) + } + + func testStringInterpolationWithDynamicType() { + let type = PostgresDataType(16435) + let format = PostgresFormat.binary + let dynamicString = DynamicString(value: "Hello world", psqlType: type, psqlFormat: format) + + let query: PostgresQuery = """ + INSERT INTO foo (dynamicType) SET (\(dynamicString)); + """ + + XCTAssertEqual(query.sql, "INSERT INTO foo (dynamicType) SET ($1);") + + var expectedBindsBytes = ByteBuffer() + expectedBindsBytes.writeInteger(Int32(dynamicString.value.utf8.count)) + expectedBindsBytes.writeString(dynamicString.value) + + let expectedMetadata: [PostgresBindings.Metadata] = [.init(dataType: type, format: format, protected: true)] + + XCTAssertEqual(query.binds.bytes, expectedBindsBytes) + XCTAssertEqual(query.binds.metadata, expectedMetadata) + } + + func testStringInterpolationWithCustomJSONEncoder() { + struct Foo: Codable, PostgresCodable { + var helloWorld: String + } + + let jsonEncoder = JSONEncoder() + jsonEncoder.keyEncodingStrategy = .convertToSnakeCase + + var query: PostgresQuery? + XCTAssertNoThrow(query = try """ + INSERT INTO test (foo) SET (\(Foo(helloWorld: "bar"), context: .init(jsonEncoder: jsonEncoder))); + """ + ) + + XCTAssertEqual(query?.sql, "INSERT INTO test (foo) SET ($1);") + + let expectedJSON = #"{"hello_world":"bar"}"# + + var expected = ByteBuffer() + expected.writeInteger(Int32(expectedJSON.utf8.count + 1)) + expected.writeInteger(UInt8(0x01)) + expected.writeString(expectedJSON) + + XCTAssertEqual(query?.binds.bytes, expected) + } + + func testAllowUsersToGenerateLotsOfRows() { + let sql = "INSERT INTO test (id) SET (\((1...5).map({"$\($0)"}).joined(separator: ", ")));" + + var query = PostgresQuery(unsafeSQL: sql, binds: .init(capacity: 5)) + for value in 1...5 { + query.binds.append(Int(value), context: .default) + } + + XCTAssertEqual(query.sql, "INSERT INTO test (id) SET ($1, $2, $3, $4, $5);") + + var expected = ByteBuffer() + for value in 1...5 { + expected.writeInteger(UInt32(8)) + expected.writeInteger(value) + } + + XCTAssertEqual(query.binds.bytes, expected) + } + + func testUnescapedSQL() { + let tableName = UUID().uuidString.uppercased() + let value = 1 + + let query: PostgresQuery = "INSERT INTO \(unescaped: tableName) (id) SET (\(value));" + + var expected = ByteBuffer() + expected.writeInteger(UInt32(8)) + expected.writeInteger(value) + + XCTAssertEqual(query.binds.bytes, expected) + } +} + +extension PostgresQueryTests { + struct DynamicString: PostgresDynamicTypeEncodable { + let value: String + + var psqlType: PostgresDataType + var psqlFormat: PostgresFormat + + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresNIO.PostgresEncodingContext + ) where JSONEncoder: PostgresJSONEncoder { + byteBuffer.writeString(value) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift new file mode 100644 index 00000000..9d662252 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -0,0 +1,457 @@ +import Atomics +import NIOEmbedded +import NIOPosix +import XCTest +@testable import PostgresNIO +import NIOCore +import Logging + +final class PostgresRowSequenceTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + + func testBackpressureWorks() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRow: DataRow = [ByteBuffer(integer: Int64(1))] + stream.receive([dataRow]) + + var iterator = rowSequence.makeAsyncIterator() + let row = try await iterator.next() + XCTAssertEqual(dataSource.requestCount, 1) + XCTAssertEqual(row?.data, dataRow) + + stream.receive(completion: .success("SELECT 1")) + let empty = try await iterator.next() + XCTAssertNil(empty) + } + + + func testCancellationWorksWhileIterating() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(Int.self), counter) + counter += 1 + + if counter == 64 { + break + } + } + + XCTAssertEqual(dataSource.cancelCount, 1) + } + + func testCancellationWorksBeforeIterating() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + + var iterator: PostgresRowSequence.AsyncIterator? = rowSequence.makeAsyncIterator() + iterator = nil + + XCTAssertEqual(dataSource.cancelCount, 1) + XCTAssertNil(iterator, "Surpress warning") + } + + func testDroppingTheSequenceCancelsTheSource() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + var rowSequence: PostgresRowSequence? = stream.asyncSequence() + rowSequence = nil + + XCTAssertEqual(dataSource.cancelCount, 1) + XCTAssertNil(rowSequence, "Surpress warning") + } + + func testStreamBasedOnCompletedQuery() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let rowSequence = stream.asyncSequence() + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + stream.receive(completion: .success("SELECT 128")) + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(Int.self), counter) + counter += 1 + } + + XCTAssertEqual(dataSource.cancelCount, 0) + } + + func testStreamIfInitializedWithAllData() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + stream.receive(completion: .success("SELECT 128")) + + let rowSequence = stream.asyncSequence() + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(Int.self), counter) + counter += 1 + } + + XCTAssertEqual(dataSource.cancelCount, 0) + } + + func testStreamIfInitializedWithError() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) + + let rowSequence = stream.asyncSequence() + + do { + var counter = 0 + for try await _ in rowSequence { + counter += 1 + } + XCTFail("Expected that an error was thrown before.") + } catch { + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + } + } + + func testSucceedingRowContinuationsWorks() async throws { + let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: eventLoop, + logger: self.logger + ) + + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() + var rowIterator = rowSequence.makeAsyncIterator() + + eventLoop.scheduleTask(in: .seconds(1)) { + let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + } + + let row1 = try await rowIterator.next() + XCTAssertEqual(try row1?.decode(Int.self), 0) + + eventLoop.scheduleTask(in: .seconds(1)) { + stream.receive(completion: .success("SELECT 1")) + } + + let row2 = try await rowIterator.next() + XCTAssertNil(row2) + } + + func testFailingRowContinuationsWorks() async throws { + let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: eventLoop, + logger: self.logger + ) + + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() + var rowIterator = rowSequence.makeAsyncIterator() + + eventLoop.scheduleTask(in: .seconds(1)) { + let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + } + + let row1 = try await rowIterator.next() + XCTAssertEqual(try row1?.decode(Int.self), 0) + + eventLoop.scheduleTask(in: .seconds(1)) { + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) + } + + do { + _ = try await rowIterator.next() + XCTFail("Expected that an error was thrown before.") + } catch { + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + } + } + + func testAdaptiveRowBufferShrinksAndGrows() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + let initialDataRows: [DataRow] = (0.. don't ask for more + XCTAssertEqual(dataSource.requestCount, 0) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 1) + + // if the buffer gets new rows so that it has equal or more than target (the target size + // should be halved), however shrinking is only allowed AFTER the first extra rows were + // received. + let addDataRows1: [DataRow] = [[ByteBuffer(integer: Int64(0))]] + stream.receive(addDataRows1) + XCTAssertEqual(dataSource.requestCount, 1) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 2) + + // if the buffer gets new rows so that it has equal or more than target (the target size + // should be halved) + let addDataRows2: [DataRow] = [[ByteBuffer(integer: Int64(0))], [ByteBuffer(integer: Int64(0))]] + stream.receive(addDataRows2) // this should to target being halved. + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget / 2) { + _ = try await rowIterator.next() // Remove all rows until we are back at target + XCTAssertEqual(dataSource.requestCount, 2) + } + + // if we remove another row we should trigger getting new rows. + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 3) + + // remove all remaining rows... this will trigger a target size double + for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { + _ = try await rowIterator.next() // Remove all rows until we are back at target + XCTAssertEqual(dataSource.requestCount, 3) + } + + let fillBufferDataRows: [DataRow] = (0.. don't ask for more + XCTAssertEqual(dataSource.requestCount, 3) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 4) + } + + func testAdaptiveRowShrinksToMin() async throws { + let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() + let stream = PSQLRowStream( + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: embeddedEventLoop, + logger: self.logger + ) + + var currentTarget = AdaptiveRowBuffer.defaultBufferTarget + + let initialDataRows: [DataRow] = (0.. AdaptiveRowBuffer.defaultBufferMinimum { + // the buffer is filled up to currentTarget at that point, if we remove one row and add + // one row it should shrink + XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + _ = try await rowIterator.next() + expectedRequestCount += 1 + XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + + stream.receive([[ByteBuffer(integer: Int64(1))], [ByteBuffer(integer: Int64(1))]]) + let newTarget = currentTarget / 2 + let toDrop = currentTarget + 1 - newTarget + + // consume all messages that are to much. + for _ in 0..= 350, "Results count not large enough: \(results.count)") - } - - func testSelectType() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - let results = try conn.simpleQuery("SELECT * FROM pg_type WHERE typname = 'float8'").wait() - // [ - // "typreceive": "float8recv", - // "typelem": "0", - // "typarray": "1022", - // "typalign": "d", - // "typanalyze": "-", - // "typtypmod": "-1", - // "typname": "float8", - // "typnamespace": "11", - // "typdefault": "", - // "typdefaultbin": "", - // "typcollation": "0", - // "typispreferred": "t", - // "typrelid": "0", - // "typbyval": "t", - // "typnotnull": "f", - // "typinput": "float8in", - // "typlen": "8", - // "typcategory": "N", - // "typowner": "10", - // "typtype": "b", - // "typdelim": ",", - // "typndims": "0", - // "typbasetype": "0", - // "typacl": "", - // "typisdefined": "t", - // "typmodout": "-", - // "typmodin": "-", - // "typsend": "float8send", - // "typstorage": "p", - // "typoutput": "float8out" - // ] - switch results.count { - case 1: - XCTAssertEqual(results[0].column("typname")?.string, "float8") - XCTAssertEqual(results[0].column("typnamespace")?.int, 11) - XCTAssertEqual(results[0].column("typowner")?.int, 10) - XCTAssertEqual(results[0].column("typlen")?.int, 8) - default: XCTFail("Should be exactly one result, but got \(results.count)") - } - } - - func testIntegers() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - struct Integers: Decodable { - let smallint: Int16 - let smallint_min: Int16 - let smallint_max: Int16 - let int: Int32 - let int_min: Int32 - let int_max: Int32 - let bigint: Int64 - let bigint_min: Int64 - let bigint_max: Int64 - } - let results = try conn.query(""" - SELECT - 1::SMALLINT as smallint, - -32767::SMALLINT as smallint_min, - 32767::SMALLINT as smallint_max, - 1::INT as int, - -2147483647::INT as int_min, - 2147483647::INT as int_max, - 1::BIGINT as bigint, - -9223372036854775807::BIGINT as bigint_min, - 9223372036854775807::BIGINT as bigint_max - """).wait() - switch results.count { - case 1: - XCTAssertEqual(results[0].column("smallint")?.int16, 1) - XCTAssertEqual(results[0].column("smallint_min")?.int16, -32_767) - XCTAssertEqual(results[0].column("smallint_max")?.int16, 32_767) - XCTAssertEqual(results[0].column("int")?.int32, 1) - XCTAssertEqual(results[0].column("int_min")?.int32, -2_147_483_647) - XCTAssertEqual(results[0].column("int_max")?.int32, 2_147_483_647) - XCTAssertEqual(results[0].column("bigint")?.int64, 1) - XCTAssertEqual(results[0].column("bigint_min")?.int64, -9_223_372_036_854_775_807) - XCTAssertEqual(results[0].column("bigint_max")?.int64, 9_223_372_036_854_775_807) - default: XCTFail("Should be exactly one result, but got \(results.count)") - } - } - - func testPi() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - struct Pi: Decodable { - let text: String - let numeric_string: String - let numeric_decimal: Decimal - let double: Double - let float: Float - } - let results = try conn.query(""" - SELECT - pi()::TEXT as text, - pi()::NUMERIC as numeric_string, - pi()::NUMERIC as numeric_decimal, - pi()::FLOAT8 as double, - pi()::FLOAT4 as float - """).wait() - switch results.count { - case 1: - //print(results[0]) - XCTAssertEqual(results[0].column("text")?.string?.hasPrefix("3.14159265"), true) - XCTAssertEqual(results[0].column("numeric_string")?.string?.hasPrefix("3.14159265"), true) - XCTAssertTrue(results[0].column("numeric_decimal")?.decimal?.isLess(than: 3.14159265358980) ?? false) - XCTAssertFalse(results[0].column("numeric_decimal")?.decimal?.isLess(than: 3.14159265358978) ?? true) - XCTAssertTrue(results[0].column("double")?.double?.description.hasPrefix("3.141592") ?? false) - XCTAssertTrue(results[0].column("float")?.float?.description.hasPrefix("3.141592") ?? false) - default: XCTFail("Should be exactly one result, but got \(results.count)") - } - } - - func testUUID() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - struct Model: Decodable { - let id: UUID - let string: String - } - let results = try conn.query(""" - SELECT - '123e4567-e89b-12d3-a456-426655440000'::UUID as id, - '123e4567-e89b-12d3-a456-426655440000'::UUID as string - """).wait() - switch results.count { - case 1: - //print(results[0]) - XCTAssertEqual(results[0].column("id")?.uuid, UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) - XCTAssertEqual(UUID(uuidString: results[0].column("id")?.string ?? ""), UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) - default: XCTFail("Should be exactly one result, but got \(results.count)") - } - } - - func testDates() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - struct Dates: Decodable { - var date: Date - var timestamp: Date - var timestamptz: Date - } - let results = try conn.query(""" - SELECT - '2016-01-18 01:02:03 +0042'::DATE as date, - '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, - '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz - """).wait() - switch results.count { - case 1: - //print(results[0]) - XCTAssertEqual(results[0].column("date")?.date?.description, "2016-01-18 00:00:00 +0000") - XCTAssertEqual(results[0].column("timestamp")?.date?.description, "2016-01-18 01:02:03 +0000") - XCTAssertEqual(results[0].column("timestamptz")?.date?.description, "2016-01-18 00:20:03 +0000") - default: XCTFail("Should be exactly one result, but got \(results.count)") - } - } - - /// https://github.com/vapor/nio-postgres/issues/20 - func testBindInteger() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - _ = try conn.simpleQuery("drop table if exists person;").wait() - _ = try conn.simpleQuery("create table person(id serial primary key, first_name text, last_name text);").wait() - defer { _ = try! conn.simpleQuery("drop table person;").wait() } - let id = PostgresData(int32: 5) - _ = try conn.query("SELECT id, first_name, last_name FROM person WHERE id = $1", [id]).wait() - } - - // https://github.com/vapor/nio-postgres/issues/21 - func testAverageLengthNumeric() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - let rows = try conn.query("select avg(length('foo')) as average_length").wait() - let length = try XCTUnwrap(rows[0].column("average_length")?.double) - XCTAssertEqual(length, 3.0) - } - - func testNumericParsing() throws { - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - let rows = try conn.query(""" - select - '1234.5678'::numeric as a, - '-123.456'::numeric as b, - '123456.789123'::numeric as c, - '3.14159265358979'::numeric as d, - '10000'::numeric as e, - '0.00001'::numeric as f, - '100000000'::numeric as g, - '0.000000001'::numeric as h, - '100000000000'::numeric as i, - '0.000000000001'::numeric as j, - '123000000000'::numeric as k, - '0.000000000123'::numeric as l, - '0.5'::numeric as m - """).wait() - XCTAssertEqual(rows[0].column("a")?.string, "1234.5678") - XCTAssertEqual(rows[0].column("b")?.string, "-123.456") - XCTAssertEqual(rows[0].column("c")?.string, "123456.789123") - XCTAssertEqual(rows[0].column("d")?.string, "3.14159265358979") - XCTAssertEqual(rows[0].column("e")?.string, "10000") - XCTAssertEqual(rows[0].column("f")?.string, "0.00001") - XCTAssertEqual(rows[0].column("g")?.string, "100000000") - XCTAssertEqual(rows[0].column("h")?.string, "0.000000001") - XCTAssertEqual(rows[0].column("k")?.string, "123000000000") - XCTAssertEqual(rows[0].column("l")?.string, "0.000000000123") - XCTAssertEqual(rows[0].column("m")?.string, "0.5") - } - - func testSingleNumericParsing() throws { - // this seemingly duped test is useful for debugging numeric parsing - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - let numeric = "790226039477542363.6032384900176272473" - let rows = try conn.query(""" - select - '\(numeric)'::numeric as n - """).wait() - XCTAssertEqual(rows[0].column("n")?.string, numeric) - } - - func testRandomlyGeneratedNumericParsing() throws { - // this test takes a long time to run - try XCTSkipUnless(Self.shouldRunLongRunningTests) - - let conn = try PostgresConnection.test(on: eventLoop).wait() - defer { try! conn.close().wait() } - - for _ in 0..<1_000_000 { - let integer = UInt.random(in: UInt.min.. String? { - getenv(name).flatMap { String(cString: $0) } -} - -let isLoggingConfigured: Bool = { - LoggingSystem.bootstrap { label in - var handler = StreamLogHandler.standardOutput(label: label) - handler.logLevel = env("LOG_LEVEL").flatMap { Logger.Level(rawValue: $0) } ?? .debug - return handler - } - return true -}() diff --git a/Tests/PostgresNIOTests/Utilities.swift b/Tests/PostgresNIOTests/Utilities.swift index 21e6a2fb..610d8f10 100644 --- a/Tests/PostgresNIOTests/Utilities.swift +++ b/Tests/PostgresNIOTests/Utilities.swift @@ -1,95 +1,9 @@ import Logging -import PostgresNIO -import XCTest -extension PostgresConnection { - static func address() throws -> SocketAddress { - try .makeAddressResolvingHost( env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432) +extension Logger { + static var psqlTest: Logger { + var logger = Logger(label: "psql.test") + logger.logLevel = .info + return logger } - - static func testUnauthenticated(on eventLoop: EventLoop) -> EventLoopFuture { - do { - return connect(to: try address(), on: eventLoop) - } catch { - return eventLoop.makeFailedFuture(error) - } - } - - static func test(on eventLoop: EventLoop) -> EventLoopFuture { - return testUnauthenticated(on: eventLoop).flatMap { conn in - return conn.authenticate( - username: env("POSTGRES_USERNAME") ?? "vapor_username", - database: env("POSTGRES_DATABASE") ?? "vapor_database", - password: env("POSTGRES_PASSWORD") ?? "vapor_password" - ).map { - return conn - }.flatMapError { error in - conn.close().flatMapThrowing { - throw error - } - } - } - } -} - -extension XCTestCase { - - public static var shouldRunLongRunningTests: Bool { - // The env var must be set and have the value `"true"`, `"1"`, or `"yes"` (case-insensitive). - // For the sake of sheer annoying pedantry, values like `"2"` are treated as false. - guard let rawValue = ProcessInfo.processInfo.environment["POSTGRES_LONG_RUNNING_TESTS"] else { return false } - if let boolValue = Bool(rawValue) { return boolValue } - if let intValue = Int(rawValue) { return intValue == 1 } - return rawValue.lowercased() == "yes" - } - - public static var shouldRunPerformanceTests: Bool { - // Same semantics as above. Any present non-truthy value will explicitly disable performance - // tests even if they would've overwise run in the current configuration. - let defaultValue = !_isDebugAssertConfiguration() // default to not running in debug builds - - guard let rawValue = ProcessInfo.processInfo.environment["POSTGRES_PERFORMANCE_TESTS"] else { return defaultValue } - if let boolValue = Bool(rawValue) { return boolValue } - if let intValue = Int(rawValue) { return intValue == 1 } - return rawValue.lowercased() == "yes" - } - -} - - -// 1247.typisdefined: 0x01 (BOOLEAN) -// 1247.typbasetype: 0x00000000 (OID) -// 1247.typnotnull: 0x00 (BOOLEAN) -// 1247.typcategory: 0x42 (CHAR) -// 1247.typname: 0x626f6f6c (NAME) -// 1247.typbyval: 0x01 (BOOLEAN) -// 1247.typrelid: 0x00000000 (OID) -// 1247.typalign: 0x63 (CHAR) -// 1247.typndims: 0x00000000 (INTEGER) -// 1247.typacl: null -// 1247.typsend: 0x00000985 (REGPROC) -// 1247.typmodout: 0x00000000 (REGPROC) -// 1247.typstorage: 0x70 (CHAR) -// 1247.typispreferred: 0x01 (BOOLEAN) -// 1247.typinput: 0x000004da (REGPROC) -// 1247.typoutput: 0x000004db (REGPROC) -// 1247.typlen: 0x0001 (SMALLINT) -// 1247.typcollation: 0x00000000 (OID) -// 1247.typdefaultbin: null -// 1247.typelem: 0x00000000 (OID) -// 1247.typnamespace: 0x0000000b (OID) -// 1247.typtype: 0x62 (CHAR) -// 1247.typowner: 0x0000000a (OID) -// 1247.typdefault: null -// 1247.typtypmod: 0xffffffff (INTEGER) -// 1247.typarray: 0x000003e8 (OID) -// 1247.typreceive: 0x00000984 (REGPROC) -// 1247.typmodin: 0x00000000 (REGPROC) -// 1247.typanalyze: 0x00000000 (REGPROC) -// 1247.typdelim: 0x2c (CHAR) -struct PGType: Decodable { - var typname: String - var typnamespace: UInt32 - var typowner: UInt32 - var typlen: Int16 } diff --git a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift new file mode 100644 index 00000000..c6f876f2 --- /dev/null +++ b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift @@ -0,0 +1,66 @@ +import Atomics +import NIOCore +import XCTest +import PostgresNIO + +class PostgresJSONCodingTests: XCTestCase { + // https://github.com/vapor/postgres-nio/issues/126 + func testCustomJSONEncoder() { + let previousDefaultJSONEncoder = PostgresNIO._defaultJSONEncoder + defer { + PostgresNIO._defaultJSONEncoder = previousDefaultJSONEncoder + } + final class CustomJSONEncoder: PostgresJSONEncoder { + let counter = ManagedAtomic(0) + func encode(_ value: T) throws -> Data where T : Encodable { + self.counter.wrappingIncrement(ordering: .relaxed) + return try JSONEncoder().encode(value) + } + } + struct Object: Codable { + var foo: Int + var bar: Int + } + let customJSONEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 0) + PostgresNIO._defaultJSONEncoder = customJSONEncoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 1) + + let customJSONBEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 0) + PostgresNIO._defaultJSONEncoder = customJSONBEncoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 1) + } + + // https://github.com/vapor/postgres-nio/issues/126 + func testCustomJSONDecoder() { + let previousDefaultJSONDecoder = PostgresNIO._defaultJSONDecoder + defer { + PostgresNIO._defaultJSONDecoder = previousDefaultJSONDecoder + } + final class CustomJSONDecoder: PostgresJSONDecoder { + let counter = ManagedAtomic(0) + func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable { + self.counter.wrappingIncrement(ordering: .relaxed) + return try JSONDecoder().decode(type, from: data) + } + } + struct Object: Codable { + var foo: Int + var bar: Int + } + let customJSONDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 0) + PostgresNIO._defaultJSONDecoder = customJSONDecoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 1) + + let customJSONBDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 0) + PostgresNIO._defaultJSONDecoder = customJSONBDecoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 1) + } +} diff --git a/dev/generate-postgresrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh new file mode 100755 index 00000000..e641ed8d --- /dev/null +++ b/dev/generate-postgresrow-multi-decode.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +set -eu + +here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +function genWithContextParameter() { + how_many=$1 + + if [[ $how_many -ne 1 ]] ; then + echo "" + fi + + echo " @inlinable" + echo " @_alwaysEmitIntoClient" + echo -n " public func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws" + + echo -n " -> (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo ") {" + echo " precondition(self.columns.count >= $how_many)" + #echo " var columnIndex = 0" + if [[ $how_many -eq 1 ]] ; then + echo " let columnIndex = 0" + echo " var cellIterator = self.data.makeIterator()" + echo " var cellData = cellIterator.next().unsafelyUnwrapped" + echo " var columnIterator = self.columns.makeIterator()" + echo " let column = columnIterator.next().unsafelyUnwrapped" + echo " let swiftTargetType: Any.Type = T0.self" + else + echo " var columnIndex = 0" + echo " var cellIterator = self.data.makeIterator()" + echo " var cellData = cellIterator.next().unsafelyUnwrapped" + echo " var columnIterator = self.columns.makeIterator()" + echo " var column = columnIterator.next().unsafelyUnwrapped" + echo " var swiftTargetType: Any.Type = T0.self" + fi + + echo + echo " do {" + echo " let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo + for ((n = 1; n<$how_many; n +=1)); do + echo " columnIndex = $n" + echo " cellData = cellIterator.next().unsafelyUnwrapped" + echo " column = columnIterator.next().unsafelyUnwrapped" + echo " swiftTargetType = T$n.self" + echo " let r$n = try T$n._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo + done + + echo -n " return (r0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", r$(($n))" + done + echo ")" + echo " } catch let code as PostgresDecodingError.Code {" + echo " throw PostgresDecodingError(" + echo " code: code," + echo " columnName: column.name," + echo " columnIndex: columnIndex," + echo " targetType: swiftTargetType," + echo " postgresType: column.dataType," + echo " postgresFormat: column.format," + echo " postgresData: cellData," + echo " file: file," + echo " line: line" + echo " )" + echo " }" + echo " }" +} + +function genWithoutContextParameter() { + how_many=$1 + + echo "" + + echo " @inlinable" + echo " @_alwaysEmitIntoClient" + echo -n " public func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, file: String = #fileID, line: Int = #line) throws" + + echo -n " -> (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo ") {" + echo -n " try self.decode(" + if [[ $how_many -eq 1 ]] ; then + echo -n "T0.self" + else + echo -n "(T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").self" + fi + echo ", context: .default, file: file, line: line)" + echo " }" +} + +grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { + echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" + exit 1 +} + +{ +cat <<"EOF" +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh +EOF +echo + +echo "extension PostgresRow {" + +# note: +# - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor +# - narrowing the interval below is SemVer _MAJOR_! +for n in {1..15}; do + genWithContextParameter "$n" + genWithoutContextParameter "$n" +done +echo "}" +} > "$here/../Sources/PostgresNIO/New/PostgresRow-multi-decode.swift" diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh new file mode 100755 index 00000000..8317149b --- /dev/null +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +set -eu + +here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +function genWithContextParameter() { + how_many=$1 + + if [[ $how_many -ne 1 ]] ; then + echo "" + fi + + echo " @inlinable" + echo " @_alwaysEmitIntoClient" + echo -n " public func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) " + + echo -n "-> AsyncThrowingMapSequence {" + + echo " self.map { row in" + + if [[ $how_many -eq 1 ]] ; then + echo " try row.decode(T0.self, context: context, file: file, line: line)" + else + echo -n " try row.decode((T0" + + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$n" + done + echo ").self, context: context, file: file, line: line)" + + fi + + echo " }" + echo " }" +} + +function genWithoutContextParameter() { + how_many=$1 + + echo "" + + echo " @inlinable" + echo " @_alwaysEmitIntoClient" + echo -n " public func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, file: String = #fileID, line: Int = #line) " + echo -n "-> AsyncThrowingMapSequence {" + + echo -n " self.decode(" + if [[ $how_many -eq 1 ]] ; then + echo -n "T0.self" + else + echo -n "(T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").self" + fi + echo ", context: .default, file: file, line: line)" + echo " }" +} + +grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { + echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" + exit 1 +} + +{ +cat <<"EOF" +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh +EOF +echo + +echo "#if canImport(_Concurrency)" +echo "extension AsyncSequence where Element == PostgresRow {" + +# note: +# - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor +# - narrowing the interval below is SemVer _MAJOR_! +for n in {1..15}; do + genWithContextParameter "$n" + genWithoutContextParameter "$n" +done +echo "}" +echo "#endif" +} > "$here/../Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift" diff --git a/docker-compose.yml b/docker-compose.yml index ea508229..3eff4249 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,43 +1,30 @@ -version: '2' +version: '3.7' + +x-shared-config: &shared_config + environment: + POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-scram-sha-256}" + POSTGRES_USER: test_username + POSTGRES_DB: test_database + POSTGRES_PASSWORD: test_password + ports: + - 5432:5432 services: + psql-16: + image: postgres:16 + <<: *shared_config + psql-15: + image: postgres:15 + <<: *shared_config + psql-14: + image: postgres:14 + <<: *shared_config + psql-13: + image: postgres:13 + <<: *shared_config psql-12: image: postgres:12 - environment: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - ports: - - 5432:5432 + <<: *shared_config psql-11: image: postgres:11 - environment: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - ports: - - 5432:5432 - psql-10: - image: postgres:10 - environment: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - ports: - - 5432:5432 - psql-9: - image: postgres:9 - environment: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - ports: - - 5432:5432 - psql-ssl: - image: scenecheck/postgres-ssl:latest - environment: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - ports: - - 5432:5432 + <<: *shared_config