From 682f49cb9b16ae7880ba787f2958fa36b0610422 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Mon, 2 Jun 2025 14:08:56 -0400 Subject: [PATCH] Don't cache agent tools during a run ### Summary: Towards #767. We were caching the list of tools for an agent, so if you did `agent.tools.append(...)` from a tool call, the next call to the model wouldn't include the new tool. THis is a bug. ### Test Plan: Unit tests. Note that now MCP tools are listed each time the agent runs (users can still cache the `list_tools` however). --- src/agents/exceptions.py | 2 + src/agents/extensions/models/litellm_model.py | 6 +- src/agents/mcp/server.py | 10 ++-- src/agents/result.py | 1 - src/agents/run.py | 15 ++--- src/agents/voice/model.py | 2 + tests/mcp/test_mcp_tracing.py | 60 ++++++++++++------- tests/models/test_litellm_extra_body.py | 3 +- tests/test_agent_runner.py | 35 +++++++++++ tests/test_agent_runner_streamed.py | 37 ++++++++++++ tests/test_run_error_details.py | 20 ++++--- 11 files changed, 143 insertions(+), 48 deletions(-) diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 4f6e2e768..c00024c2e 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -15,6 +15,7 @@ @dataclass class RunErrorDetails: """Data collected from an agent run when an exception occurs.""" + input: str | list[TResponseInputItem] new_items: list[RunItem] raw_responses: list[ModelResponse] @@ -29,6 +30,7 @@ def __str__(self) -> str: class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + run_data: RunErrorDetails | None def __init__(self, *args: object) -> None: diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 14130c5e3..7cf7a2de5 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -110,12 +110,14 @@ async def get_response( input_tokens_details=InputTokensDetails( cached_tokens=getattr( response_usage.prompt_tokens_details, "cached_tokens", 0 - ) or 0 + ) + or 0 ), output_tokens_details=OutputTokensDetails( reasoning_tokens=getattr( response_usage.completion_tokens_details, "reasoning_tokens", 0 - ) or 0 + ) + or 0 ), ) if response.usage diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 414b517ab..b012a0968 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -88,7 +88,7 @@ def create_streams( tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback | None + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -243,7 +243,7 @@ def create_streams( tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback | None + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -314,7 +314,7 @@ def create_streams( tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback | None + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -394,7 +394,7 @@ def create_streams( tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback | None + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -403,7 +403,7 @@ def create_streams( headers=self.params.get("headers", None), timeout=self.params.get("timeout", timedelta(seconds=30)), sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)), - terminate_on_close=self.params.get("terminate_on_close", True) + terminate_on_close=self.params.get("terminate_on_close", True), ) @property diff --git a/src/agents/result.py b/src/agents/result.py index 764815246..5cf0e74c8 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -274,4 +274,3 @@ def _cleanup_tasks(self): def __str__(self) -> str: return pretty_print_run_result_streaming(self) - diff --git a/src/agents/run.py b/src/agents/run.py index c67386495..bfbdacd45 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,4 +1,3 @@ - from __future__ import annotations import asyncio @@ -182,6 +181,8 @@ async def run( try: while True: + all_tools = await cls._get_all_tools(current_agent) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -197,8 +198,6 @@ async def run( output_type=output_type_name, ) current_span.start(mark_as_current=True) - - all_tools = await cls._get_all_tools(current_agent) current_span.span_data.tools = [t.name for t in all_tools] current_turn += 1 @@ -210,9 +209,7 @@ async def run( data={"max_turns": max_turns}, ), ) - raise MaxTurnsExceeded( - f"Max turns ({max_turns}) exceeded" - ) + raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") logger.debug( f"Running agent {current_agent.name} (turn {current_turn})", @@ -295,7 +292,7 @@ async def run( last_agent=current_agent, context_wrapper=context_wrapper, input_guardrail_results=input_guardrail_results, - output_guardrail_results=[] + output_guardrail_results=[], ) raise finally: @@ -528,6 +525,8 @@ async def _run_streamed_impl( if streamed_result.is_complete: break + all_tools = await cls._get_all_tools(current_agent) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -543,8 +542,6 @@ async def _run_streamed_impl( output_type=output_type_name, ) current_span.start(mark_as_current=True) - - all_tools = await cls._get_all_tools(current_agent) tool_names = [t.name for t in all_tools] current_span.span_data.tools = tool_names current_turn += 1 diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de76..b048a452d 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/mcp/test_mcp_tracing.py b/tests/mcp/test_mcp_tracing.py index b71954b5b..54575dcb5 100644 --- a/tests/mcp/test_mcp_tracing.py +++ b/tests/mcp/test_mcp_tracing.py @@ -44,6 +44,10 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, + }, { "type": "agent", "data": { @@ -53,10 +57,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, - }, { "type": "function", "data": { @@ -66,8 +66,12 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, + }, ], - } + }, ], } ] @@ -100,6 +104,13 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2"], + }, + }, { "type": "agent", "data": { @@ -109,13 +120,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": { - "server": "fake_mcp_server", - "result": ["test_tool_1", "test_tool_2"], - }, - }, { "type": "function", "data": { @@ -133,8 +137,15 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2"], + }, + }, ], - } + }, ], } ] @@ -165,6 +176,13 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2", "test_tool_3"], + }, + }, { "type": "agent", "data": { @@ -174,13 +192,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": { - "server": "fake_mcp_server", - "result": ["test_tool_1", "test_tool_2", "test_tool_3"], - }, - }, { "type": "function", "data": { @@ -190,8 +201,15 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2", "test_tool_3"], + }, + }, ], - } + }, ], } ] diff --git a/tests/models/test_litellm_extra_body.py b/tests/models/test_litellm_extra_body.py index ac56c25cf..3c83b0607 100644 --- a/tests/models/test_litellm_extra_body.py +++ b/tests/models/test_litellm_extra_body.py @@ -26,8 +26,7 @@ async def fake_acompletion(model, messages=None, **kwargs): monkeypatch.setattr(litellm, "acompletion", fake_acompletion) settings = ModelSettings( - temperature=0.1, - extra_body={"cached_content": "some_cache", "foo": 123} + temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123} ) model = LitellmModel(model="test-model") diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 14a278a90..77eb8019e 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn(): pass assert model.last_turn_args.get("previous_response_id") == "resp-stream-test" + + +@pytest.mark.asyncio +async def test_dynamic_tool_addition_run() -> None: + """Test that tools can be added to an agent during a run.""" + model = FakeModel() + + executed: dict[str, bool] = {"called": False} + + agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again") + + @function_tool(name_override="tool2") + def tool2() -> str: + executed["called"] = True + return "result2" + + @function_tool(name_override="add_tool") + async def add_tool() -> str: + agent.tools.append(tool2) + return "added" + + agent.tools.append(add_tool) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("add_tool", json.dumps({}))], + [get_function_tool_call("tool2", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="start") + + assert executed["called"] is True + assert result.final_output == "done" diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 87a76a706..31fe2979b 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -18,6 +18,7 @@ RunContextWrapper, Runner, UserError, + function_tool, handoff, ) from agents.items import RunItem @@ -684,3 +685,39 @@ async def test_streaming_events(): assert len(agent_data) == 2, "should have 2 agent updated events" assert agent_data[0].new_agent == agent_2, "should have started with agent_2" assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1" + + +@pytest.mark.asyncio +async def test_dynamic_tool_addition_run_streamed() -> None: + model = FakeModel() + + executed: dict[str, bool] = {"called": False} + + agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again") + + @function_tool(name_override="tool2") + def tool2() -> str: + executed["called"] = True + return "result2" + + @function_tool(name_override="add_tool") + async def add_tool() -> str: + agent.tools.append(tool2) + return "added" + + agent.tools.append(add_tool) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("add_tool", json.dumps({}))], + [get_function_tool_call("tool2", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="start") + async for _ in result.stream_events(): + pass + + assert executed["called"] is True + assert result.final_output == "done" diff --git a/tests/test_run_error_details.py b/tests/test_run_error_details.py index 2268b3780..104b248fc 100644 --- a/tests/test_run_error_details.py +++ b/tests/test_run_error_details.py @@ -12,10 +12,12 @@ async def test_run_error_includes_data(): model = FakeModel() agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) - model.add_multiple_turn_outputs([ - [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], - [get_text_message("done")], - ]) + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) with pytest.raises(MaxTurnsExceeded) as exc: await Runner.run(agent, input="hello", max_turns=1) data = exc.value.run_data @@ -29,10 +31,12 @@ async def test_run_error_includes_data(): async def test_streamed_run_error_includes_data(): model = FakeModel() agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) - model.add_multiple_turn_outputs([ - [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], - [get_text_message("done")], - ]) + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) result = Runner.run_streamed(agent, input="hello", max_turns=1) with pytest.raises(MaxTurnsExceeded) as exc: async for _ in result.stream_events():