diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 92a25ac4..3a80e144 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.2.0-alpha.88" + ".": "0.2.0-alpha.89" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index c4447ed6..68b4cad0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## 0.2.0-alpha.89 (2025-09-08) + +Full Changelog: [v0.2.0-alpha.88...v0.2.0-alpha.89](https://github.com/openlayer-ai/openlayer-python/compare/v0.2.0-alpha.88...v0.2.0-alpha.89) + +### Features + +* add guardrails system with PII protection ([0bcc636](https://github.com/openlayer-ai/openlayer-python/commit/0bcc6366c438f6e501381601245b114c8990dbdc)) +* integrate guardrails into tracing system ([b846ba5](https://github.com/openlayer-ai/openlayer-python/commit/b846ba512fcaba2ad5b1ccf743983ffcc7aa9610)) +* introduce guardrail step typ ([73c53c6](https://github.com/openlayer-ai/openlayer-python/commit/73c53c627193ae78f7752b08eb9b5d301d6107be)) + + +### Bug Fixes + +* PII redaction and trace function calls ([5f49d4a](https://github.com/openlayer-ai/openlayer-python/commit/5f49d4a3e41c45056162cb7620e5061616752d8f)) + + +### Chores + +* add missing type annotations ([86d2d54](https://github.com/openlayer-ai/openlayer-python/commit/86d2d54a1880c199aa22f9eb790d9632c59a0a13)) +* cleanup unnecessary files ([d91fd55](https://github.com/openlayer-ai/openlayer-python/commit/d91fd55a10f4f6d72e07674245645b181478c7d9)) +* completes OPEN-7287 remove concrete guardrail implementations ([623f812](https://github.com/openlayer-ai/openlayer-python/commit/623f812901d8574d399604394048282335ea071b)) +* completes OPEN-7289 write unit tests for the tracer ([c80943f](https://github.com/openlayer-ai/openlayer-python/commit/c80943f0b8f147ab6b0be2feadac0efeab4cf539)) +* update tracer implementation ([9abe566](https://github.com/openlayer-ai/openlayer-python/commit/9abe566c3bed316d890cddd77ae9560a6899d293)) + + +### Documentation + +* add comprehensive guardrails usage examples ([be7d827](https://github.com/openlayer-ai/openlayer-python/commit/be7d82759ce3dbf1bec27e0d1790c6e4776c892c)) + + +### Refactors + +* simplify guardrails integration and clean up examples ([4dde617](https://github.com/openlayer-ai/openlayer-python/commit/4dde6175d904aacb6d0519386880b35f9b8bf4ef)) + ## 0.2.0-alpha.88 (2025-09-04) Full Changelog: [v0.2.0-alpha.87...v0.2.0-alpha.88](https://github.com/openlayer-ai/openlayer-python/compare/v0.2.0-alpha.87...v0.2.0-alpha.88) diff --git a/examples/tracing/programmatic_configuration.py b/examples/tracing/programmatic_configuration.py index ce37393b..a8c22396 100644 --- a/examples/tracing/programmatic_configuration.py +++ b/examples/tracing/programmatic_configuration.py @@ -123,7 +123,9 @@ def mixed_config_function(query: str) -> str: print("=" * 50) # Note: Replace the placeholder API keys and IDs with real values - print("Note: Replace placeholder API keys and pipeline IDs with real values before running.") + print( + "Note: Replace placeholder API keys and pipeline IDs with real values before running." + ) print() try: diff --git a/pyproject.toml b/pyproject.toml index cff3d9f5..fe5c2806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openlayer" -version = "0.2.0-alpha.88" +version = "0.2.0-alpha.89" description = "The official Python library for the openlayer API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/openlayer/_version.py b/src/openlayer/_version.py index 337938d0..a28c29ff 100644 --- a/src/openlayer/_version.py +++ b/src/openlayer/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "openlayer" -__version__ = "0.2.0-alpha.88" # x-release-please-version +__version__ = "0.2.0-alpha.89" # x-release-please-version diff --git a/src/openlayer/lib/guardrails/__init__.py b/src/openlayer/lib/guardrails/__init__.py new file mode 100644 index 00000000..a6c9e5b8 --- /dev/null +++ b/src/openlayer/lib/guardrails/__init__.py @@ -0,0 +1,17 @@ +"""Guardrails module for Openlayer tracing.""" + +from .base import ( + GuardrailAction, + BlockStrategy, + GuardrailResult, + BaseGuardrail, + GuardrailBlockedException, +) + +__all__ = [ + "GuardrailAction", + "BlockStrategy", + "GuardrailResult", + "BaseGuardrail", + "GuardrailBlockedException", +] diff --git a/src/openlayer/lib/guardrails/base.py b/src/openlayer/lib/guardrails/base.py new file mode 100644 index 00000000..ea12e2c9 --- /dev/null +++ b/src/openlayer/lib/guardrails/base.py @@ -0,0 +1,118 @@ +"""Base classes and interfaces for guardrails system.""" + +import abc +import enum +import logging +from typing import Any, Dict, Optional +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +class GuardrailAction(enum.Enum): + """Actions that a guardrail can take.""" + + ALLOW = "allow" + BLOCK = "block" + MODIFY = "modify" + + +class BlockStrategy(enum.Enum): + """Strategies for handling blocked requests.""" + + RAISE_EXCEPTION = ( + "raise_exception" # Raise GuardrailBlockedException (breaks pipeline) + ) + RETURN_EMPTY = "return_empty" # Return empty/None response (graceful) + RETURN_ERROR_MESSAGE = "return_error_message" # Return error message (graceful) + SKIP_FUNCTION = "skip_function" # Skip function execution, return None (graceful) + + +@dataclass +class GuardrailResult: + """Result of applying a guardrail.""" + + action: GuardrailAction + modified_data: Optional[Any] = None + metadata: Optional[Dict[str, Any]] = None + reason: Optional[str] = None + block_strategy: Optional[BlockStrategy] = None + error_message: Optional[str] = None + + def __post_init__(self): + """Validate the result after initialization.""" + if self.action == GuardrailAction.MODIFY and self.modified_data is None: + raise ValueError("modified_data must be provided when action is MODIFY") + if self.action == GuardrailAction.BLOCK and self.block_strategy is None: + self.block_strategy = ( + BlockStrategy.RAISE_EXCEPTION + ) # Default to existing behavior + + +class GuardrailBlockedException(Exception): + """Exception raised when a guardrail blocks execution.""" + + def __init__( + self, + guardrail_name: str, + reason: str, + metadata: Optional[Dict[str, Any]] = None, + ): + self.guardrail_name = guardrail_name + self.reason = reason + self.metadata = metadata or {} + super().__init__(f"Guardrail '{guardrail_name}' blocked execution: {reason}") + + +class BaseGuardrail(abc.ABC): + """Base class for all guardrails.""" + + def __init__(self, name: str, enabled: bool = True, **config): + """Initialize the guardrail. + + Args: + name: Human-readable name for this guardrail + enabled: Whether this guardrail is active + **config: Guardrail-specific configuration + """ + self.name = name + self.enabled = enabled + self.config = config + + @abc.abstractmethod + def check_input(self, inputs: Dict[str, Any]) -> GuardrailResult: + """Check and potentially modify function inputs. + + Args: + inputs: Dictionary of function inputs (parameter_name -> value) + + Returns: + GuardrailResult indicating the action to take + """ + pass + + @abc.abstractmethod + def check_output(self, output: Any, inputs: Dict[str, Any]) -> GuardrailResult: + """Check and potentially modify function output. + + Args: + output: The function's output + inputs: Dictionary of function inputs for context + + Returns: + GuardrailResult indicating the action to take + """ + pass + + def is_enabled(self) -> bool: + """Check if this guardrail is enabled.""" + return self.enabled + + def get_metadata(self) -> Dict[str, Any]: + """Get metadata about this guardrail for trace logging.""" + return { + "name": self.name, + "type": self.__class__.__name__, + "enabled": self.enabled, + "config": self.config, + } diff --git a/src/openlayer/lib/tracing/enums.py b/src/openlayer/lib/tracing/enums.py index 9b467a96..a188f4ff 100644 --- a/src/openlayer/lib/tracing/enums.py +++ b/src/openlayer/lib/tracing/enums.py @@ -4,9 +4,10 @@ class StepType(enum.Enum): - USER_CALL = "user_call" - CHAT_COMPLETION = "chat_completion" AGENT = "agent" + CHAT_COMPLETION = "chat_completion" + GUARDRAIL = "guardrail" + HANDOFF = "handoff" RETRIEVER = "retriever" TOOL = "tool" - HANDOFF = "handoff" + USER_CALL = "user_call" diff --git a/src/openlayer/lib/tracing/steps.py b/src/openlayer/lib/tracing/steps.py index 122890f2..4afe13f4 100644 --- a/src/openlayer/lib/tracing/steps.py +++ b/src/openlayer/lib/tracing/steps.py @@ -2,7 +2,7 @@ import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional, List from .. import utils from . import enums @@ -229,6 +229,39 @@ def to_dict(self) -> Dict[str, Any]: return step_dict +class GuardrailStep(Step): + """Step for tracking guardrail execution.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.step_type = enums.StepType.GUARDRAIL + self.action: Optional[str] = None + self.blocked_entities: Optional[List[str]] = None + self.confidence_threshold: float = None + self.reason: Optional[str] = None + self.detected_entities: Optional[List[str]] = None + self.redacted_entities: Optional[List[str]] = None + self.block_strategy: Optional[str] = None + self.data_type: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Dictionary representation of the GuardrailStep.""" + step_dict = super().to_dict() + step_dict.update( + { + "action": self.action, + "blockedEntities": self.blocked_entities, + "confidenceThreshold": self.confidence_threshold, + "reason": self.reason, + "detectedEntities": self.detected_entities, + "blockStrategy": self.block_strategy, + "redactedEntities": self.redacted_entities, + "dataType": self.data_type, + } + ) + return step_dict + + # ----------------------------- Factory function ----------------------------- # def step_factory(step_type: enums.StepType, *args, **kwargs) -> Step: """Factory function to create a step based on the step_type.""" @@ -241,5 +274,6 @@ def step_factory(step_type: enums.StepType, *args, **kwargs) -> Step: enums.StepType.RETRIEVER: RetrieverStep, enums.StepType.TOOL: ToolStep, enums.StepType.HANDOFF: HandoffStep, + enums.StepType.GUARDRAIL: GuardrailStep, } return step_type_mapping[step_type](*args, **kwargs) diff --git a/src/openlayer/lib/tracing/tracer.py b/src/openlayer/lib/tracing/tracer.py index c04e56c8..a616eb2b 100644 --- a/src/openlayer/lib/tracing/tracer.py +++ b/src/openlayer/lib/tracing/tracer.py @@ -15,6 +15,7 @@ from ...types.inference_pipelines.data_stream_params import ConfigLlmData from .. import utils from . import enums, steps, traces +from ..guardrails.base import GuardrailResult, GuardrailAction logger = logging.getLogger(__name__) @@ -23,7 +24,9 @@ TRUE_LIST = ["true", "on", "1"] _publish = utils.get_env_variable("OPENLAYER_DISABLE_PUBLISH") not in TRUE_LIST -_verify_ssl = (utils.get_env_variable("OPENLAYER_VERIFY_SSL") or "true").lower() in TRUE_LIST +_verify_ssl = ( + utils.get_env_variable("OPENLAYER_VERIFY_SSL") or "true" +).lower() in TRUE_LIST _client = None # Configuration variables for programmatic setup @@ -167,9 +170,10 @@ def trace( *step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, + guardrails: Optional[List[Any]] = None, **step_kwargs, ): - """Decorator to trace a function. + """Decorator to trace a function with optional guardrails. Examples -------- @@ -180,13 +184,17 @@ def trace( >>> import os >>> from openlayer.tracing import tracer + >>> from openlayer.lib.guardrails import PIIGuardrail >>> >>> # Set the environment variables >>> os.environ["OPENLAYER_API_KEY"] = "YOUR_OPENLAYER_API_KEY_HERE" >>> os.environ["OPENLAYER_PROJECT_NAME"] = "YOUR_OPENLAYER_PROJECT_NAME_HERE" >>> - >>> # Decorate all the functions you want to trace - >>> @tracer.trace() + >>> # Create guardrail instance + >>> pii_guardrail = PIIGuardrail(name="PII Protection") + >>> + >>> # Decorate functions with tracing and guardrails + >>> @tracer.trace(guardrails=[pii_guardrail]) >>> def main(user_query: str) -> str: >>> context = retrieve_context(user_query) >>> answer = generate_answer(user_query, context) @@ -196,7 +204,7 @@ def trace( >>> def retrieve_context(user_query: str) -> str: >>> return "Some context" >>> - >>> @tracer.trace() + >>> @tracer.trace(guardrails=[pii_guardrail]) >>> def generate_answer(user_query: str, context: str) -> str: >>> return "Some answer" >>> @@ -234,12 +242,14 @@ def __next__(self): # Initialize tracing on first iteration only if not self._trace_initialized: self._original_gen = func(*func_args, **func_kwargs) - self._step, self._is_root_step, self._token = _create_and_initialize_step( - step_name=step_name, - step_type=enums.StepType.USER_CALL, - inputs=None, - output=None, - metadata=None, + self._step, self._is_root_step, self._token = ( + _create_and_initialize_step( + step_name=step_name, + step_type=enums.StepType.USER_CALL, + inputs=None, + output=None, + metadata=None, + ) ) self._inputs = _extract_function_inputs( func_signature=func_signature, @@ -286,17 +296,103 @@ def __next__(self): return sync_generator_wrapper else: - # Handle regular functions + # Handle regular functions with guardrail support @wraps(func) def wrapper(*func_args, **func_kwargs): if step_kwargs.get("name") is None: step_kwargs["name"] = func.__name__ - with create_step(*step_args, inference_pipeline_id=inference_pipeline_id, **step_kwargs) as step: + with create_step( + *step_args, + inference_pipeline_id=inference_pipeline_id, + **step_kwargs, + ) as step: output = exception = None + original_inputs = None + modified_inputs = None + guardrail_metadata = {} + try: - output = func(*func_args, **func_kwargs) + # Extract original inputs for guardrail processing + original_inputs = _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ) + + # Apply input guardrails + modified_inputs, input_guardrail_metadata = ( + _apply_input_guardrails( + guardrails or [], + original_inputs, + ) + ) + guardrail_metadata.update(input_guardrail_metadata) + + # Check if function execution should be skipped + if ( + hasattr(modified_inputs, "__class__") + and modified_inputs.__class__.__name__ + == "SkipFunctionExecution" + ): + # Function execution was blocked with SKIP_FUNCTION strategy + output = None + logger.debug( + "Function %s execution skipped by guardrail", + func.__name__, + ) + else: + # Execute function with potentially modified inputs + if modified_inputs != original_inputs: + # Reconstruct function arguments from modified inputs + bound = func_signature.bind(*func_args, **func_kwargs) + bound.apply_defaults() + + # Update bound arguments with modified values + for ( + param_name, + modified_value, + ) in modified_inputs.items(): + if param_name in bound.arguments: + bound.arguments[param_name] = modified_value + + output = func(*bound.args, **bound.kwargs) + else: + output = func(*func_args, **func_kwargs) + + # Apply output guardrails (skip if function was skipped) + if ( + hasattr(modified_inputs, "__class__") + and modified_inputs.__class__.__name__ + == "SkipFunctionExecution" + ): + final_output, output_guardrail_metadata = output, {} + # Use original inputs for logging since modified_inputs + # is a special marker + modified_inputs = original_inputs + else: + final_output, output_guardrail_metadata = ( + _apply_output_guardrails( + guardrails or [], + output, + modified_inputs or original_inputs, + ) + ) + guardrail_metadata.update(output_guardrail_metadata) + + if final_output != output: + output = final_output + except Exception as exc: + # Check if this is a guardrail exception + if hasattr(exc, "guardrail_name"): + guardrail_metadata[f"{exc.guardrail_name}_blocked"] = { + "action": "blocked", + "reason": exc.reason, + "metadata": getattr(exc, "metadata", {}), + } + _log_step_exception(step, exc) exception = exc @@ -308,6 +404,7 @@ def wrapper(*func_args, **func_kwargs): func_kwargs=func_kwargs, context_kwarg=context_kwarg, output=output, + guardrail_metadata=guardrail_metadata, ) if exception is not None: @@ -323,11 +420,13 @@ def trace_async( *step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, + guardrails: Optional[List[Any]] = None, **step_kwargs, ): """Decorator to trace async functions and async generators. - This decorator automatically detects whether the function is a regular async function + This decorator automatically detects whether the function is a regular async + function or an async generator and handles both cases appropriately. Examples @@ -379,12 +478,14 @@ async def __anext__(self): # Initialize tracing on first iteration only if not self._trace_initialized: self._original_gen = func(*func_args, **func_kwargs) - self._step, self._is_root_step, self._token = _create_and_initialize_step( - step_name=step_name, - step_type=enums.StepType.USER_CALL, - inputs=None, - output=None, - metadata=None, + self._step, self._is_root_step, self._token = ( + _create_and_initialize_step( + step_name=step_name, + step_type=enums.StepType.USER_CALL, + inputs=None, + output=None, + metadata=None, + ) ) self._inputs = _extract_function_inputs( func_signature=func_signature, @@ -440,13 +541,82 @@ async def async_function_wrapper(*func_args, **func_kwargs): **step_kwargs, ) as step: output = exception = None + guardrail_metadata = {} try: - output = await func(*func_args, **func_kwargs) + # Apply input guardrails if provided + if guardrails: + try: + inputs = _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ) + + # Process inputs through guardrails + modified_inputs, input_metadata = ( + _apply_input_guardrails( + guardrails, + inputs, + ) + ) + guardrail_metadata.update(input_metadata) + + # Execute function with potentially modified inputs + if modified_inputs != inputs: + # Reconstruct function arguments from modified inputs + bound = func_signature.bind( + *func_args, **func_kwargs + ) + bound.apply_defaults() + + # Update bound arguments with modified values + for ( + param_name, + modified_value, + ) in modified_inputs.items(): + if param_name in bound.arguments: + bound.arguments[param_name] = ( + modified_value + ) + + output = await func(*bound.args, **bound.kwargs) + else: + output = await func(*func_args, **func_kwargs) + except Exception as e: + # Log guardrail errors but don't fail function execution + logger.error("Guardrail error: %s", e) + output = await func(*func_args, **func_kwargs) + else: + output = await func(*func_args, **func_kwargs) + except Exception as exc: _log_step_exception(step, exc) - exception = exc - raise + raise exc + + # Apply output guardrails if provided + if guardrails and output is not None: + try: + final_output, output_metadata = ( + _apply_output_guardrails( + guardrails, + output, + _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ), + ) + ) + guardrail_metadata.update(output_metadata) + + if final_output != output: + output = final_output + except Exception as e: + # Log guardrail errors but don't fail function execution + logger.error("Output guardrail error: %s", e) # Extract inputs and finalize logging _process_wrapper_inputs_and_outputs( @@ -456,6 +626,7 @@ async def async_function_wrapper(*func_args, **func_kwargs): func_kwargs=func_kwargs, context_kwarg=context_kwarg, output=output, + guardrail_metadata=guardrail_metadata, ) return output @@ -471,12 +642,78 @@ def sync_wrapper(*func_args, **func_kwargs): **step_kwargs, ) as step: output = exception = None + guardrail_metadata = {} try: - output = func(*func_args, **func_kwargs) + # Apply input guardrails if provided + if guardrails: + try: + inputs = _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ) + + # Process inputs through guardrails + modified_inputs, input_metadata = ( + _apply_input_guardrails( + guardrails, + inputs, + ) + ) + guardrail_metadata.update(input_metadata) + + # Execute function with potentially modified inputs + if modified_inputs != inputs: + # Reconstruct function arguments from modified inputs + bound = func_signature.bind( + *func_args, **func_kwargs + ) + bound.apply_defaults() + + # Update bound arguments with modified values + for ( + param_name, + modified_value, + ) in modified_inputs.items(): + if param_name in bound.arguments: + bound.arguments[param_name] = modified_value + + output = func(*bound.args, **bound.kwargs) + else: + output = func(*func_args, **func_kwargs) + except Exception as e: + # Log guardrail errors but don't fail function execution + logger.error("Guardrail error: %s", e) + output = func(*func_args, **func_kwargs) + else: + output = func(*func_args, **func_kwargs) + except Exception as exc: _log_step_exception(step, exc) exception = exc + # Apply output guardrails if provided + if guardrails and output is not None: + try: + final_output, output_metadata = _apply_output_guardrails( + guardrails, + output, + _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ), + ) + guardrail_metadata.update(output_metadata) + + if final_output != output: + output = final_output + except Exception as e: + # Log guardrail errors but don't fail function execution + logger.error("Output guardrail error: %s", e) + # Extract inputs and finalize logging _process_wrapper_inputs_and_outputs( step=step, @@ -485,6 +722,7 @@ def sync_wrapper(*func_args, **func_kwargs): func_kwargs=func_kwargs, context_kwarg=context_kwarg, output=output, + guardrail_metadata=guardrail_metadata, ) if exception is not None: @@ -528,15 +766,15 @@ def log_context(context: List[str]) -> None: def update_current_trace(**kwargs) -> None: """Updates the current trace metadata with the provided values. - + This function allows users to set trace-level metadata dynamically during execution without having to pass it through function arguments. - + All provided key-value pairs will be stored in the trace metadata. - + Example: >>> from openlayer.lib import trace, update_current_trace - >>> + >>> >>> @trace() >>> def my_function(): >>> # Update trace with user context @@ -555,27 +793,27 @@ def update_current_trace(**kwargs) -> None: "(e.g., inside a function decorated with @trace)." ) return - + current_trace.update_metadata(**kwargs) logger.debug("Updated current trace metadata") def update_current_step( attributes: Optional[Dict[str, Any]] = None, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ) -> None: """Updates the current step with the provided attributes. - + This function allows users to set step-level metadata dynamically during execution. - + Args: attributes: Optional dictionary of attributes to set on the step metadata: Optional dictionary of metadata to merge with existing metadata - + Example: >>> from openlayer.lib import trace, update_current_step - >>> + >>> >>> @trace() >>> def my_function(): >>> # Update current step with additional context @@ -592,7 +830,7 @@ def update_current_step( "(e.g., inside a function decorated with @trace)." ) return - + # Update step attributes using the existing log method update_data = {} if metadata is not None: @@ -600,15 +838,15 @@ def update_current_step( existing_metadata = current_step.metadata or {} existing_metadata.update(metadata) update_data["metadata"] = existing_metadata - + # Handle generic attributes by setting them directly on the step if attributes is not None: for key, value in attributes.items(): setattr(current_step, key, value) - + if update_data: current_step.log(**update_data) - + logger.debug("Updated current step metadata") @@ -664,7 +902,9 @@ def _create_and_initialize_step( return new_step, is_root_step, token -def _handle_trace_completion(is_root_step: bool, step_name: str, inference_pipeline_id: Optional[str] = None) -> None: +def _handle_trace_completion( + is_root_step: bool, step_name: str, inference_pipeline_id: Optional[str] = None +) -> None: """Handle trace completion and data streaming.""" if is_root_step: logger.debug("Ending the trace...") @@ -740,15 +980,23 @@ def _process_wrapper_inputs_and_outputs( func_kwargs: dict, context_kwarg: Optional[str], output: Any, + guardrail_metadata: Optional[Dict[str, Any]] = None, ) -> None: - """Extract function inputs and finalize step logging - common pattern across wrappers.""" + """Extract function inputs and finalize step logging - common pattern across + wrappers.""" inputs = _extract_function_inputs( func_signature=func_signature, func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, ) - _finalize_step_logging(step=step, inputs=inputs, output=output, start_time=step.start_time) + _finalize_step_logging( + step=step, + inputs=inputs, + output=output, + start_time=step.start_time, + guardrail_metadata=guardrail_metadata, + ) def _extract_function_inputs( @@ -782,6 +1030,7 @@ def _finalize_step_logging( inputs: dict, output: Any, start_time: float, + guardrail_metadata: Optional[Dict[str, Any]] = None, ) -> None: """Finalize step timing and logging.""" if step.end_time is None: @@ -795,10 +1044,38 @@ def _finalize_step_logging( else: step.log(output=output) + # Start with existing metadata instead of overwriting it + step_metadata = step.metadata.copy() if step.metadata else {} + + # Add guardrail metadata to step metadata + if guardrail_metadata: + step_metadata["guardrails"] = guardrail_metadata + + # Add summary fields for easy filtering + step_metadata["has_guardrails"] = True + step_metadata["guardrail_actions"] = [ + metadata.get("action") for metadata in guardrail_metadata.values() + ] + step_metadata["guardrail_names"] = [ + key.replace("input_", "").replace("output_", "") + for key in guardrail_metadata.keys() + ] + + # Add flags for specific actions for easy filtering + actions = step_metadata["guardrail_actions"] + step_metadata["guardrail_blocked"] = "blocked" in actions + step_metadata["guardrail_modified"] = ( + "redacted" in actions or "modified" in actions + ) + step_metadata["guardrail_allowed"] = "allow" in actions + else: + step_metadata["has_guardrails"] = False + step.log( inputs=inputs, end_time=step.end_time, latency=step.latency, + metadata=step_metadata, ) @@ -821,9 +1098,14 @@ def _finalize_sync_generator_step( # Context variable was created in a different context (e.g., different thread) # This can happen in async/multi-threaded environments like FastAPI/OpenWebUI # We can safely ignore this as the step finalization will still complete - logger.debug("Context variable reset failed - generator consumed in different context") - - _finalize_step_logging(step=step, inputs=inputs, output=output, start_time=step.start_time) + logger.debug( + "Context variable reset failed - generator consumed in different context" + ) + + _finalize_step_logging( + step=step, inputs=inputs, output=output, start_time=step.start_time + ) + _handle_trace_completion( is_root_step=is_root_step, step_name=step_name, @@ -842,7 +1124,9 @@ def _finalize_async_generator_step( ) -> None: """Finalize async generator step - called when generator is consumed.""" _current_step.reset(token) - _finalize_step_logging(step=step, inputs=inputs, output=output, start_time=step.start_time) + _finalize_step_logging( + step=step, inputs=inputs, output=output, start_time=step.start_time + ) _handle_trace_completion( is_root_step=is_root_step, step_name=step_name, @@ -894,12 +1178,12 @@ def post_process_trace( "steps": processed_steps, **root_step.metadata, } - + # Include trace-level metadata if set - extract keys to row/record level if trace_obj.metadata is not None: # Add each trace metadata key directly to the row/record level trace_data.update(trace_obj.metadata) - + if root_step.ground_truth: trace_data["groundTruth"] = root_step.ground_truth if input_variables: @@ -910,3 +1194,384 @@ def post_process_trace( trace_data["context"] = context return trace_data, input_variable_names + + +# ----------------------------- Guardrail helper functions ----------------------------- # + + +def _apply_input_guardrails( + guardrails: List[Any], + inputs: Dict[str, Any], +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Apply guardrails to function inputs, creating guardrail steps. + + Args: + guardrails: List of guardrail instances + inputs: Extracted function inputs + + Returns: + Tuple of (modified_inputs, guardrail_metadata) + """ + if not guardrails: + return inputs, {} + + modified_inputs = inputs.copy() + overall_metadata = {} + + for i, guardrail in enumerate(guardrails): + try: + # Import here to avoid circular imports + from ..guardrails.base import BaseGuardrail, GuardrailBlockedException + + if not isinstance(guardrail, BaseGuardrail): + logger.warning("Skipping invalid guardrail: %s", guardrail) + continue + + if not guardrail.is_enabled(): + continue + + # Create a guardrail step for this check + with create_step( + name=f"{guardrail.name} - Input", + step_type=enums.StepType.GUARDRAIL, + ) as guardrail_step: + try: + # Apply the guardrail + result = guardrail.check_input(modified_inputs) + + # Store guardrail metadata for main function step + guardrail_key = f"input_{guardrail.name.lower().replace(' ', '_')}" + overall_metadata[guardrail_key] = { + "action": result.action.value, + "reason": result.reason, + "metadata": result.metadata or {}, + } + + # Prepare step logging data + step_log_data = { + "action": result.action.value, + "reason": result.reason, + "data_type": "input", + "inputs": {"original_data": modified_inputs}, + } + + if result.action.value == "block": + # Handle the block according to strategy + final_inputs, block_metadata = _handle_guardrail_block( + guardrail=guardrail, + result=result, + modified_inputs=modified_inputs, + guardrail_metadata=overall_metadata, + guardrail_key=guardrail_key, + is_input=True, + ) + + # Add final output if different + if final_inputs != modified_inputs: + step_log_data["output"] = final_inputs + + # Log once with all data + guardrail_step.log(**step_log_data) + return final_inputs, overall_metadata + + elif ( + result.action.value == "modify" + and result.modified_data is not None + ): + step_log_data["output"] = result.modified_data + modified_inputs = result.modified_data + logger.debug("Guardrail %s modified inputs", guardrail.name) + + else: # allow + step_log_data["output"] = modified_inputs + + # Single log call with all data + guardrail_step.log(**step_log_data) + + except Exception as e: + # Create error result for the guardrail step + error_result = GuardrailResult( + action=GuardrailAction.ALLOW, # Default to allow on error + reason=f"Guardrail error: {str(e)}", + metadata={"error": str(e), "error_type": type(e).__name__}, + ) + guardrail_step.log( + inputs={"original_data": modified_inputs}, + output=modified_inputs, + ) + + if hasattr(e, "guardrail_name"): + # Re-raise guardrail exceptions + raise + else: + # Log other exceptions but don't fail the trace + logger.error( + "Error applying input guardrail %s: %s", guardrail.name, e + ) + guardrail_key = ( + f"input_{guardrail.name.lower().replace(' ', '_')}" + ) + overall_metadata[guardrail_key] = { + "action": "error", + "reason": str(e), + "metadata": {"error_type": type(e).__name__}, + "guardrail_name": guardrail.name, + } + + except Exception as e: + # Handle exceptions that occur outside the guardrail step context + if hasattr(e, "guardrail_name"): + raise + else: + logger.error( + "Error setting up input guardrail %s: %s", + getattr(guardrail, "name", f"guardrail_{i}"), + e, + ) + + return modified_inputs, overall_metadata + + +def _apply_output_guardrails( + guardrails: List[Any], output: Any, inputs: Dict[str, Any] +) -> Tuple[Any, Dict[str, Any]]: + """Apply guardrails to function output, creating guardrail steps. + + Args: + guardrails: List of guardrail instances + output: Function output + inputs: Function inputs for context + + Returns: + Tuple of (modified_output, guardrail_metadata) + """ + if not guardrails: + return output, {} + + modified_output = output + overall_metadata = {} + + for i, guardrail in enumerate(guardrails): + try: + # Import here to avoid circular imports + from ..guardrails.base import BaseGuardrail, GuardrailBlockedException + + if not isinstance(guardrail, BaseGuardrail): + logger.warning("Skipping invalid guardrail: %s", guardrail) + continue + + if not guardrail.is_enabled(): + continue + + # Create a guardrail step for this check + with create_step( + name=f"{guardrail.name} - Output", + step_type=enums.StepType.GUARDRAIL, + ) as guardrail_step: + try: + # Apply the guardrail + result = guardrail.check_output(modified_output, inputs) + + # Store guardrail metadata for main function step + guardrail_key = f"output_{guardrail.name.lower().replace(' ', '_')}" + overall_metadata[guardrail_key] = { + "action": result.action.value, + "reason": result.reason, + "metadata": result.metadata or {}, + } + + # Prepare step logging data + step_log_data = { + "action": result.action.value, + "reason": result.reason, + "data_type": "output", + "inputs": {"original_data": modified_output}, + } + + if result.action.value == "block": + # Handle the block according to strategy + final_output, block_metadata = _handle_guardrail_block( + guardrail=guardrail, + result=result, + modified_output=modified_output, + guardrail_metadata=overall_metadata, + guardrail_key=guardrail_key, + is_input=False, + ) + + # Add final output if different + if final_output != modified_output: + step_log_data["output"] = final_output + + # Log once with all data + guardrail_step.log(**step_log_data) + return final_output, overall_metadata + + elif ( + result.action.value == "modify" + and result.modified_data is not None + ): + step_log_data["output"] = result.modified_data + modified_output = result.modified_data + logger.debug("Guardrail %s modified output", guardrail.name) + + else: # allow + step_log_data["output"] = modified_output + + # Single log call with all data + guardrail_step.log(**step_log_data) + + except Exception as e: + # Create error result for the guardrail step + error_result = GuardrailResult( + action=GuardrailAction.ALLOW, # Default to allow on error + reason=f"Guardrail error: {str(e)}", + metadata={"error": str(e), "error_type": type(e).__name__}, + ) + guardrail_step.log( + inputs={"original_data": modified_output}, + output=modified_output, + ) + + if hasattr(e, "guardrail_name"): + # Re-raise guardrail exceptions + raise + else: + # Log other exceptions but don't fail the trace + logger.error( + "Error applying output guardrail %s: %s", guardrail.name, e + ) + guardrail_key = ( + f"output_{guardrail.name.lower().replace(' ', '_')}" + ) + overall_metadata[guardrail_key] = { + "action": "error", + "reason": str(e), + "metadata": {"error_type": type(e).__name__}, + } + guardrail_step.log(**overall_metadata[guardrail_key]) + + except Exception as e: + # Handle exceptions that occur outside the guardrail step context + if hasattr(e, "guardrail_name"): + raise + else: + logger.error( + "Error setting up output guardrail %s: %s", + getattr(guardrail, "name", f"guardrail_{i}"), + e, + ) + + return modified_output, overall_metadata + + +def _handle_guardrail_block( + guardrail: Any, + result: Any, + modified_inputs: Optional[Dict[str, Any]] = None, + modified_output: Optional[Any] = None, + guardrail_metadata: Optional[Dict[str, Any]] = None, + guardrail_key: Optional[str] = None, + is_input: bool = True, +) -> Tuple[Any, Dict[str, Any]]: + """Handle different block strategies for guardrails. + + Args: + guardrail: The guardrail instance + result: The GuardrailResult with block action + modified_inputs: Current inputs (for input guardrails) + modified_output: Current output (for output guardrails) + guardrail_metadata: Current guardrail metadata + guardrail_key: Key for storing metadata + is_input: True if this is an input guardrail, False for output + + Returns: + Tuple of (data, metadata) or raises exception based on strategy + """ + from ..guardrails.base import BlockStrategy, GuardrailBlockedException + + strategy = getattr(result, "block_strategy", None) + if strategy is None: + strategy = BlockStrategy.RAISE_EXCEPTION + + # Update metadata to reflect the blocking strategy used + if guardrail_metadata is not None and guardrail_key is not None: + guardrail_metadata[guardrail_key].update( + { + "block_strategy": strategy.value, + "error_message": getattr(result, "error_message", None), + } + ) + + if strategy == BlockStrategy.RAISE_EXCEPTION: + # Original behavior - raise exception (breaks pipeline) + raise GuardrailBlockedException( + guardrail_name=guardrail.name, + reason=result.reason + or f"{'Input' if is_input else 'Output'} blocked by guardrail", + metadata=result.metadata, + ) + + elif strategy == BlockStrategy.RETURN_EMPTY: + # Return empty/None response (graceful) + if is_input: + # For input blocking, return empty inputs + empty_inputs = {key: "" for key in (modified_inputs or {})} + logger.info( + "Guardrail %s blocked input, returning empty inputs", guardrail.name + ) + return empty_inputs, guardrail_metadata or {} + else: + # For output blocking, return None + logger.info("Guardrail %s blocked output, returning None", guardrail.name) + return None, guardrail_metadata or {} + + elif strategy == BlockStrategy.RETURN_ERROR_MESSAGE: + # Return error message (graceful) + error_msg = getattr( + result, "error_message", "Request blocked due to policy violation" + ) + logger.info( + "Guardrail %s blocked %s, returning error message", + guardrail.name, + "input" if is_input else "output", + ) + + if is_input: + # For input blocking, replace inputs with error message + error_inputs = {key: error_msg for key in (modified_inputs or {})} + return error_inputs, guardrail_metadata or {} + else: + # For output blocking, return error message + return error_msg, guardrail_metadata or {} + + elif strategy == BlockStrategy.SKIP_FUNCTION: + # Skip function execution, return None (graceful) + logger.info( + "Guardrail %s blocked %s, skipping execution", + guardrail.name, + "input" if is_input else "output", + ) + + if is_input: + # For input blocking, this will be handled by the main wrapper + # We'll use a special marker to indicate function should be skipped + class SkipFunctionExecution: + pass + + return SkipFunctionExecution(), guardrail_metadata or {} + else: + # For output blocking, return None + return None, guardrail_metadata or {} + + else: + # Fallback to raising exception + logger.warning( + "Unknown block strategy %s, falling back to raising exception", strategy + ) + raise GuardrailBlockedException( + guardrail_name=guardrail.name, + reason=result.reason + or f"{'Input' if is_input else 'Output'} blocked by guardrail", + metadata=result.metadata, + ) diff --git a/tests/test_tracing_core.py b/tests/test_tracing_core.py new file mode 100644 index 00000000..7ebdda3d --- /dev/null +++ b/tests/test_tracing_core.py @@ -0,0 +1,607 @@ +"""Core tracing functionality tests. + +Usage: +pytest tests/test_tracing_core.py -v +""" + +# ruff: noqa: ARG001 +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownArgumentType=false + +import asyncio +from typing import Any, Set, Dict, List, Generator +from unittest.mock import patch + +import pytest + +from openlayer.lib.tracing import enums, steps, tracer, traces + + +class TestBasicTracing: + """Test basic tracing functionality.""" + + def setup_method(self) -> None: + """Setup before each test - reset global state.""" + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + """Cleanup after each test.""" + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + @patch.object(tracer, "_publish", False) + def test_sync_function_tracing(self) -> None: + """Test that sync functions are traced correctly.""" + + @tracer.trace() + def simple_function(x: int, y: str = "default") -> str: + return f"{y}: {x}" + + result = simple_function(42, "test") + assert result == "test: 42" + + @patch.object(tracer, "_publish", False) + def test_async_function_tracing(self) -> None: + """Test that async functions are traced correctly.""" + + @tracer.trace_async() + async def async_function(x: int) -> int: + await asyncio.sleep(0.001) + return x * 2 + + result = asyncio.run(async_function(21)) + assert result == 42 + + @patch.object(tracer, "_publish", False) + def test_sync_generator_tracing(self) -> None: + """Test that sync generators are traced correctly.""" + + @tracer.trace() + def generator_function(n: int) -> Generator[int, None, None]: + for i in range(n): + yield i + + gen = generator_function(3) + results = list(gen) + assert results == [0, 1, 2] + + @patch.object(tracer, "_publish", False) + def test_nested_tracing(self) -> None: + """Test that nested traced functions work correctly.""" + + @tracer.trace() + def inner_function(x: int) -> int: + return x * 2 + + @tracer.trace() + def outer_function(x: int) -> int: + return inner_function(x) + 1 + + result = outer_function(5) + assert result == 11 + + +class TestContextManagement: + """Test context management functionality.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + @patch.object(tracer, "_publish", False) + def test_create_step_context_manager(self) -> None: + """Test the create_step context manager.""" + with tracer.create_step("test_step") as step: + assert step.name == "test_step" + + +class TestTraceDataStructure: + """Test trace data structure and content.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + @patch.object(tracer, "_publish", False) + def test_trace_captures_inputs_and_outputs(self) -> None: + """Test that trace captures function inputs and outputs correctly.""" + captured_trace = None + + @tracer.trace() + def test_function(a: int, b: str = "default", **kwargs: Any) -> Dict[str, Any]: + current_trace = tracer.get_current_trace() + nonlocal captured_trace + captured_trace = current_trace + return {"result": a * 2, "message": b} + + result = test_function(42, "test", extra="value") + + # Verify function result + assert result == {"result": 84, "message": "test"} + + # Verify trace structure + assert captured_trace is not None + assert len(captured_trace.steps) == 1 + + root_step = captured_trace.steps[0] + assert root_step.name == "test_function" + assert root_step.step_type == enums.StepType.USER_CALL + + # Verify inputs were captured (excluding 'self' and 'cls') + assert "a" in root_step.inputs + assert root_step.inputs["a"] == 42 + assert root_step.inputs["b"] == "test" + assert root_step.inputs["kwargs"] == {"extra": "value"} + + # Verify output was captured + assert root_step.output == {"result": 84, "message": "test"} + + # Verify timing data + assert root_step.start_time is not None + assert root_step.end_time is not None + assert root_step.latency is not None + assert root_step.latency > 0 # Should have some latency + + @patch.object(tracer, "_publish", False) + def test_nested_trace_structure(self) -> None: + """Test that nested traces create proper parent-child relationships.""" + captured_trace = None + + @tracer.trace() + def inner_function(x: int) -> int: + return x * 3 + + @tracer.trace() + def middle_function(x: int) -> int: + return inner_function(x) + 10 + + @tracer.trace() + def outer_function(x: int) -> int: + current_trace = tracer.get_current_trace() + nonlocal captured_trace + captured_trace = current_trace + return middle_function(x) + 1 + + result = outer_function(5) + assert result == 26 # (5 * 3) + 10 + 1 + + # Verify trace structure + assert captured_trace is not None + assert len(captured_trace.steps) == 1 # Only root step at trace level + + root_step = captured_trace.steps[0] + assert root_step.name == "outer_function" + + # Verify nested structure + assert len(root_step.steps) == 1 # middle_function + middle_step = root_step.steps[0] + assert middle_step.name == "middle_function" + + assert len(middle_step.steps) == 1 # inner_function + inner_step = middle_step.steps[0] + assert inner_step.name == "inner_function" + assert len(inner_step.steps) == 0 # leaf node + + # Verify all steps have proper data + assert inner_step.inputs["x"] == 5 + assert inner_step.output == 15 + assert middle_step.inputs["x"] == 5 + assert middle_step.output == 25 + assert root_step.inputs["x"] == 5 + assert root_step.output == 26 + + @patch.object(tracer, "_publish", False) + def test_step_timing_data(self) -> None: + """Test that step timing data is captured correctly.""" + + @tracer.trace() + def timed_function() -> str: + import time + + time.sleep(0.01) # 10ms delay + return "done" + + result = timed_function() + + assert result == "done" + + # Get the trace to examine timing + with tracer.create_step("dummy") as _dummy_step: + pass # This will finish the previous trace + + # The timing should be reasonable + # Note: We can't access the previous trace easily, so let's test timing + # with a context manager approach + + @patch.object(tracer, "_publish", False) + def test_step_ids_are_unique(self) -> None: + """Test that each step gets a unique ID.""" + step_ids: Set[str] = set() + + @tracer.trace() + def function1() -> str: + step = tracer.get_current_step() + if step is not None: + step_ids.add(str(step.id)) + return "result1" + + @tracer.trace() + def function2() -> str: + step = tracer.get_current_step() + if step is not None: + step_ids.add(str(step.id)) + return "result2" + + function1() + function2() + + # Should have 2 unique IDs + assert len(step_ids) == 2 + + @patch.object(tracer, "_publish", False) + def test_context_kwarg_functionality(self) -> None: + """Test that context_kwarg properly captures context data.""" + captured_context = None + + @tracer.trace(context_kwarg="context_data") + def rag_function(query: str, context_data: List[str]) -> str: # noqa: ARG001 + nonlocal captured_context + captured_context = tracer.get_rag_context() + return f"Answer for {query} using context" + + context_list = ["context1", "context2", "context3"] + result = rag_function("test query", context_list) + + assert result == "Answer for test query using context" + assert captured_context == context_list + + +class TestTraceMetadata: + """Test trace metadata functionality.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + @patch.object(tracer, "_publish", False) + def test_update_current_trace_metadata(self) -> None: + """Test that trace metadata can be updated during execution.""" + captured_trace = None + + @tracer.trace() + def test_function() -> str: + tracer.update_current_trace( + user_id="user123", session_id="session456", custom_field="custom_value" + ) + nonlocal captured_trace + captured_trace = tracer.get_current_trace() + return "result" + + test_function() + + assert captured_trace is not None + assert captured_trace.metadata is not None + assert captured_trace.metadata["user_id"] == "user123" + assert captured_trace.metadata["session_id"] == "session456" + assert captured_trace.metadata["custom_field"] == "custom_value" + + @patch.object(tracer, "_publish", False) + def test_update_current_step_metadata(self) -> None: + """Test that step metadata can be updated during execution.""" + captured_step = None + + @tracer.trace() + def test_function() -> str: + tracer.update_current_step( + metadata={"model_version": "v1.2.3"}, + attributes={"custom_attr": "value"}, + ) + nonlocal captured_step + captured_step = tracer.get_current_step() + return "result" + + test_function() + + assert captured_step is not None + assert captured_step.metadata is not None + assert captured_step.metadata["model_version"] == "v1.2.3" + assert captured_step.custom_attr == "value" + + @patch.object(tracer, "_publish", False) + def test_log_output_overrides_function_output(self) -> None: + """Test that log_output overrides the function's return value in trace.""" + captured_step = None + + @tracer.trace() + def test_function() -> str: + tracer.log_output("manual output") + nonlocal captured_step + captured_step = tracer.get_current_step() + return "function output" # This should be overridden + + result = test_function() + + # Function still returns its normal output + assert result == "function output" + + # But trace should show manual output + # Note: The manual output logging happens via metadata flag + assert captured_step is not None + assert captured_step.metadata is not None + assert captured_step.metadata.get("manual_output_logged") is True + + +class TestTraceSerialization: + """Test trace serialization and post-processing.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def test_step_to_dict_format(self) -> None: + """Test step serialization format.""" + step = steps.Step( + name="test_step", + inputs={"input1": "value1", "input2": 42}, + output={"result": "success"}, + metadata={"meta1": "metavalue1"}, + ) + step.step_type = enums.StepType.USER_CALL + # Fix the assignment issue by setting the end_time and latency properly + step.end_time = step.start_time + 0.1 # type: ignore + step.latency = 100.0 # type: ignore + + step_dict = step.to_dict() + + # Verify required fields + assert step_dict["name"] == "test_step" + assert step_dict["type"] == "user_call" + assert "id" in step_dict + assert "startTime" in step_dict + assert step_dict["endTime"] is not None + assert step_dict["latency"] == 100.0 + assert step_dict["inputs"] == {"input1": "value1", "input2": 42} + assert step_dict["output"] == {"result": "success"} + assert step_dict["metadata"] == {"meta1": "metavalue1"} + + def test_trace_to_dict_format(self) -> None: + """Test trace serialization format.""" + trace = traces.Trace() + + # Add a step to the trace + step = steps.Step(name="root_step") + step.step_type = enums.StepType.USER_CALL + trace.add_step(step) + + # Add nested step + nested_step = steps.Step(name="nested_step") + nested_step.step_type = enums.StepType.CHAT_COMPLETION + step.add_nested_step(nested_step) + + trace_dict = trace.to_dict() + + assert isinstance(trace_dict, list) + assert len(trace_dict) == 1 + assert trace_dict[0]["name"] == "root_step" + assert len(trace_dict[0]["steps"]) == 1 + assert trace_dict[0]["steps"][0]["name"] == "nested_step" + + @patch.object(tracer, "_publish", False) + def test_post_process_trace_format(self) -> None: + """Test the post_process_trace function output format.""" + captured_trace = None + + @tracer.trace() + def test_function(param1: str, param2: int) -> Dict[str, str]: # noqa: ARG001 + tracer.update_current_trace(user_id="test_user") + tracer.log_context(["context1", "context2"]) + nonlocal captured_trace + captured_trace = tracer.get_current_trace() + return {"answer": "test response"} + + test_function("test_param", 42) + + # Process the trace + assert captured_trace is not None + trace_data, input_variable_names = tracer.post_process_trace(captured_trace) + + # Verify trace_data structure + assert isinstance(trace_data, dict) + + # Check required fields + required_fields = [ + "inferenceTimestamp", + "inferenceId", + "output", + "latency", + "cost", + "tokens", + "steps", + ] + for field in required_fields: + assert field in trace_data, f"Missing field: {field}" + + # Verify input variables + assert "param1" in input_variable_names + assert "param2" in input_variable_names + assert trace_data["param1"] == "test_param" + assert trace_data["param2"] == 42 + + # Verify trace-level metadata was included + assert trace_data["user_id"] == "test_user" + + # Verify context was captured + assert trace_data["context"] == ["context1", "context2"] + + # Verify steps structure + assert isinstance(trace_data["steps"], list) + assert len(trace_data["steps"]) == 1 + assert trace_data["steps"][0]["name"] == "test_function" + + +class TestStepTypes: + """Test different step types and their specific behavior.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def test_step_factory_creates_correct_types(self) -> None: + """Test that step factory creates the correct step types.""" + step_types_mapping = { + enums.StepType.USER_CALL: steps.UserCallStep, + enums.StepType.CHAT_COMPLETION: steps.ChatCompletionStep, + enums.StepType.AGENT: steps.AgentStep, + enums.StepType.RETRIEVER: steps.RetrieverStep, + enums.StepType.TOOL: steps.ToolStep, + enums.StepType.GUARDRAIL: steps.GuardrailStep, + } + + for step_type, expected_class in step_types_mapping.items(): + step = steps.step_factory(step_type, name=f"test_{step_type.value}") + assert isinstance(step, expected_class) + assert step.step_type == step_type + + def test_chat_completion_step_serialization(self) -> None: + """Test ChatCompletionStep specific serialization.""" + step = steps.ChatCompletionStep(name="chat_step") + step.inputs = {"prompt": [{"role": "user", "content": "Hello"}]} + step.model = "gpt-3.5-turbo" + step.provider = "openai" + + step_dict = step.to_dict() + + assert step_dict["type"] == "chat_completion" + assert step_dict["inputs"]["prompt"] == [{"role": "user", "content": "Hello"}] + assert step_dict["model"] == "gpt-3.5-turbo" + assert step_dict["provider"] == "openai" + + @patch.object(tracer, "_publish", False) + def test_add_chat_completion_step(self) -> None: + """Test adding a chat completion step to trace.""" + captured_steps: List[Any] = [] + + with tracer.create_step("main_step") as main_step: + tracer.add_chat_completion_step_to_trace( + name="Chat Step", + model="gpt-4", + prompt=[{"role": "user", "content": "Test"}], + provider="openai", + ) + captured_steps = main_step.steps + + assert len(captured_steps) == 1 + chat_step = captured_steps[0] + assert chat_step.name == "Chat Step" + assert chat_step.step_type == enums.StepType.CHAT_COMPLETION + + +class TestErrorHandlingInTraces: + """Test error handling and exception capture in traces.""" + + def setup_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + def teardown_method(self) -> None: + tracer._configured_api_key = None + tracer._configured_pipeline_id = None + tracer._configured_base_url = None + tracer._client = None + + @patch.object(tracer, "_publish", False) + def test_exception_captured_in_metadata(self) -> None: + """Test that exceptions are captured in step metadata.""" + captured_step = None + + @tracer.trace() + def error_function() -> None: + nonlocal captured_step + captured_step = tracer.get_current_step() + raise ValueError("Test error message") + + with pytest.raises(ValueError, match="Test error message"): + error_function() + + # Verify exception was logged in metadata + assert captured_step is not None + assert captured_step.metadata is not None + assert "Exceptions" in captured_step.metadata + assert "Test error message" in captured_step.metadata["Exceptions"] + + @patch.object(tracer, "_publish", False) + def test_nested_exception_handling(self) -> None: + """Test exception handling in nested traced functions.""" + captured_steps: List[Any] = [] + + @tracer.trace() + def inner_error_function() -> None: + step = tracer.get_current_step() + captured_steps.append(step) + raise RuntimeError("Inner error") + + @tracer.trace() + def outer_function() -> None: + step = tracer.get_current_step() + captured_steps.append(step) + return inner_error_function() + + with pytest.raises(RuntimeError, match="Inner error"): + outer_function() + + # Both steps should have been captured + assert len(captured_steps) == 2 + + # Inner step should have exception metadata + inner_step = captured_steps[0] + assert inner_step is not None + assert inner_step.metadata is not None + assert "Exceptions" in inner_step.metadata + assert "Inner error" in inner_step.metadata["Exceptions"]