@@ -661,7 +661,6 @@ async def _arun_step(
661
661
step , task .extra_state ["new_memory" ], verbose = self ._verbose
662
662
)
663
663
664
- # TODO: see if we want to do step-based inputs
665
664
tools = self .get_tools (task .input )
666
665
openai_tools = [tool .metadata .to_openai_tool () for tool in tools ]
667
666
@@ -670,40 +669,46 @@ async def _arun_step(
670
669
task , mode = mode , ** llm_chat_kwargs
671
670
)
672
671
673
- # TODO: implement _should_continue
674
672
latest_tool_calls = self .get_latest_tool_calls (task ) or []
675
673
latest_tool_outputs : List [ToolOutput ] = []
676
674
677
675
if not self ._should_continue (
678
676
latest_tool_calls , task .extra_state ["n_function_calls" ]
679
677
):
680
678
is_done = True
681
-
682
679
else :
683
680
is_done = False
681
+
682
+ # Validate all tool calls first
684
683
for tool_call in latest_tool_calls :
685
- # Some validation
686
684
if not isinstance (tool_call , get_args (OpenAIToolCall )):
687
685
raise ValueError ("Invalid tool_call object" )
688
-
689
686
if tool_call .type != "function" :
690
687
raise ValueError ("Invalid tool type. Unsupported by OpenAI" )
691
688
692
- # TODO: maybe execute this with multi-threading
693
- return_direct = await self ._acall_function (
694
- tools ,
695
- tool_call ,
696
- task .extra_state ["new_memory" ],
697
- latest_tool_outputs ,
698
- )
689
+ # Execute all tool calls in parallel using asyncio.gather
690
+ tool_results = await asyncio .gather (
691
+ * [
692
+ self ._acall_function (
693
+ tools ,
694
+ tool_call ,
695
+ task .extra_state ["new_memory" ],
696
+ latest_tool_outputs ,
697
+ )
698
+ for tool_call in latest_tool_calls
699
+ ]
700
+ )
701
+
702
+ # Process results
703
+ for return_direct in tool_results :
699
704
task .extra_state ["sources" ].append (latest_tool_outputs [- 1 ])
700
705
701
- # change function call to the default value, if a custom function was given
702
- # as an argument (none and auto are predefined by OpenAI)
706
+ # change function call to the default value if a custom function was given
703
707
if tool_choice not in ("auto" , "none" ):
704
708
tool_choice = "auto"
705
709
task .extra_state ["n_function_calls" ] += 1
706
710
711
+ # If any tool call requests direct return and it's the only call
707
712
if return_direct and len (latest_tool_calls ) == 1 :
708
713
is_done = True
709
714
response_str = latest_tool_outputs [- 1 ].content
@@ -723,7 +728,6 @@ async def _arun_step(
723
728
[
724
729
step .get_next_step (
725
730
step_id = str (uuid .uuid4 ()),
726
- # NOTE: input is unused
727
731
input = None ,
728
732
)
729
733
]
0 commit comments