Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import inspect
import json
from collections.abc import Awaitable
Expand Down Expand Up @@ -179,6 +180,10 @@ class FunctionTool:
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""

_func: ToolFunction[...] | None = field(default=None, repr=False)
"""The function that implements the tool. Ensures that a reference to the
original function exists when @function_tool is used."""

# Tool-specific guardrails
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None
"""Optional list of input guardrails to run before invoking this tool."""
Expand All @@ -190,6 +195,17 @@ def __post_init__(self):
if self.strict_json_schema:
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)

if self._func:
functools.update_wrapper(self, self._func)

def __call__(self, *args, **kwargs):
if not self._func:
raise AttributeError("""FunctionTool has no attribute `_func` and is not callable.
Likely because it was created directly without the
@function_tool decorator.""")

return self._func(*args, **kwargs)


@dataclass
class FileSearchTool:
Expand Down Expand Up @@ -661,6 +677,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
_func=func,
)

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
6 changes: 3 additions & 3 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@


@function_tool
async def test_tool(query: str) -> str:
async def _test_tool(query: str) -> str:
"""A test tool for testing tool call tracking."""
return f"Tool result for: {query}"


@pytest.fixture
def agent() -> Agent:
"""Fixture for a basic agent with a fake model."""
return Agent(name="test", model=FakeModel(), tools=[test_tool])
return Agent(name="test", model=FakeModel(), tools=[_test_tool])


@pytest.fixture
Expand Down Expand Up @@ -961,7 +961,7 @@ async def test_tool_execution_integration(agent: Agent):
[
{ # type: ignore
"type": "function_call",
"name": "test_tool",
"name": "_test_tool",
"arguments": '{"query": "test query"}',
"call_id": "call_123",
}
Expand Down
50 changes: 50 additions & 0 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import json
from dataclasses import asdict
from typing import Any

import pytest
Expand Down Expand Up @@ -81,6 +83,44 @@ async def test_simple_function():
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)

# Direct call
result = tool(2, 2)
assert result == 4


async def async_function(a: int, b: int = 5):
return a + b


@pytest.mark.asyncio
async def test_async_function():
tool = function_tool(async_function, failure_error_function=None)
assert tool.name == "async_function"

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
'{"a": 1}',
)
assert result == 6

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'),
'{"a": 1, "b": 2}',
)
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)

# Direct call
result = await tool(2, 2)
assert result == 4

assert not inspect.iscoroutinefunction(tool.__call__), "tool.__call__ should sync."


class Foo(BaseModel):
a: int
Expand Down Expand Up @@ -148,6 +188,16 @@ async def test_complex_args_function():
)


def test_absent_func_tool():
tool = function_tool(simple_function)
kwargs = asdict(tool)
kwargs.pop("_func")
manually_defined_tool = FunctionTool(**kwargs)

with pytest.raises(AttributeError, match="not callable"):
manually_defined_tool(1, 1)


def test_function_config_overrides():
tool = function_tool(simple_function, name_override="custom_name")
assert tool.name == "custom_name"
Expand Down