diff --git a/spannerlib/wrappers/spannerlib-python/.gitignore b/spannerlib/wrappers/spannerlib-python/.gitignore new file mode 100644 index 00000000..a4341510 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/.gitignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +build/ +dist/ +.eggs/ +lib/ +lib64/ +*.egg-info/ +*.egg +MANIFEST + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +.pytest_cache/ +cover/ + +# Environments +.env +.venv +env/ +venv/ + +# mypy +.mypy_cache/ + +# IDEs and editors +.idea/ +.vscode/ +.DS_Store + +# Build Artifacts +google/cloud/spannerlib/spannerlib-artifacts + +*_sponge_log.xml \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-python/MANIFEST.in b/spannerlib/wrappers/spannerlib-python/MANIFEST.in new file mode 100644 index 00000000..1a112459 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/MANIFEST.in @@ -0,0 +1,2 @@ +include google/cloud/spannerlib/internal/spannerlib-artifacts/spannerlib.so +include google/cloud/spannerlib/internal/spannerlib-artifacts/spannerlib.h \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-python/README.md b/spannerlib/wrappers/spannerlib-python/README.md new file mode 100644 index 00000000..14606bc8 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/README.md @@ -0,0 +1,87 @@ +# SPANNERLIB-PY: A High-Performance Python Wrapper for the Go Spanner Client Shared lib 🐍 + +## Introduction +spannerlib-py provides a high-performance, idiomatic Python interface for Google Cloud Spanner by wrapping the official Go Client Shared library. + +The Go library is compiled into a C-shared library (.so), and this project uses ctypes to call it directly from Python, aiming to combine Go's performance with Python's ease of use. + +**Code Structure** + +```bash +spannerlib-python/ +|___google/cloud/spannerlib/ + |_____internal - SpannerLib wrapper +|___tests/ + |___unit/ - Unit tests + |___system/ - System tests +|___samples +README.md +noxfile.py +myproject.toml - Project config for packaging +``` + +**Lint support** + +```bash +nox -s format lint +``` + +## Running Tests + +### Unit Tests + +To run the unit tests, navigate to the root directory of this wrapper (`spannerlib-python`) and run: + +```bash +python3 -m unittest tests/unit/test_spannerlib_wrapper.py +``` + +### System Tests + +The system tests require a Cloud Spanner Emulator instance running. + +1. **Pull and Run the Emulator:** + ```bash + docker pull gcr.io/cloud-spanner-emulator/emulator + docker run -p 9010:9010 -p 9020:9020 -d gcr.io/cloud-spanner-emulator/emulator + ``` + +2. **Set Environment Variable:** + Ensure the `SPANNER_EMULATOR_HOST` environment variable is set: + ```bash + export SPANNER_EMULATOR_HOST=localhost:9010 + ``` + +3. **Create Test Instance and Database:** + You need the `gcloud` CLI installed and configured. + ```bash + gcloud spanner instances create test-instance --config=emulator-config --description="Test Instance" --nodes=1 + gcloud spanner databases create testdb --instance=test-instance + ``` + +4. **Run the System Tests:** + Navigate to the root directory of this wrapper (`spannerlib-python`) and run: + ```bash + python3 -m unittest tests/system/test_spannerlib_wrapper.py + ``` + +## Build and install + +**Install locally** +```bash +pip3 install -e . +``` + +**Package** +```bash +pip3 install build +python3 -m build +``` + +**Validate** +```bash +pip3 install twine +twine check dist/* +unzip -l dist/spannerlib-0.1.0-py3-none-any.whl +tar -tvzf dist/spannerlib-0.1.0.tar.gz +``` diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/__init__.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/__init__.py new file mode 100644 index 00000000..e9e24859 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python wrapper for the Spanner Go library.""" +from __future__ import absolute_import + +import logging + +from google.cloud.spannerlib.connection import Connection +from google.cloud.spannerlib.internal.errors import SpannerLibError +from google.cloud.spannerlib.pool import Pool +from google.cloud.spannerlib.rows import Rows + +logging.basicConfig(level=logging.INFO) + +__all__ = [ + "Pool", + "Connection", + "Rows", + "SpannerLibError", +] diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/connection.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/connection.py new file mode 100644 index 00000000..a904f69c --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/connection.py @@ -0,0 +1,264 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for the Connection class, representing a single connection to Spanner.""" + +from __future__ import absolute_import + +import logging +from typing import Any + +from google.cloud.spanner_v1 import ( + BatchWriteRequest, + CommitResponse, + ExecuteBatchDmlRequest, + ExecuteBatchDmlResponse, + ExecuteSqlRequest, + TransactionOptions, +) + +from google.cloud.spannerlib.internal.spannerlib import check_error, get_lib +from google.cloud.spannerlib.internal.types import ( + serialized_bytes_to_go_slice, + to_bytes, +) +from google.cloud.spannerlib.library_object import AbstractLibraryObject +from google.cloud.spannerlib.rows import Rows + +logger = logging.getLogger(__name__) + + +class Connection(AbstractLibraryObject): + """Represents a single connection to the Spanner database. + + This class wraps the connection handle from the underlying Go library, + providing methods to manage the connection lifecycle. + """ + + def __init__(self, id: int, pool: Any): + """ + Initializes a Connection. + + Args: + pool: The parent Pool object. + id: The pinner ID for this object in the Go library. + id: The connection ID from the Go library. + """ + super().__init__(id) + self._pool = pool + self._closed = False + logger.debug( + f"Connection ID: {self.id} initialized for pool ID: {self.pool.id}" + ) + + def __enter__(self): + """Enter the runtime context related to this object.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the runtime context related to this object, ensuring the connection is closed.""" + self.close() + + @property + def pool(self): + """Returns the parent Pool object.""" + return self._pool + + def close(self): + """Closes the connection and releases resources in the Go library. + + If the connection is already closed, this method does nothing. + It also checks if the parent pool is closed. + """ + if not self.closed: + if self.pool.closed: + logger.debug( + f"Connection ID: {self.id} implicitly closed because pool is closed." + ) + self.closed = True + return + + logger.info( + f"Closing connection ID: {self.id} for pool ID: {self.pool.id}" + ) + # Call the Go library function to close the connection. + ret = get_lib().CloseConnection(self.pool.id, self.id) + check_error(ret, "CloseConnection") + # Release the pinner ID in the Go library. + self._release() + logger.info(f"Connection ID: {self.id} closed") + + def execute(self, request: ExecuteSqlRequest) -> Rows: + """Executes a SQL statement on the connection. + + Args: + sql: The SQL statement to execute. + + Returns: + A Rows object representing the result of the execution. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info( + f"Executing SQL on connection ID: {self.id} for pool ID: {self.pool.id}" + ) + + request_slice = serialized_bytes_to_go_slice( + ExecuteSqlRequest.serialize(request) + ) + + # Call the Go library function to execute the SQL statement. + ret = get_lib().Execute( + self.pool.id, + self.id, + request_slice, + ) + check_error(ret, "Execute") + logger.info( + f"SQL execution successful on connection ID: {self.id}. Got Rows ID: {ret.object_id}" + ) + return Rows(ret.object_id, self.pool, self) + + def begin_transaction(self, options: TransactionOptions = None): + """Begins a new transaction on the connection. + + Args: + options: Optional transaction options from google.cloud.spanner_v1. + + Raises: + RuntimeError: If the connection is closed. + SpannerLibraryError: If the Go library call fails. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info( + f"Beginning transaction on connection ID: {self.id} for pool ID: {self.pool.id}" + ) + + if options is None: + options = TransactionOptions() + + options_slice = serialized_bytes_to_go_slice( + TransactionOptions.serialize(options) + ) + + ret = get_lib().BeginTransaction(self.pool.id, self.id, options_slice) + check_error(ret, "BeginTransaction") + + logger.info(f"Transaction started on connection ID: {self.id}") + + def commit(self) -> CommitResponse: + """Commits the transaction. + + Raises: + RuntimeError: If the connection is closed. + SpannerLibraryError: If the Go library call fails. + + Returns: + A CommitResponse object. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info(f"Committing on connection ID: {self.id}") + ret = get_lib().Commit(self.pool.id, self.id) + check_error(ret, "Commit") + logger.info("Committed") + response_bytes = to_bytes(ret.msg, ret.msg_len) + return CommitResponse.deserialize(response_bytes) + + def rollback(self): + """Rolls back the transaction. + + Raises: + RuntimeError: If the connection is closed. + SpannerLibraryError: If the Go library call fails. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info(f"Rolling back on connection ID: {self.id}") + ret = get_lib().Rollback(self.pool.id, self.id) + check_error(ret, "Rollback") + logger.info("Rolled back") + + def execute_batch( + self, request: ExecuteBatchDmlRequest + ) -> ExecuteBatchDmlResponse: + """Executes a batch of DML statements on the connection. + + Args: + request: The ExecuteBatchDmlRequest object. + + Returns: + An ExecuteBatchDmlResponse object representing the result of the execution. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info( + f"Executing batch DML on connection ID: {self.id} for pool ID: {self.pool.id}" + ) + + request_slice = serialized_bytes_to_go_slice( + ExecuteBatchDmlRequest.serialize(request) + ) + + # Call the Go library function to execute the batch DML statement. + ret = get_lib().ExecuteBatch( + self.pool.id, + self.id, + request_slice, + ) + check_error(ret, "ExecuteBatch") + logger.info( + f"Batch DML execution successful on connection ID: {self.id}." + ) + response_bytes = to_bytes(ret.msg, ret.msg_len) + return ExecuteBatchDmlResponse.deserialize(response_bytes) + + def write_mutations( + self, request: BatchWriteRequest.MutationGroup + ) -> CommitResponse: + """Writes a mutation to the connection. + + Args: + request: The BatchWriteRequest_MutationGroup object. + + Returns: + A CommitResponse object. + """ + if self.closed: + raise RuntimeError("Connection is closed.") + + logger.info( + f"Writing mutation on connection ID: {self.id} for pool ID: {self.pool.id}" + ) + + request_slice = serialized_bytes_to_go_slice( + BatchWriteRequest.MutationGroup.serialize(request) + ) + + # Call the Go library function to write the mutation. + ret = get_lib().WriteMutations( + self.pool.id, + self.id, + request_slice, + ) + check_error(ret, "WriteMutations") + logger.info(f"Mutation write successful on connection ID: {self.id}.") + response_bytes = to_bytes(ret.msg, ret.msg_len) + return CommitResponse.deserialize(response_bytes) diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/errors.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/errors.py new file mode 100644 index 00000000..a561968f --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/errors.py @@ -0,0 +1,31 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom error types for the spannerlib package.""" + +from __future__ import absolute_import + +from google.cloud.spannerlib.internal import SpannerError + + +class SpannerPoolError(SpannerError): + """Error related to Pool operations.""" + + pass + + +class SpannerConnectionError(SpannerError): + """Error related to Connection operations.""" + + pass diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/__init__.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/__init__.py new file mode 100644 index 00000000..c15d7daf --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal module for the spannerlib package.""" +from __future__ import absolute_import + +from google.cloud.spannerlib.internal.errors import ( + SpannerError, + SpannerLibError, +) +from google.cloud.spannerlib.internal.spannerlib import check_error, get_lib +from google.cloud.spannerlib.internal.types import ( + GoReturn, + GoString, + to_go_string, +) + +__all__ = [ + "check_error", + "get_lib", + "to_go_string", + "GoString", + "GoReturn", + "SpannerError", + "SpannerLibError", +] diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/_helper.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/_helper.py new file mode 100644 index 00000000..f533a505 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/_helper.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +import ctypes +import logging + +from google.cloud.spannerlib.internal.types import GoReturn, GoSlice, GoString + +logger = logging.getLogger(__name__) + + +def log_go_string(go_string: GoString): + """Helper function to logger.debug the contents of a GoString for debugging.""" + logger.debug(f"GoString Length (n): {go_string.n}") + if go_string.p: + try: + py_string = go_string.p[: go_string.n].decode("utf-8") + logger.debug(f"GoString Content (p): {py_string}") + except UnicodeDecodeError: + logger.debug( + f"GoString Content (p) as bytes: {go_string.p[:go_string.n]}" + ) + else: + logger.debug("GoString Content (p): NULL") + + +def log_go_slice(slc: GoSlice): + """Helper function to logger.debug the contents of a GoSlice for debugging.""" + logger.debug("--- GoSlice ---") + logger.debug(f" Len: {slc.len}, Cap: {slc.cap}") + logger.debug(f" Data Address: {hex(slc.data) if slc.data else 'None'}") + + # Check if the slice has any data to read + if not slc.data or slc.len == 0: + logger.debug(" Content: (empty)") + logger.debug("---------------") + return + + # Read slc.len bytes from the memory address in slc.data + content_bytes = ctypes.string_at(slc.data, slc.len) + logger.debug(f" Content (raw bytes): {content_bytes}") + + # Attempt to decode the bytes as a UTF-8 string for readability + try: + content_string = content_bytes.decode("utf-8") + logger.debug(f" Content (decoded str): '{content_string}'") + except UnicodeDecodeError: + logger.debug(" Content (decoded str): [Data is not valid UTF-8]") + + +def log_go_return(go_return: GoReturn): + """Helper function to logger.debug the contents of a GoReturn for debugging.""" + logger.debug( + f"GoReturn: pinner_id: {go_return.pinner_id}, " + f"error_code: {go_return.error_code}, " + f"object_id: {go_return.object_id}, " + f"msg_len: {go_return.msg_len}" + ) + if go_return.msg_len: + retrieved_bytes = ctypes.string_at(go_return.msg, go_return.msg_len) + retrieved_string = retrieved_bytes.decode("utf-8") + logger.debug(f"GoReturn Message: {retrieved_string}") diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/errors.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/errors.py new file mode 100644 index 00000000..c441d9ed --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/errors.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal error types for the spannerlib package.""" +from __future__ import absolute_import + + +class SpannerError(Exception): + """Base exception for spannerlib_py.""" + + pass + + +class SpannerLibError(SpannerError): + """Error related to an underlying Go library call. + + Attributes: + error_code: The error code returned by the Go library, if available. + """ + + def __init__(self, message, error_code=None): + """Initializes a SpannerLibraryError. + + Args: + message: The error message. + error_code: The optional error code from the Go library. + """ + super().__init__(message) + self.error_code = error_code diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/spannerlib.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/spannerlib.py new file mode 100644 index 00000000..132da65c --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/spannerlib.py @@ -0,0 +1,236 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +import ctypes +import logging +import os +import threading + +from google.cloud.spannerlib.internal.errors import SpannerLibError +from google.cloud.spannerlib.internal.types import GoReturn, GoSlice, GoString + +logger = logging.getLogger(__name__) + + +class Spannerlib: + _instance = None + _lib = None + _load_lock = threading.Lock() + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + @classmethod + def get_instance(cls): + if cls._instance is None: + with cls._load_lock: + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.load() + return cls._instance + + @classmethod + def get_lib_path(self): + _lib_path = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "spannerlib-artifacts/spannerlib.so", + ) + ) + return _lib_path + + def load(self): + if Spannerlib._lib is None: + _lib_path = Spannerlib.get_lib_path() + logger.info(f"Loading shared library from {_lib_path}") + try: + Spannerlib._lib = ctypes.CDLL(_lib_path) + self._setup_functions() + except OSError as e: + logger.error( + f"Failed to load shared library from {_lib_path}: {e}" + ) + Spannerlib._lib = None # Ensure _lib is None if loading failed + raise SpannerLibError(f"Failed to load shared library: {e}") + + def _setup_functions(self): + if Spannerlib._lib is None: + return + + # --- Function Definitions --- + # These are set up to match the exported functions in spannerlib.h + + # Release + Spannerlib._lib.Release.argtypes = [ctypes.c_longlong] + Spannerlib._lib.Release.restype = ctypes.c_int32 + + # CreatePool + Spannerlib._lib.CreatePool.argtypes = [GoString] + Spannerlib._lib.CreatePool.restype = GoReturn + + # ClosePool + Spannerlib._lib.ClosePool.argtypes = [ctypes.c_longlong] + Spannerlib._lib.ClosePool.restype = GoReturn + + # CreateConnection + Spannerlib._lib.CreateConnection.argtypes = [ctypes.c_longlong] + Spannerlib._lib.CreateConnection.restype = GoReturn + + # CloseConnection + Spannerlib._lib.CloseConnection.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.CloseConnection.restype = GoReturn + + # Execute + # Corresponds to: GoReturn Execute(int64_t poolId, int64_t connId, GoSlice sql); + Spannerlib._lib.Execute.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + GoSlice, + ] + Spannerlib._lib.Execute.restype = GoReturn + + # CloseRows + # Corresponds to: GoReturn CloseRows(int64_t poolId, int64_t connId, int64_t rowsId); + Spannerlib._lib.CloseRows.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.CloseRows.restype = GoReturn + + # Metadata + # Corresponds to: GoReturn Metadata(int64_t poolId, int64_t connId, int64_t rowsId); + Spannerlib._lib.Metadata.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.Metadata.restype = GoReturn + + # ResultSetStats + # Corresponds to: GoReturn ResultSetStats(int64_t poolId, int64_t connId, int64_t rowsId); + Spannerlib._lib.ResultSetStats.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.ResultSetStats.restype = GoReturn + + # BeginTransaction + # Corresponds to: GoReturn BeginTransaction(int64_t poolId, int64_t connId, GoSlice txOpts); + Spannerlib._lib.BeginTransaction.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + GoSlice, + ] + Spannerlib._lib.BeginTransaction.restype = GoReturn + + # Commit + # Corresponds to: GoReturn Commit(int64_t poolId, int64_t connId); + Spannerlib._lib.Commit.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.Commit.restype = GoReturn + + # Rollback + # Corresponds to: GoReturn Rollback(int64_t poolId, int64_t connId); + Spannerlib._lib.Rollback.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ] + Spannerlib._lib.Rollback.restype = GoReturn + + # ExecuteBatch + # Corresponds to: GoReturn ExecuteBatch(int64_t poolId, int64_t connId, GoSlice statements); + Spannerlib._lib.ExecuteBatch.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + GoSlice, + ] + Spannerlib._lib.ExecuteBatch.restype = GoReturn + + # WriteMutations + # Corresponds to: GoReturn WriteMutations(int64_t poolId, int64_t connId, GoSlice mutations); + Spannerlib._lib.WriteMutations.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + GoSlice, + ] + Spannerlib._lib.WriteMutations.restype = GoReturn + + # Next + # Corresponds to: + # GoReturn Next(int64_t poolId, int64_t connId, int64_t rowsId, int32_t numRows, int32_t encodeRowOption); + Spannerlib._lib.Next.argtypes = [ + ctypes.c_longlong, + ctypes.c_longlong, + ctypes.c_longlong, + ctypes.c_int32, + ctypes.c_int32, + ] + Spannerlib._lib.Next.restype = GoReturn + + @staticmethod + def check_error(ret: GoReturn, func_name: str): + """Checks the return value from Go functions for errors.""" + if ret.error_code != 0: + error_msg = f"{func_name} failed" + if ret.msg_len != 0: + try: + # Attempt to convert the error message from bytes + go_error_msg = ctypes.cast(ret.msg, ctypes.c_char_p).value + if go_error_msg: + error_msg += f": {go_error_msg.decode('utf-8', errors='replace')}" + except Exception as e: + error_msg += f" (Failed to decode error message: {e})" + logger.error(error_msg) + # Release the pinner_ids + if ret.pinner_id != 0: + try: + lib = Spannerlib.get_instance().lib + if lib: + lib.Release(ret.pinner_id) + except Exception as e: + logger.warning( + f"Error releasing pinnerId {ret.pinner_id}: {e}" + ) + + raise SpannerLibError(error_msg, error_code=ret.error_code) + + @property + def lib(self): + if Spannerlib._lib is None: + self.load() + return Spannerlib._lib + + +# Module-level functions to interact with the singleton +def check_error(ret: GoReturn, func_name: str): + Spannerlib.check_error(ret, func_name) + + +def get_lib(): + return Spannerlib.get_instance().lib + + +# Attempt to initialize the singleton on module load +try: + Spannerlib.get_instance() +except SpannerLibError: + logger.error("Spannerlib failed to initialize on module load.") diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/types.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/types.py new file mode 100644 index 00000000..291a0f98 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/internal/types.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CTypes definitions for interacting with the Go library.""" +from __future__ import absolute_import + +import ctypes +import logging + +logger = logging.getLogger(__name__) + + +# Define GoString structure, matching the Go layout. +class GoString(ctypes.Structure): + """Represents a Go string for C interop. + + Fields: + p: Pointer to the first byte of the string data. + n: Length of the string. + """ + + _fields_ = [("p", ctypes.c_char_p), ("n", ctypes.c_ssize_t)] + + +# Define common return structure from Go functions. +class GoReturn(ctypes.Structure): + """Represents the common return structure from Go functions. + + Fields: + pinner_id: ID for managing memory in Go (r0). + error_code: Error code, 0 for success (r1). + object_id: ID of the created object in Go, if any (r2). + msg_len: Length of the error message (r3). + msg: Pointer to the error message string, if any (r4). + """ + + _fields_ = [ + ("pinner_id", ctypes.c_longlong), # result pinnerId - r0 + ("error_code", ctypes.c_int32), # error code - r1 + ("object_id", ctypes.c_longlong), # object code - r2 + ("msg_len", ctypes.c_int32), # msg length - r3 + ("msg", ctypes.c_void_p), # msg string - r4 + ] + + +class GoSlice(ctypes.Structure): + _fields_ = [ + ("data", ctypes.c_void_p), + ("len", ctypes.c_longlong), + ("cap", ctypes.c_longlong), + ] + + +def to_go_string(s: str) -> GoString: + """Converts a Python string to a GoString. + + Args: + s: The Python string to convert. + + Returns: + GoString: A GoString instance.""" + encoded_s = s.encode("utf-8") + return GoString(encoded_s, len(encoded_s)) + + +def to_go_slice(s: str) -> GoSlice: + """Converts a Python string to a GoSlice.""" + encoded_s = s.encode("utf-8") + n = len(encoded_s) + + # Create a C-compatible mutable buffer from the bytes + # This is the memory that the GoSlice will point to. + buffer = ctypes.create_string_buffer(encoded_s) + # Create the GoSlice + return GoSlice( + data=ctypes.cast( + buffer, ctypes.c_void_p + ), # Cast the buffer to a void pointer + len=n, + cap=n, # For a new slice from a string, len and cap are the same + ) + + +def serialized_bytes_to_go_slice(serialized_bytes: bytes) -> GoSlice: + """Converts a Python string to a GoSlice.""" + slice_len = len(serialized_bytes) + go_slice = GoSlice( + data=ctypes.cast(serialized_bytes, ctypes.c_void_p), + len=slice_len, + cap=slice_len, + ) + go_slice._keepalive = serialized_bytes + return go_slice + + +def to_bytes(msg: ctypes.c_void_p, len: ctypes.c_int32) -> bytes: + """Converts shared lib msg to a bytes.""" + return ctypes.string_at(msg, len) diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/library_object.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/library_object.py new file mode 100644 index 00000000..92e1f309 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/library_object.py @@ -0,0 +1,75 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for the AbstractLibraryObject class.""" +from __future__ import absolute_import + +import logging + +from google.cloud.spannerlib.internal.spannerlib import Spannerlib + +logger = logging.getLogger(__name__) + + +class AbstractLibraryObject: + """Abstract base class for objects that are managed by the Go library. + + This class provides a common interface for releasing resources in the Go library. + """ + + def __init__(self, id): + """Initializes the AbstractLibraryObject. + + Args: + id: The ID for this library object in the Go library. + """ + self._id = id + self._closed = False + + @property + def id(self): + """Returns the ID for this library object in the Go library.""" + return self._id + + @id.setter + def id(self, value): + """Sets the ID.""" + self._id = value + + @property + def closed(self): + """Returns True if the library object is closed, False otherwise.""" + return self._closed + + @closed.setter + def closed(self, value): + """Sets the closed state of the library object.""" + self._closed = value + + def _release(self): + """Releases the object in the Go library. + + This method calls the Release function in the Go library to free the resources + associated with this object. + """ + if self._id == 0: + return + try: + lib = Spannerlib.get_instance().lib + if lib: + lib.Release(self._id) + logger.debug(f"Released {self._id}") + self._closed = True + except Exception as e: + logger.warning(f"Error releasing {self._id}: {e}") diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/pool.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/pool.py new file mode 100644 index 00000000..2d923e10 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/pool.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for the Pool class, representing a connection pool to Spanner.""" +from __future__ import absolute_import + +import logging + +from google.cloud.spannerlib.connection import Connection +from google.cloud.spannerlib.internal.errors import SpannerLibError +from google.cloud.spannerlib.internal.spannerlib import check_error, get_lib +from google.cloud.spannerlib.internal.types import to_go_string +from google.cloud.spannerlib.library_object import AbstractLibraryObject + +logger = logging.getLogger(__name__) + + +class Pool(AbstractLibraryObject): + """Manages a pool of connections to the Spanner database. + + This class wraps the connection pool handle from the underlying Go library, + providing methods to create connections and manage the pool lifecycle. + """ + + def __init__(self, id): + """ + Initializes the connection pool. + + Args: + id: The pinner ID for this object in the Go library. + """ + super().__init__(id) + + def __enter__(self): + """Enter the runtime context related to this object.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the runtime context related to this object, ensuring the pool is closed.""" + self.close() + + def close(self): + """Closes the connection pool and releases resources in the Go library. + + If the pool is already closed, this method does nothing. + """ + if not self.closed: + logger.info(f"Closing pool ID: {self.id}") + # Call the Go library function to close the pool. + ret = get_lib().ClosePool(self.id) + check_error(ret, "ClosePool") + # Release the object in the Go library. + self._release() + logger.info(f"Pool ID: {self.id} closed") + + def create_connection(self): + """ + Creates a new connection from the pool. + + Returns: + Connection: A new Connection object. + + Raises: + SpannerLibError: If the pool is closed. + """ + if self.closed: + logger.error("Attempted to create connection from a closed pool") + raise SpannerLibError("Pool is closed") + logger.debug(f"Creating connection from pool ID: {self.id}") + # Call the Go library function to create a connection. + ret = get_lib().CreateConnection(self.id) + check_error(ret, "CreateConnection") + logger.info( + f"Connection created with ID: {ret.object_id} from pool ID: {self.id}" + ) + return Connection(ret.object_id, self) + + @classmethod + def create_pool(cls, connection_string: str): + """ + Creates a new connection pool. + + Args: + connection_string (str): The connection string for the database. + + Returns: + Pool: A new Pool object. + """ + logger.info( + f"Creating pool with connection string: {connection_string}" + ) + # Call the Go library function to create a pool. + ret = get_lib().CreatePool(to_go_string(connection_string)) + check_error(ret, "CreatePool") + pool = cls(ret.object_id) + logger.info(f"Pool created with ID: {pool.id}") + return pool diff --git a/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/rows.py b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/rows.py new file mode 100644 index 00000000..6d995fee --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/google/cloud/spannerlib/rows.py @@ -0,0 +1,183 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for the Rows class, representing the result of execute statement.""" + +import ctypes +import logging +from typing import Any + +from google.cloud.spanner_v1 import ResultSetMetadata, ResultSetStats +from google.protobuf.struct_pb2 import ListValue + +from google.cloud.spannerlib.internal.spannerlib import check_error, get_lib +from google.cloud.spannerlib.library_object import AbstractLibraryObject + +logger = logging.getLogger(__name__) + + +class Rows(AbstractLibraryObject): + """Represents the result of an executed SQL statement.""" + + def __init__(self, id: int, pool: Any, conn: Any): + super().__init__(id) + self._pool = pool + self._conn = conn + + def __enter__(self): + """Enter the runtime context related to this object.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the runtime context related to this object, ensuring the rows are closed.""" + self.close() + + @property + def pool(self): + """Returns the parent Pool object.""" + return self._pool + + @property + def conn(self): + """Returns the parent Connection object.""" + return self._conn + + def close(self): + """Closes the rows and releases resources in the Go library. + + If the Rows object is already closed, this method does nothing. + It also checks if the parent pool or conn is closed. + """ + if not self.closed: + if self.pool.closed: + logger.debug( + f"Rows ID: {self.id} implicitly closed because pool is closed." + ) + self.closed = True + return + if self.conn.closed: + logger.debug( + f"Rows ID: {self.id} implicitly closed because connection is closed." + ) + self.closed = True + return + + logger.info( + f"Closing rows ID: {self.id} for pool ID: {self.pool.id} and connection ID: {self.conn.id}" + ) + # Call the Go library function to close the connection. + ret = get_lib().CloseRows(self.pool.id, self.conn.id, self.id) + check_error(ret, "CloseRows") + self.closed = True + logger.info(f"Rows ID: {self.id} closed") + # Release the pinner ID in the Go library. + self._release() + + def metadata(self) -> ResultSetMetadata: + """Retrieves the metadata for the result set. + + Returns: + ResultSetMetadata object containing the metadata. + """ + if self.closed: + raise RuntimeError("Rows object is closed.") + + logger.debug(f"Getting metadata for Rows ID: {self.id}") + ret = get_lib().Metadata(self.pool.id, self.conn.id, self.id) + check_error(ret, "Metadata") + + if ret.msg_len > 0: + try: + proto_bytes = ctypes.string_at(ret.msg, ret.msg_len) + return ResultSetMetadata.deserialize(proto_bytes) + except Exception as e: + logger.error(f"Failed to decode/parse metadata JSON: {e}") + raise RuntimeError(f"Failed to get metadata: {e}") + return ResultSetMetadata() + + def next(self) -> ListValue: + """Fetches the next row(s) from the result set. + + Returns: + The fetched row(s), likely as a list of lists or list of dicts, + depending on the JSON structure returned by the Go layer. + Returns None if no more rows are available. + + Raises: + RuntimeError: If the Rows object is closed or if parsing fails. + SpannerLibraryError: If the Go library call fails. + """ + if self.closed: + raise RuntimeError("Rows object is closed.") + + logger.debug(f"Fetching next row for Rows ID: {self.id}") + ret = get_lib().Next( + self.pool.id, + self.conn.id, + self.id, + 1, + 1, + ) + check_error(ret, "Next") + + if ret.msg_len > 0 and ret.msg: + try: + proto_bytes = ctypes.string_at(ret.msg, ret.msg_len) + next_row = ListValue() + next_row.ParseFromString(proto_bytes) + return next_row + except Exception as e: + logger.error(f"Failed to decode/parse row data JSON: {e}") + raise RuntimeError(f"Failed to get next row(s): {e}") + else: + # Assuming no message means no more rows + logger.debug("No more rows...") + return None + + def result_set_stats(self) -> ResultSetStats: + """Retrieves the result set statistics. + + Returns: + ResultSetStats object containing the statistics. + """ + if self.closed: + raise RuntimeError("Rows object is closed.") + + logger.debug(f"Getting ResultSetStats for Rows ID: {self.id}") + ret = get_lib().ResultSetStats(self.pool.id, self.conn.id, self.id) + check_error(ret, "ResultSetStats") + + if ret.msg_len > 0: + try: + proto_bytes = ctypes.string_at(ret.msg, ret.msg_len) + return ResultSetStats.deserialize(proto_bytes) + except Exception as e: + logger.error(f"Failed to decode/parse ResultSetStats JSON: {e}") + raise RuntimeError(f"Failed to get ResultSetStats: {e}") + + return ResultSetStats() + + def update_count(self) -> int: + """Retrieves the update count. + + Returns: + int representing the update count. + """ + stats = self.result_set_stats() + if stats.row_count_exact: + return stats.row_count_exact + elif stats.row_count_lower_bound: + return stats.row_count_lower_bound + + return 0 diff --git a/spannerlib/wrappers/spannerlib-python/noxfile.py b/spannerlib/wrappers/spannerlib-python/noxfile.py new file mode 100644 index 00000000..5596d9f3 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/noxfile.py @@ -0,0 +1,129 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +import nox + +DEFAULT_PYTHON_VERSION = "3.13" +PYTHON_VERSIONS = ["3.13"] + +UNIT_TEST_PYTHON_VERSIONS: List[str] = ["3.13"] +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.13"] + + +FLAKE8_VERSION = "flake8==6.1.0" +BLACK_VERSION = "black[jupyter]==23.7.0" +ISORT_VERSION = "isort==5.11.0" +LINT_PATHS = ["google", "tests", "samples", "noxfile.py"] + +STANDARD_DEPENDENCIES = [ + "google-cloud-spanner", +] + +UNIT_TEST_STANDARD_DEPENDENCIES = [ + "mock", + "asyncmock", + "pytest", + "pytest-cov", + "pytest-asyncio", +] + +SYSTEM_TEST_STANDARD_DEPENDENCIES = [ + "pytest", +] + +VERBOSE = True +MODE = "--verbose" if VERBOSE else "--quiet" + +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def format(session): + """ + Run isort to sort imports. Then run black + to format code to uniform standard. + """ + session.install(BLACK_VERSION, ISORT_VERSION) + # Use the --fss option to sort imports using strict alphabetical order. + # See https://pycqa.github.io/isort/docs/configuration/options.html#force-sort-within-sections + session.run( + "isort", + "--fss", + *LINT_PATHS, + ) + session.run( + "black", + "--line-length=80", + *LINT_PATHS, + ) + + +@nox.session +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install(FLAKE8_VERSION, BLACK_VERSION) + session.install("black", "isort") + session.run( + "flake8", + "--max-line-length=124", + *LINT_PATHS, + ) + + +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS) +def unit(session): + """Run unit tests.""" + + session.install(*STANDARD_DEPENDENCIES, *UNIT_TEST_STANDARD_DEPENDENCIES) + + # Run py.test against the unit tests. + session.run( + "py.test", + MODE, + f"--junitxml=unit_{session.python}_sponge_log.xml", + "--cov=google", + "--cov=tests/unit", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + os.path.join("tests", "unit"), + *session.posargs, + env={}, + ) + + +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system(session): + """Run system tests.""" + + session.install(*STANDARD_DEPENDENCIES, *SYSTEM_TEST_STANDARD_DEPENDENCIES) + + # Run py.test against the unit tests. + session.run( + "py.test", + MODE, + f"--junitxml=system_{session.python}_sponge_log.xml", + os.path.join("tests", "system"), + *session.posargs, + env={}, + ) diff --git a/spannerlib/wrappers/spannerlib-python/pyproject.toml b/spannerlib/wrappers/spannerlib-python/pyproject.toml new file mode 100644 index 00000000..dada4b6a --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "spannerlib" +version = "0.1.0" +authors = [ + { name="Google LLC", email="googleapis-packages@google.com" }, +] +description = "A Python wrapper for the Go spannerlib" +readme = "README.md" +license = {text = "Apache License 2.0"} +requires-python = ">=3.8" +classifiers = [ + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dependencies = [ + "google-cloud-spanner", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "nox", +] + +[tool.setuptools] + +[tool.setuptools.packages.find] +where = ["."] +include = ["google*"] diff --git a/spannerlib/wrappers/spannerlib-python/samples/_helper.py b/spannerlib/wrappers/spannerlib-python/samples/_helper.py new file mode 100644 index 00000000..a3ccf649 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/_helper.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud.spanner_v1 import ExecuteSqlRequest + +from google.cloud.spannerlib import SpannerLibError + + +def setup_env(): + # Set environment variable for Spanner Emulator if not set + if not os.environ.get("SPANNER_EMULATOR_HOST"): + os.environ["SPANNER_EMULATOR_HOST"] = "localhost:9010" + print( + f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}" + ) + + +def setup(conn): + print("\nSetting up the environment...") + try: + conn.execute(ExecuteSqlRequest(sql="DROP TABLE IF EXISTS Singers")) + print("Dropped existing Singers table.") + except SpannerLibError as e: + print(f"Error dropping table: {e}") + conn.execute( + ExecuteSqlRequest( + sql=( + "CREATE TABLE Singers " + "(SingerId INT64, FirstName STRING(1024), LastName STRING(1024)) " + "PRIMARY KEY (SingerId)" + ) + ) + ) + print("Created Singers table.") + + +def cleanup(conn): + print("\nCleaning up the environment...") + try: + conn.execute(ExecuteSqlRequest(sql="DROP TABLE IF EXISTS Singers")) + print("Dropped Singers table.") + except SpannerLibError as e: + print(f"Error dropping table: {e}") + + +def count_rows(rows) -> int: + """Counts the number of rows in the result set.""" + count = 0 + while rows.next() is not None: + count += 1 + return count + + +def format_results(metadata, rows_data): + """Formats the results as a table string.""" + if not metadata or not metadata.row_type or not metadata.row_type.fields: + return "No column information available." + + headers = [ + field.name if field.name else "-" for field in metadata.row_type.fields + ] + column_widths = [len(header) for header in headers] + + # Calculate maximum width for each column + for row in rows_data: + for i, value in enumerate(row): + column_widths[i] = max(column_widths[i], len(str(value))) + + header_line = " | ".join( + header.ljust(column_widths[i]) for i, header in enumerate(headers) + ) + separator_line = "-+-".join("-" * width for width in column_widths) + + table_lines = [header_line, separator_line] + + for row in rows_data: + row_values = [str(value) for value in row] + row_line = " | ".join( + value.ljust(column_widths[i]) for i, value in enumerate(row_values) + ) + table_lines.append(row_line) + + return "\n".join(table_lines) diff --git a/spannerlib/wrappers/spannerlib-python/samples/ddl_stats.py b/spannerlib/wrappers/spannerlib-python/samples/ddl_stats.py new file mode 100644 index 00000000..8a623b92 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/ddl_stats.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _helper import cleanup, setup, setup_env +from google.cloud.spanner_v1 import ExecuteSqlRequest # noqa: E402 + +from google.cloud.spannerlib import Pool, SpannerLibError # noqa: E402 + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) + + +def run_dml_stats_sample(test_connection_string): + try: + pool = Pool.create_pool(test_connection_string) + print(f"Successfully created pool with ID: {pool.id}") + + print("Attempting to create a connection from the pool...") + conn = None # Initialize conn to None + try: + conn = pool.create_connection() + print(f"Successfully created connection with ID: {conn.id}") + + setup(conn) + + print("\nAttempting to execute an INSERT statement...") + insert_sql = "INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')" + print(f"Executing SQL: {insert_sql}") + insert_request = ExecuteSqlRequest(sql=insert_sql) + insert_rows = conn.execute(insert_request) + stats = insert_rows.result_set_stats() + if stats: + print(f"Insert count: {stats.row_count_exact}") + insert_rows.close() + + print("\nAttempting to execute an UPDATE statement...") + update_sql = ( + "UPDATE Singers SET LastName = 'Richardson' WHERE SingerId = 1" + ) + print(f"Executing SQL: {update_sql}") + update_request = ExecuteSqlRequest(sql=update_sql) + update_rows = conn.execute(update_request) + stats = update_rows.result_set_stats() + if stats: + print(f"Update count: {stats.row_count_exact}") + update_rows.close() + + print("\nAttempting to execute a DELETE statement...") + delete_sql = "DELETE FROM Singers WHERE SingerId = 1" # noqa: E501 + print(f"Executing SQL: {delete_sql}") + delete_request = ExecuteSqlRequest(sql=delete_sql) + delete_rows = conn.execute(delete_request) + stats = delete_rows.result_set_stats() + if stats: + print(f"Delete count: {stats.row_count_exact}") + delete_rows.close() + + except SpannerLibError as e: + print(f"Error during DML operations: {e}") + finally: + if conn: + cleanup(conn) + conn.close() + print("\nConnection closed.") + + print("Closing the pool...") + pool.close() + print("Pool closed.") + + except SpannerLibError as e: + print(f"Error creating pool: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + setup_env() + run_dml_stats_sample(EMULATOR_TEST_CONNECTION_STRING) diff --git a/spannerlib/wrappers/spannerlib-python/samples/execute_batch.py b/spannerlib/wrappers/spannerlib-python/samples/execute_batch.py new file mode 100644 index 00000000..8824aea8 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/execute_batch.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _helper import cleanup, count_rows, setup, setup_env +from google.cloud.spanner_v1 import ExecuteBatchDmlRequest, ExecuteSqlRequest + +from google.cloud.spannerlib import Pool, SpannerLibError + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) + + +def run_execute_batch_sample(test_connection_string): + try: + pool = Pool.create_pool(test_connection_string) + print(f"Successfully created pool with ID: {pool.id}") + + print("Attempting to create a connection from the pool...") + conn = None # Initialize conn to None + try: + conn = pool.create_connection() + print(f"Successfully created connection with ID: {conn.id}") + + setup(conn) + + print("\nAttempting to run ExecuteBatchDmlRequest...") + try: + statements = [ + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (100, 'Batch', 'User1')" + ), + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (101, 'Batch', 'User2')" + ), + ] + request = ExecuteBatchDmlRequest(statements=statements) + + response = conn.execute_batch(request) + print(f"ExecuteBatchDmlResponse: {response}") + + if response.status.code == 0: + print("Batch DML executed successfully.") + for i, result_set in enumerate(response.result_sets): + print( + f" Statement {i}: Rows affected: {result_set.stats.row_count_exact}" + ) + else: + print(f"Batch DML failed with status: {response.status}") + + except SpannerLibError as e: + print(f"Error during execute_batch: {e}") + + # Verify + rows = conn.execute( + ExecuteSqlRequest( + sql="SELECT * FROM Singers WHERE SingerId >= 100" + ) + ) + print( + "\nVerification Query: SELECT * FROM Singers WHERE SingerId >= 100" + ) + print(f"Found {count_rows(rows)} rows.") + rows.close() + + print("\nAttempting to run ExecuteBatchDmlRequest with an error...") + try: + statements = [ + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (200, 'Good', 'Batch')" + ), + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO NonExistentTable (Id) VALUES (1)" # This will fail + ), + ] + request = ExecuteBatchDmlRequest(statements=statements) + + conn.execute_batch(request) + except SpannerLibError as e: + print(f"Error during execute_batch as expected: {e}") + + # Verify the transaction was rolled back + rows = conn.execute( + ExecuteSqlRequest( + sql="SELECT * FROM Singers WHERE SingerId >= 200" + ) + ) + print( + "\nVerification Query: SELECT * FROM Singers WHERE SingerId >= 200" + ) + print(f"Found {count_rows(rows)} rows (should be 0).") + rows.close() + + except SpannerLibError as e: + print(f"Error creating or using connection: {e}") + finally: + if conn: + cleanup(conn) + conn.close() + print("\nConnection closed.") + + print("Closing the pool...") + pool.close() + print("Pool closed.") + + except SpannerLibError as e: + print(f"Error creating pool: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + setup_env() + run_execute_batch_sample(EMULATOR_TEST_CONNECTION_STRING) diff --git a/spannerlib/wrappers/spannerlib-python/samples/quickstart.py b/spannerlib/wrappers/spannerlib-python/samples/quickstart.py new file mode 100644 index 00000000..0eeec218 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/quickstart.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from _helper import format_results +from google.cloud.spanner_v1 import ExecuteSqlRequest # noqa: E402 + +from google.cloud.spannerlib import Pool, SpannerLibError # noqa: E402 + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) + + +def setup_env(): + # Set environment variable for Spanner Emulator if not set + if not os.environ.get("SPANNER_EMULATOR_HOST"): + os.environ["SPANNER_EMULATOR_HOST"] = "localhost:9010" + print( + f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}" + ) + + +def run_quickstart(test_connection_string): + try: + pool = Pool.create_pool(test_connection_string) + print(f"Successfully created pool with ID: {pool.id}") + + print("Attempting to create a connection from the pool...") + try: + with pool.create_connection() as conn: + print(f"Successfully created connection with ID: {conn.id}") + print("Connection test successful!") + print("Attempting to execute a statement on the connection...") + + sql = "SELECT 1 as one, 'hello' as greeting;" + print(f"Executing SQL: {sql}") + request = ExecuteSqlRequest(sql=sql) + + rows = conn.execute(request) + + metadata = rows.metadata() + + rows_data = [] + row = rows.next() + while row is not None: + rows_data.append(row) + row = rows.next() + + if rows_data: + print("\nResults:") + print(format_results(metadata, rows_data)) + print("") + else: + print("No rows returned.") + + rows.close() + except SpannerLibError as e: + print(f"Error creating or using connection: {e}") + + print("Closing the pool...") + pool.close() + print("Pool closed.") + + except SpannerLibError as e: + print(f"Error creating pool: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + run_quickstart(EMULATOR_TEST_CONNECTION_STRING) diff --git a/spannerlib/wrappers/spannerlib-python/samples/requirements.txt b/spannerlib/wrappers/spannerlib-python/samples/requirements.txt new file mode 100644 index 00000000..58cf3064 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/requirements.txt @@ -0,0 +1,2 @@ +google-cloud-spanner==3.57.0 +futures==3.4.0; python_version < "3" diff --git a/spannerlib/wrappers/spannerlib-python/samples/transactions.py b/spannerlib/wrappers/spannerlib-python/samples/transactions.py new file mode 100644 index 00000000..684c92d8 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/transactions.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _helper import cleanup, count_rows, setup, setup_env +from google.cloud.spanner_v1 import ExecuteSqlRequest + +from google.cloud.spannerlib import Pool, SpannerLibError + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) + + +def run_transaction_sample(test_connection_string): + try: + pool = Pool.create_pool(test_connection_string) + print(f"Successfully created pool with ID: {pool.id}") + + print("Attempting to create a connection from the pool...") + conn = None # Initialize conn to None + try: + conn = pool.create_connection() + print(f"Successfully created connection with ID: {conn.id}") + + setup(conn) + + print("\nAttempting to run a transaction | BEGIN/ROLLBACK...") + try: + conn.begin_transaction() + print("Transaction started.") + conn.execute( + ExecuteSqlRequest( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Catalina', 'Smith')" + ) + ) + conn.rollback() + print("Transaction rolled back.") + except SpannerLibError as e: + print(f"Error during transaction: {e}") + + # Verify the transaction + rows = conn.execute(ExecuteSqlRequest(sql="SELECT * FROM Singers")) + print("\nVerification Query: SELECT * FROM Singers") + print(f"Found {count_rows(rows)} rows.") + rows.close() + + print("\nAttempting to run a transaction | BEGIN/COMMIT...") + try: + conn.begin_transaction() + print("Transaction started.") + conn.execute( + ExecuteSqlRequest( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (2, 'Catalina', 'Smith')" + ) + ) + print("Executed first INSERT.") + conn.execute( + ExecuteSqlRequest( + sql="INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, 'Alice', 'Trentor')" + ) + ) + print("Executed second INSERT.") + commit_response = conn.commit() + print(f"Transaction committed successfully: {commit_response}") + except SpannerLibError as e: + print(f"Error during transaction: {e}") + + # Verify the transaction + rows = conn.execute(ExecuteSqlRequest(sql="SELECT * FROM Singers")) + print("\nVerification Query: SELECT * FROM Singers") + print(f"Found {count_rows(rows)} rows.") + rows.close() + + except SpannerLibError as e: + print(f"Error creating or using connection: {e}") + finally: + if conn: + cleanup(conn) + conn.close() + print("\nConnection closed.") + + print("Closing the pool...") + pool.close() + print("Pool closed.") + + except SpannerLibError as e: + print(f"Error creating pool: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + setup_env() + run_transaction_sample(EMULATOR_TEST_CONNECTION_STRING) diff --git a/spannerlib/wrappers/spannerlib-python/samples/write_mutation.py b/spannerlib/wrappers/spannerlib-python/samples/write_mutation.py new file mode 100644 index 00000000..059cf288 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/samples/write_mutation.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _helper import cleanup, count_rows, setup, setup_env +from google.cloud.spanner_v1 import ( + BatchWriteRequest, + ExecuteSqlRequest, + Mutation, +) +from google.protobuf.struct_pb2 import ListValue, Value + +from google.cloud.spannerlib import Pool, SpannerLibError + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) + + +def run_write_mutation_sample(test_connection_string): + try: + pool = Pool.create_pool(test_connection_string) + print(f"Successfully created pool with ID: {pool.id}") + + print("Attempting to create a connection from the pool...") + conn = None # Initialize conn to None + try: + conn = pool.create_connection() + print(f"Successfully created connection with ID: {conn.id}") + + setup(conn) + + print("\nAttempting to run write_mutation...") + try: + mutation = Mutation( + insert=Mutation.Write( + table="Singers", + columns=["SingerId", "FirstName", "LastName"], + values=[ + ListValue( + values=[ + Value(string_value="200"), + Value(string_value="Mutation"), + Value(string_value="User1"), + ] + ) + ], + ) + ) + mutation_group = BatchWriteRequest.MutationGroup( + mutations=[mutation] + ) + + response = conn.write_mutations(mutation_group) + print(f"write_mutations response: {response}") + + except SpannerLibError as e: + print(f"Error during write_mutations: {e}") + + # Verify + rows = conn.execute( + ExecuteSqlRequest( + sql="SELECT * FROM Singers WHERE SingerId >= 200" + ) + ) + print( + "\nVerification Query: SELECT * FROM Singers WHERE SingerId >= 200" + ) + print(f"Found {count_rows(rows)} rows.") + rows.close() + + except SpannerLibError as e: + print(f"Error creating or using connection: {e}") + finally: + if conn: + cleanup(conn) + conn.close() + print("\nConnection closed.") + + print("Closing the pool...") + pool.close() + print("Pool closed.") + + except SpannerLibError as e: + print(f"Error creating pool: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + setup_env() + run_write_mutation_sample(EMULATOR_TEST_CONNECTION_STRING) diff --git a/spannerlib/wrappers/spannerlib-python/setup.cfg b/spannerlib/wrappers/spannerlib-python/setup.cfg new file mode 100644 index 00000000..42af1cfe --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/setup.cfg @@ -0,0 +1,2 @@ +[options] +include_package_data = True diff --git a/spannerlib/wrappers/spannerlib-python/tests/__init__.py b/spannerlib/wrappers/spannerlib-python/tests/__init__.py new file mode 100644 index 00000000..38e805ce --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/spannerlib/wrappers/spannerlib-python/tests/system/__init__.py b/spannerlib/wrappers/spannerlib-python/tests/system/__init__.py new file mode 100644 index 00000000..38e805ce --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/system/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/spannerlib/wrappers/spannerlib-python/tests/system/_helper.py b/spannerlib/wrappers/spannerlib-python/tests/system/_helper.py new file mode 100644 index 00000000..4783d6fb --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/system/_helper.py @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +TEST_ON_PROD = False + +EMULATOR_TEST_CONNECTION_STRING = ( + "localhost:9010" + "/projects/test-project" + "/instances/test-instance" + "/databases/testdb" + "?autoConfigEmulator=true" +) +PROD_TEST_CONNECTION_STRING = ( + "projects/span-cloud-testing/instances/asapha-test/databases/testdb" +) + +TEST_CONNECTION_STRING = ( + PROD_TEST_CONNECTION_STRING + if TEST_ON_PROD + else EMULATOR_TEST_CONNECTION_STRING +) + + +def setup_test_env(): + if not TEST_ON_PROD: + # Set environment variable for Spanner Emulator + os.environ["SPANNER_EMULATOR_HOST"] = "localhost:9010" + print( + f"Set SPANNER_EMULATOR_HOST to {os.environ['SPANNER_EMULATOR_HOST']}" + ) + print(f"Using Connection String: {get_test_connection_string()}") + + +def get_test_connection_string(): + return ( + PROD_TEST_CONNECTION_STRING + if TEST_ON_PROD + else EMULATOR_TEST_CONNECTION_STRING + ) diff --git a/spannerlib/wrappers/spannerlib-python/tests/system/test_connection.py b/spannerlib/wrappers/spannerlib-python/tests/system/test_connection.py new file mode 100644 index 00000000..f1c2b6a4 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/system/test_connection.py @@ -0,0 +1,232 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +import os +import sys +import unittest + +from ._helper import get_test_connection_string, setup_test_env + +# Adjust path to import from src +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +) + +from google.cloud.spanner_v1 import BatchWriteRequest # noqa: E402 +from google.cloud.spanner_v1 import CommitResponse # noqa: E402 +from google.cloud.spanner_v1 import ExecuteBatchDmlRequest # noqa: E402 +from google.cloud.spanner_v1 import ExecuteSqlRequest # noqa: E402 +from google.cloud.spanner_v1 import Mutation # noqa: E402 +from google.protobuf.struct_pb2 import ListValue, Value # noqa: E402 + +from google.cloud.spannerlib import Pool, SpannerLibError # noqa: E402 +from google.cloud.spannerlib.rows import Rows # noqa: E402 + + +class TestConnectionE2E(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_test_env() + print(f"Using Connection String: {get_test_connection_string()}") + try: + with Pool.create_pool(get_test_connection_string()) as pool: + with pool.create_connection() as conn: + try: + conn.execute( + ExecuteSqlRequest(sql="DROP TABLE test_table") + ) + except Exception: + pass # Ignore error if table doesn't exist + conn.execute( + ExecuteSqlRequest( + sql="CREATE TABLE test_table (id INT64, name STRING(MAX)) PRIMARY KEY (id)" + ) + ) + except Exception as e: + print(f"Error in setUpClass: {e}") + raise + + def setUp(self): + self.pool = Pool.create_pool(get_test_connection_string()) + self.conn = self.pool.create_connection() + # Clean up the table before each test + try: + self.conn.execute( + ExecuteSqlRequest(sql="DELETE FROM test_table WHERE TRUE") + ) + except Exception as e: + print(f"Error in setUp: {e}") + raise + + def tearDown(self): + if self.conn: + self.conn.close() + if self.pool: + self.pool.close() + + def test_execute_query(self): + """Test ExecuteSqlRequest with a SELECT statement.""" + request = ExecuteSqlRequest(sql="SELECT 1") + rows = self.conn.execute(request) + self.assertIsInstance(rows, Rows) + row = rows.next() + self.assertIsNotNone(row) + self.assertEqual(row.values[0].string_value, "1") + self.assertIsNone(rows.next()) # No more rows + rows.close() + + def test_transaction_commit(self): + """Test begin_transaction, execute, and commit.""" + self.conn.begin_transaction() + + insert_request = ExecuteSqlRequest( + sql="INSERT INTO test_table (id, name) VALUES (1, 'Test User')" + ) + self.conn.execute(insert_request) + + commit_response = self.conn.commit() + self.assertIsInstance(commit_response, CommitResponse) + + # Verify the insert + select_request = ExecuteSqlRequest( + sql="SELECT name FROM test_table WHERE id = 1" + ) + rows = self.conn.execute(select_request) + row = rows.next() + self.assertEqual(row.values[0].string_value, "Test User") + rows.close() + + def test_transaction_rollback(self): + """Test begin_transaction, execute, and rollback.""" + self.conn.begin_transaction() + + insert_request = ExecuteSqlRequest( + sql="INSERT INTO test_table (id, name) VALUES (2, 'Rollback User')" + ) + self.conn.execute(insert_request) + + self.conn.rollback() + + # Verify the insert was rolled back + select_request = ExecuteSqlRequest( + sql="SELECT name FROM test_table WHERE id = 2" + ) + rows = self.conn.execute(select_request) + self.assertIsNone(rows.next()) + rows.close() + + def test_execute_batch_dml(self): + """Test ExecuteBatchDmlRequest with INSERT statements.""" + statements = [ + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO test_table (id, name) VALUES (10, 'Batch User 1')" + ), + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO test_table (id, name) VALUES (11, 'Batch User 2')" + ), + ] + request = ExecuteBatchDmlRequest(statements=statements) + + response = self.conn.execute_batch(request) + self.assertIsNotNone(response) + self.assertEqual(len(response.result_sets), 2) + self.assertEqual(response.status.code, 0) # OK + + for i, result_set in enumerate(response.result_sets): + self.assertEqual(result_set.stats.row_count_exact, 1) + + # Verify the inserts + select_request = ExecuteSqlRequest( + sql="SELECT name FROM test_table WHERE id IN (10, 11) ORDER BY id" + ) + rows = self.conn.execute(select_request) + row = rows.next() + self.assertEqual(row.values[0].string_value, "Batch User 1") + row = rows.next() + self.assertEqual(row.values[0].string_value, "Batch User 2") + self.assertIsNone(rows.next()) + rows.close() + + def test_execute_batch_dml_with_error(self): + """Test ExecuteBatchDmlRequest with a statement that causes an error.""" + statements = [ + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO test_table (id, name) VALUES (20, 'Good Batch')" + ), + ExecuteBatchDmlRequest.Statement( + sql="INSERT INTO non_existent_table (id, name) VALUES (21, 'Bad Batch')" + ), + ] + request = ExecuteBatchDmlRequest(statements=statements) + + with self.assertRaises(SpannerLibError) as cm: + self.conn.execute_batch(request) + + self.assertIn("non_existent_table", str(cm.exception)) + + # The first statement should have been rolled back + select_request = ExecuteSqlRequest( + sql="SELECT name FROM test_table WHERE id = 20" + ) + rows = self.conn.execute(select_request) + self.assertIsNone(rows.next()) + rows.close() + + def test_write_mutations(self): + """Test write_mutation with an INSERT statement.""" + mutation = Mutation( + insert=Mutation.Write( + table="test_table", + columns=["id", "name"], + values=[ + ListValue( + values=[ + Value(string_value="30"), + Value(string_value="Mutation User"), + ] + ) + ], + ) + ) + mutation_group = BatchWriteRequest.MutationGroup(mutations=[mutation]) + + response = self.conn.write_mutations(mutation_group) + self.assertIsInstance(response, CommitResponse) + self.assertIsNotNone(response.commit_timestamp) + + # Verify the insert + select_request = ExecuteSqlRequest( + sql="SELECT name FROM test_table WHERE id = 30" + ) + rows = self.conn.execute(select_request) + row = rows.next() + self.assertEqual(row.values[0].string_value, "Mutation User") + rows.close() + + def test_write_mutations_empty(self): + """Test write_mutation with an empty mutation group.""" + mutation_group = BatchWriteRequest.MutationGroup(mutations=[]) + # This should not raise an error, but the behavior might be a no-op. + # Depending on the Go library, this might return an error. + # For now, just check if it runs without exceptions. + try: + response = self.conn.write_mutations(mutation_group) + self.assertIsInstance(response, CommitResponse) + except SpannerLibError as e: + self.fail(f"write_mutations with empty mutations failed: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/spannerlib/wrappers/spannerlib-python/tests/system/test_pool.py b/spannerlib/wrappers/spannerlib-python/tests/system/test_pool.py new file mode 100644 index 00000000..c1ded49f --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/system/test_pool.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +import os +import sys +import unittest + +from ._helper import get_test_connection_string, setup_test_env + +# Adjust path to import from src +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +) + +from google.cloud.spannerlib import Connection, Pool # noqa: E402 +from google.cloud.spannerlib import SpannerLibError # noqa: E402 + +# To run these E2E tests against a Cloud Spanner Emulator: +# 1. Start the emulator: gcloud emulators spanner start +# docker pull gcr.io/cloud-spanner-emulator/emulator +# docker run -p 9010:9010 -p 9020:9020 -d gcr.io/cloud-spanner-emulator/emulator +# 2. Set the environment variable: export SPANNER_EMULATOR_HOST=localhost:9010 +# 3. Create a test instance and database in the emulator: +# gcloud spanner instances create test-instance --config=emulator-config --description="Test Instance" --nodes=1 +# gcloud spanner databases create testdb --instance=test-instance +# 4. Run the tests: python3 -m unittest src/tests/e2e/test_spannerlib_wrapper.py +# +# You can also override the connection string by setting SPANNER_CONNECTION_STRING. + + +class TestPoolE2E(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_test_env() + print(f"Using Connection String: {get_test_connection_string()}") + + def test_pool_creation_and_close(self): + """Test basic pool creation and explicit close.""" + pool = Pool.create_pool(get_test_connection_string()) + self.assertIsNotNone(pool.id, "Pool ID should not be None") + self.assertFalse(pool.closed, "Pool should not be closed initially") + pool.close() + self.assertTrue(pool.closed, "Pool should be closed") + # Test closing again is safe + pool.close() + self.assertTrue(pool.closed, "Pool should remain closed") + + def test_pool_context_manager(self): + """Test pool creation and closure using a context manager.""" + with Pool.create_pool(get_test_connection_string()) as pool: + self.assertIsNotNone(pool.id) + self.assertFalse(pool.closed) + self.assertTrue( + pool.closed, "Pool should be closed after exiting with block" + ) + + def test_create_connection_success(self): + """Test creating a connection from an open pool.""" + with Pool.create_pool(get_test_connection_string()) as pool: + conn = pool.create_connection() + self.assertIsInstance(conn, Connection) + self.assertIsNotNone(conn.id) + self.assertFalse(conn.closed) + conn.close() + self.assertTrue(conn.closed) + + def test_create_connection_context_manager(self): + """Test creating a connection using a context manager.""" + with Pool.create_pool(get_test_connection_string()) as pool: + with pool.create_connection() as conn: + self.assertIsInstance(conn, Connection) + self.assertFalse(conn.closed) + self.assertTrue( + conn.closed, + "Connection should be closed after exiting with block", + ) + + def test_create_connection_from_closed_pool(self): + """Test creating a connection from a closed pool raises an error.""" + pool = Pool.create_pool(get_test_connection_string()) + pool.close() + with self.assertRaises(SpannerLibError): + pool.create_connection() + + +if __name__ == "__main__": + print( + "Running Pool E2E tests... This requires a live Spanner instance or Emulator." + ) + unittest.main() diff --git a/spannerlib/wrappers/spannerlib-python/tests/system/test_rows.py b/spannerlib/wrappers/spannerlib-python/tests/system/test_rows.py new file mode 100644 index 00000000..8e8a6437 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/system/test_rows.py @@ -0,0 +1,148 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +import os +import sys +import unittest + +from ._helper import get_test_connection_string, setup_test_env + +# Adjust path to import from src +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +) + +from google.cloud.spanner_v1 import ExecuteSqlRequest # noqa: E402 +from google.cloud.spanner_v1 import ResultSetMetadata # noqa: E402 + +from google.cloud.spannerlib import Pool # noqa: E402 +from google.cloud.spannerlib.rows import Rows # noqa: E402 + + +class TestRowsE2E(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_test_env() + print(f"Using Connection String: {get_test_connection_string()}") + try: + with Pool.create_pool(get_test_connection_string()) as pool: + with pool.create_connection() as conn: + try: + conn.execute( + ExecuteSqlRequest(sql="DROP TABLE rows_test") + ) + except Exception: + pass # Ignore error if table doesn't exist + conn.execute( + ExecuteSqlRequest( + sql="CREATE TABLE rows_test (id INT64, name STRING(MAX)) PRIMARY KEY (id)" + ) + ) + conn.execute( + ExecuteSqlRequest( + sql="INSERT INTO rows_test (id, name) VALUES (1, 'One'), (2, 'Two')" + ) + ) + except Exception as e: + print(f"Error in setUpClass: {e}") + raise + + def setUp(self): + self.pool = Pool.create_pool(get_test_connection_string()) + self.conn = self.pool.create_connection() + + def tearDown(self): + if self.conn: + self.conn.close() + if self.pool: + self.pool.close() + + def test_rows_iteration(self): + """Test iterating through rows using next().""" + with self.conn.execute( + ExecuteSqlRequest(sql="SELECT id, name FROM rows_test ORDER BY id") + ) as rows: + self.assertIsInstance(rows, Rows) + + row1 = rows.next() + self.assertIsNotNone(row1) + self.assertEqual(row1.values[0].string_value, "1") + self.assertEqual(row1.values[1].string_value, "One") + + row2 = rows.next() + self.assertIsNotNone(row2) + self.assertEqual(row2.values[0].string_value, "2") + self.assertEqual(row2.values[1].string_value, "Two") + + self.assertIsNone(rows.next()) # No more rows + + def test_rows_metadata(self): + """Test retrieving metadata.""" + with self.conn.execute( + ExecuteSqlRequest(sql="SELECT id, name FROM rows_test") + ) as rows: + metadata = rows.metadata() + self.assertIsInstance(metadata, ResultSetMetadata) + self.assertEqual(len(metadata.row_type.fields), 2) + self.assertEqual(metadata.row_type.fields[0].name, "id") + self.assertEqual(metadata.row_type.fields[0].type.code, 2) # INT64 + self.assertEqual(metadata.row_type.fields[1].name, "name") + self.assertEqual(metadata.row_type.fields[1].type.code, 6) # STRING + + def test_rows_context_manager(self): + """Test that rows are closed when exiting context manager.""" + rows = self.conn.execute(ExecuteSqlRequest(sql="SELECT 1")) + with rows: + self.assertFalse(rows.closed) + self.assertTrue(rows.closed) + + def test_rows_stats_select(self): + """Test ResultSetStats for a SELECT statement.""" + with self.conn.execute( + ExecuteSqlRequest(sql="SELECT id, name FROM rows_test") + ) as rows: + stats = rows.result_set_stats() + # Stats are not typically populated for SELECT in the same way as DML + self.assertIsNotNone(stats) + + def test_ddl_update_count(self): + """Test update_count for DDL.""" + with self.conn.execute( + ExecuteSqlRequest( + sql="CREATE TABLE dummy (id INT64) PRIMARY KEY (id)" + ) + ) as rows: + self.assertEqual(rows.update_count(), 0) + with self.conn.execute( + ExecuteSqlRequest(sql="DROP TABLE dummy") + ) as rows: + self.assertEqual(rows.update_count(), 0) + + def test_dml_update_count(self): + """Test update_count for DML.""" + with self.conn.execute( + ExecuteSqlRequest( + sql="INSERT INTO rows_test (id, name) VALUES (100, 'DML Test')" + ) + ) as rows: + self.assertEqual(rows.update_count(), 1) + # Clean up + self.conn.execute( + ExecuteSqlRequest(sql="DELETE FROM rows_test WHERE id = 100") + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/spannerlib/wrappers/spannerlib-python/tests/unit/__init__.py b/spannerlib/wrappers/spannerlib-python/tests/unit/__init__.py new file mode 100644 index 00000000..38e805ce --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/spannerlib/wrappers/spannerlib-python/tests/unit/test_connection.py b/spannerlib/wrappers/spannerlib-python/tests/unit/test_connection.py new file mode 100644 index 00000000..4382ead9 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/unit/test_connection.py @@ -0,0 +1,303 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Connection class.""" +from __future__ import absolute_import + +import ctypes +import unittest +from unittest.mock import MagicMock, patch + +from google.cloud.spanner_v1 import ( + BatchWriteRequest, + CommitResponse, + ExecuteBatchDmlRequest, + ExecuteSqlRequest, +) + +from google.cloud.spannerlib import Connection, Rows, SpannerLibError +from google.cloud.spannerlib.internal import GoReturn + + +class TestConnection(unittest.TestCase): + """Unit tests for the Connection class.""" + + def setUp(self): + """Set up the test environment.""" + self.mock_pool = MagicMock() + self.mock_pool.id = 1 + self.mock_pool.closed = False + + self.conn = Connection(id=123, pool=self.mock_pool) + self.mock_lib = MagicMock() + + def tearDown(self): + """Tear down the test environment.""" + if not self.conn.closed: + try: + # Minimal mock to avoid errors in close + with patch( + "google.cloud.spannerlib.connection.get_lib" + ) as mock_get_lib: + mock_lib = MagicMock() + mock_get_lib.return_value = mock_lib + mock_lib.CloseConnection.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=0, + msg=None, + ) + self.conn.close() + except SpannerLibError: + pass + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_close_success(self, mock_get_lib): + """Test the close method in case of success.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.CloseConnection.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + self.conn.close() + self.mock_lib.CloseConnection.assert_called_once_with(1, 123) + self.assertTrue(self.conn.closed) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_execute_success(self, mock_get_lib): + """Test the execute method in case of success.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Execute.return_value = GoReturn( + pinner_id=789, error_code=0, object_id=101, msg_len=0, msg=None + ) + + sql = "SELECT 1" + request = ExecuteSqlRequest(sql=sql) + rows = self.conn.execute(request) + rows.close = MagicMock() # Prevent __del__ from calling mock lib + + self.assertIsInstance(rows, Rows) + self.assertEqual(rows.id, 101) + self.assertEqual(rows._pool, self.mock_pool) + self.assertEqual(rows._conn, self.conn) + self.mock_lib.Execute.assert_called_once() + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_execute_failure(self, mock_get_lib): + """Test the execute method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Execute.return_value = GoReturn( + pinner_id=0, + error_code=1, + object_id=0, + msg_len=13, + msg=ctypes.cast(ctypes.c_char_p(b"Test error"), ctypes.c_void_p), + ) + + sql = "SELECT 1" + request = ExecuteSqlRequest(sql=sql) + with self.assertRaises(SpannerLibError): + self.conn.execute(request) + + self.mock_lib.Execute.assert_called_once() + + def test_execute_closed_connection(self): + """Test executing on a closed connection.""" + self.conn.closed = True + with self.assertRaises(RuntimeError): + sql = "SELECT 1" + request = ExecuteSqlRequest(sql=sql) + self.conn.execute(request) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_begin_transaction_success(self, mock_get_lib): + """Test the begin_transaction method in case of success.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.BeginTransaction.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + self.conn.begin_transaction() + self.mock_lib.BeginTransaction.assert_called_once() + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_begin_transaction_failure(self, mock_get_lib): + """Test the begin_transaction method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.BeginTransaction.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + + with self.assertRaises(SpannerLibError): + self.conn.begin_transaction() + self.mock_lib.BeginTransaction.assert_called_once() + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_commit_success(self, mock_get_lib): + """Test the commit method in case of success.""" + mock_get_lib.return_value = self.mock_lib + commit_response = CommitResponse() + commit_response_bytes = CommitResponse.serialize(commit_response) + + self.mock_lib.Commit.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(commit_response_bytes), + msg=ctypes.cast( + ctypes.c_char_p(commit_response_bytes), ctypes.c_void_p + ), + ) + + response = self.conn.commit() + self.assertIsInstance(response, CommitResponse) + self.mock_lib.Commit.assert_called_once_with(1, 123) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_commit_failure(self, mock_get_lib): + """Test the commit method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Commit.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + + with self.assertRaises(SpannerLibError): + self.conn.commit() + self.mock_lib.Commit.assert_called_once_with(1, 123) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_rollback_success(self, mock_get_lib): + """Test the rollback method in case of success.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Rollback.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + self.conn.rollback() + self.mock_lib.Rollback.assert_called_once() + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_rollback_failure(self, mock_get_lib): + """Test the rollback method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Rollback.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + + with self.assertRaises(SpannerLibError): + self.conn.rollback() + self.mock_lib.Rollback.assert_called_once() + + @patch("google.cloud.spannerlib.connection.ExecuteBatchDmlResponse") + @patch("google.cloud.spannerlib.connection.get_lib") + def test_execute_batch_success(self, mock_get_lib, mock_response_cls): + """Test the execute_batch method in case of success.""" + mock_get_lib.return_value = self.mock_lib + mock_deserialize = MagicMock() + mock_response_cls.deserialize = mock_deserialize + mock_response_obj = MagicMock() + mock_deserialize.return_value = mock_response_obj + + # Simulate a serialized response + dummy_response_bytes = b"dummy" + self.mock_lib.ExecuteBatch.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(dummy_response_bytes), + msg=ctypes.cast( + ctypes.c_char_p(dummy_response_bytes), ctypes.c_void_p + ), + ) + + request = ExecuteBatchDmlRequest() + response = self.conn.execute_batch(request) + + self.mock_lib.ExecuteBatch.assert_called_once() + mock_deserialize.assert_called_once_with(dummy_response_bytes) + self.assertIs(response, mock_response_obj) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_execute_batch_failure(self, mock_get_lib): + """Test the execute_batch method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.ExecuteBatch.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + + request = ExecuteBatchDmlRequest() + with self.assertRaises(SpannerLibError): + self.conn.execute_batch(request) + self.mock_lib.ExecuteBatch.assert_called_once() + + def test_execute_batch_closed_connection(self): + """Test executing batch on a closed connection.""" + self.conn.closed = True + with self.assertRaises(RuntimeError): + request = ExecuteBatchDmlRequest() + self.conn.execute_batch(request) + + @patch("google.cloud.spannerlib.connection.CommitResponse") + @patch("google.cloud.spannerlib.connection.get_lib") + def test_write_mutations_success(self, mock_get_lib, mock_response_cls): + """Test the write_mutations method in case of success.""" + mock_get_lib.return_value = self.mock_lib + mock_deserialize = MagicMock() + mock_response_cls.deserialize = mock_deserialize + mock_response_obj = MagicMock() + mock_deserialize.return_value = mock_response_obj + + # Simulate a serialized response + dummy_response_bytes = b"dummy" + self.mock_lib.WriteMutations.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(dummy_response_bytes), + msg=ctypes.cast( + ctypes.c_char_p(dummy_response_bytes), ctypes.c_void_p + ), + ) + + request = BatchWriteRequest.MutationGroup() + response = self.conn.write_mutations(request) + + self.mock_lib.WriteMutations.assert_called_once() + mock_deserialize.assert_called_once_with(dummy_response_bytes) + self.assertIs(response, mock_response_obj) + + @patch("google.cloud.spannerlib.connection.get_lib") + def test_write_mutations_failure(self, mock_get_lib): + """Test the write_mutations method in case of failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.WriteMutations.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + + request = BatchWriteRequest.MutationGroup() + with self.assertRaises(SpannerLibError): + self.conn.write_mutations(request) + self.mock_lib.WriteMutations.assert_called_once() + + def test_write_mutations_closed_connection(self): + """Test writing mutation on a closed connection.""" + self.conn.closed = True + with self.assertRaises(RuntimeError): + request = BatchWriteRequest.MutationGroup() + self.conn.write_mutations(request) + + +if __name__ == "__main__": + unittest.main() diff --git a/spannerlib/wrappers/spannerlib-python/tests/unit/test_pool.py b/spannerlib/wrappers/spannerlib-python/tests/unit/test_pool.py new file mode 100644 index 00000000..eefa9808 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/unit/test_pool.py @@ -0,0 +1,103 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +import os +import sys +import unittest +from unittest import mock + +from google.cloud.spannerlib import Pool +from google.cloud.spannerlib.internal import GoReturn + +# Adjust path to import from src +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +) + +TEST_CONNECTION_STRING = ( + "projects/test-project/instances/test-instance/databases/test-database" +) + + +class TestPool(unittest.TestCase): + @mock.patch("google.cloud.spannerlib.internal.spannerlib.Spannerlib._lib") + def test_pool_creation_and_close(self, mock_lib): + mock_lib.CreatePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=1, msg_len=0, msg=None + ) + mock_lib.ClosePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + pool = Pool.create_pool(TEST_CONNECTION_STRING) + self.assertEqual(pool.id, 1) + self.assertFalse(pool.closed) + pool.close() + self.assertTrue(pool.closed) + mock_lib.CreatePool.assert_called_once() + mock_lib.ClosePool.assert_called_once_with(1) + + @mock.patch("google.cloud.spannerlib.internal.spannerlib.Spannerlib._lib") + def test_connection_creation_and_close(self, mock_lib): + mock_lib.CreatePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=1, msg_len=0, msg=None + ) + mock_lib.CreateConnection.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=101, msg_len=0, msg=None + ) + mock_lib.CloseConnection.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + mock_lib.ClosePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + with Pool.create_pool(TEST_CONNECTION_STRING) as pool: + conn = pool.create_connection() + self.assertEqual(pool.id, 1) + self.assertEqual(conn.id, 101) + self.assertFalse(conn.closed) + conn.close() + self.assertTrue(conn.closed) + pool.close() + mock_lib.CreateConnection.assert_called_once_with(1) + mock_lib.CloseConnection.assert_called_once_with(1, 101) + + mock_lib.ClosePool.assert_called_once_with(1) + + @mock.patch("google.cloud.spannerlib.internal.spannerlib.Spannerlib._lib") + def test_connection_with_statement(self, mock_lib): + mock_lib.CreatePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=1, msg_len=0, msg=None + ) + mock_lib.CreateConnection.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=101, msg_len=0, msg=None + ) + mock_lib.CloseConnection.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + mock_lib.ClosePool.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + + with Pool.create_pool(TEST_CONNECTION_STRING) as pool: + with pool.create_connection() as conn: + self.assertEqual(conn.id, 101) + mock_lib.CloseConnection.assert_called_once_with(1, 101) + mock_lib.ClosePool.assert_called_once_with(1) + + +if __name__ == "__main__": + unittest.main() diff --git a/spannerlib/wrappers/spannerlib-python/tests/unit/test_rows.py b/spannerlib/wrappers/spannerlib-python/tests/unit/test_rows.py new file mode 100644 index 00000000..a81da385 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-python/tests/unit/test_rows.py @@ -0,0 +1,202 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Rows class.""" +from __future__ import absolute_import + +import ctypes +import unittest +from unittest.mock import MagicMock, patch + +from google.cloud.spanner_v1 import ResultSetMetadata, ResultSetStats +from google.protobuf.struct_pb2 import ListValue, Value + +from google.cloud.spannerlib import Rows, SpannerLibError +from google.cloud.spannerlib.internal import GoReturn + + +class TestRows(unittest.TestCase): + """Unit tests for the Rows class.""" + + def setUp(self): + """Set up the test environment.""" + self.mock_pool = MagicMock() + self.mock_pool.id = 1 + self.mock_pool.closed = False + + self.mock_conn = MagicMock() + self.mock_conn.id = 123 + self.mock_conn.closed = False + + self.rows = Rows(id=101, pool=self.mock_pool, conn=self.mock_conn) + self.mock_lib = MagicMock() + + def tearDown(self): + """Tear down the test environment.""" + if not self.rows.closed: + try: + # Minimal mock to avoid errors in close + with patch( + "google.cloud.spannerlib.rows.get_lib" + ) as mock_get_lib: + mock_lib = MagicMock() + mock_get_lib.return_value = mock_lib + mock_lib.CloseRows.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=0, + msg=None, + ) + self.rows.close() + except SpannerLibError: + pass + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_close_success(self, mock_get_lib): + """Test the close method.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.CloseRows.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + self.rows.close() + self.mock_lib.CloseRows.assert_called_once_with(1, 123, 101) + self.assertTrue(self.rows.closed) + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_metadata_success(self, mock_get_lib): + """Test metadata success.""" + mock_get_lib.return_value = self.mock_lib + metadata = ResultSetMetadata() + metadata_bytes = ResultSetMetadata.serialize(metadata) + self.mock_lib.Metadata.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(metadata_bytes), + msg=ctypes.cast(ctypes.c_char_p(metadata_bytes), ctypes.c_void_p), + ) + + result = self.rows.metadata() + self.assertIsInstance(result, ResultSetMetadata) + self.mock_lib.Metadata.assert_called_once_with(1, 123, 101) + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_metadata_failure(self, mock_get_lib): + """Test metadata failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Metadata.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + with self.assertRaises(SpannerLibError): + self.rows.metadata() + + def test_metadata_closed(self): + """Test metadata on closed Rows.""" + self.rows.closed = True + with self.assertRaises(RuntimeError): + self.rows.metadata() + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_next_success(self, mock_get_lib): + """Test next success.""" + mock_get_lib.return_value = self.mock_lib + list_value = ListValue(values=[Value(string_value="test")]) + list_value_bytes = list_value.SerializeToString() + self.mock_lib.Next.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(list_value_bytes), + msg=ctypes.cast(ctypes.c_char_p(list_value_bytes), ctypes.c_void_p), + ) + + result = self.rows.next() + self.assertIsInstance(result, ListValue) + self.assertEqual(result.values[0].string_value, "test") + self.mock_lib.Next.assert_called_once_with(1, 123, 101, 1, 1) + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_next_no_more_rows(self, mock_get_lib): + """Test next when no more rows.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Next.return_value = GoReturn( + pinner_id=0, error_code=0, object_id=0, msg_len=0, msg=None + ) + result = self.rows.next() + self.assertIsNone(result) + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_next_failure(self, mock_get_lib): + """Test next failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.Next.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + with self.assertRaises(SpannerLibError): + self.rows.next() + + def test_next_closed(self): + """Test next on closed Rows.""" + self.rows.closed = True + with self.assertRaises(RuntimeError): + self.rows.next() + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_result_set_stats_success(self, mock_get_lib): + """Test result_set_stats success.""" + mock_get_lib.return_value = self.mock_lib + stats = ResultSetStats(row_count_exact=5) + stats_bytes = ResultSetStats.serialize(stats) + self.mock_lib.ResultSetStats.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(stats_bytes), + msg=ctypes.cast(ctypes.c_char_p(stats_bytes), ctypes.c_void_p), + ) + + result = self.rows.result_set_stats() + self.assertIsInstance(result, ResultSetStats) + self.assertEqual(result.row_count_exact, 5) + self.mock_lib.ResultSetStats.assert_called_once_with(1, 123, 101) + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_result_set_stats_failure(self, mock_get_lib): + """Test result_set_stats failure.""" + mock_get_lib.return_value = self.mock_lib + self.mock_lib.ResultSetStats.return_value = GoReturn( + pinner_id=0, error_code=1, object_id=0, msg_len=0, msg=None + ) + with self.assertRaises(SpannerLibError): + self.rows.result_set_stats() + + @patch("google.cloud.spannerlib.rows.get_lib") + def test_update_count(self, mock_get_lib): + """Test update_count.""" + mock_get_lib.return_value = self.mock_lib + stats = ResultSetStats(row_count_exact=10) + stats_bytes = ResultSetStats.serialize(stats) + self.mock_lib.ResultSetStats.return_value = GoReturn( + pinner_id=0, + error_code=0, + object_id=0, + msg_len=len(stats_bytes), + msg=ctypes.cast(ctypes.c_char_p(stats_bytes), ctypes.c_void_p), + ) + self.assertEqual(self.rows.update_count(), 10) + + +if __name__ == "__main__": + unittest.main()