diff --git a/src/agents/__init__.py b/src/agents/__init__.py index afa578b5..d2e0857e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -47,7 +47,7 @@ from .models.openai_responses import OpenAIResponsesModel from .repl import run_demo_loop from .result import RunResult, RunResultStreaming -from .run import DefaultRunner, RunConfig, Runner, set_default_runner +from .run import AgentRunner, DefaultAgentRunner, RunConfig, Runner from .run_context import RunContextWrapper, TContext from .stream_events import ( AgentUpdatedStreamEvent, @@ -162,7 +162,8 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", - "DefaultRunner", + "AgentRunner", + "DefaultAgentRunner", "run_demo_loop", "Model", "ModelProvider", @@ -241,7 +242,6 @@ def enable_verbose_stdout_logging(): "generation_span", "get_current_span", "get_current_trace", - "get_default_runner", "guardrail_span", "handoff_span", "set_trace_processors", @@ -270,7 +270,6 @@ def enable_verbose_stdout_logging(): "set_default_openai_key", "set_default_openai_client", "set_default_openai_api", - "set_default_runner", "set_tracing_export_api_key", "enable_verbose_stdout_logging", "gen_trace_id", diff --git a/src/agents/run.py b/src/agents/run.py index 1c301cb0..ce48d1dc 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,9 +4,10 @@ import asyncio import copy from dataclasses import dataclass, field -from typing import Any, cast +from typing import Any, Generic, cast from openai.types.responses import ResponseCompletedEvent +from typing_extensions import TypedDict, Unpack from ._run_impl import ( AgentToolUseTracker, @@ -47,23 +48,6 @@ from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 -DEFAULT_RUNNER: Runner = None # type: ignore -# assigned at the end of the module initialization - - -def set_default_runner(runner: Runner | None) -> None: - """ - Set the default runner to use for the agent run. - """ - global DEFAULT_RUNNER - DEFAULT_RUNNER = runner or DefaultRunner() - -def get_default_runner() -> Runner | None: - """ - Get the default runner to use for the agent run. - """ - global DEFAULT_RUNNER - return DEFAULT_RUNNER @dataclass @@ -125,48 +109,57 @@ class RunConfig: """ -class Runner(abc.ABC): +class AgentRunnerParams(TypedDict, Generic[TContext]): + """Arguments for ``AgentRunner`` methods.""" + + context: TContext | None + """The context for the run.""" + + max_turns: int + """The maximum number of turns to run for.""" + + hooks: RunHooks[TContext] | None + """Lifecycle hooks for the run.""" + + run_config: RunConfig | None + """Run configuration.""" + + previous_response_id: str | None + """The ID of the previous response, if any.""" + + +class AgentRunner(abc.ABC): @abc.abstractmethod - async def _run_impl( + async def run( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def _run_sync_impl( + def run_sync( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def _run_streamed_impl( + def run_streamed( self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: pass + +class Runner: + pass + @classmethod async def run( cls, @@ -205,8 +198,8 @@ async def run( A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. """ - runner = DEFAULT_RUNNER - return await runner._run_impl( + runner = DefaultAgentRunner() + return await runner.run( starting_agent, input, context=context, @@ -257,8 +250,8 @@ def run_sync( A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. """ - runner = DEFAULT_RUNNER - return runner._run_sync_impl( + runner = DefaultAgentRunner() + return runner.run_sync( starting_agent, input, context=context, @@ -305,8 +298,8 @@ def run_streamed( Returns: A result object that contains data about the run, as well as a method to stream events. """ - runner = DEFAULT_RUNNER - return runner._run_streamed_impl( + runner = DefaultAgentRunner() + return runner.run_streamed( starting_agent, input, context=context, @@ -316,7 +309,6 @@ def run_streamed( previous_response_id=previous_response_id, ) - @classmethod def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: if agent.output_type is None or agent.output_type is str: @@ -353,18 +345,19 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) -class DefaultRunner(Runner): - async def _run_impl( + +class DefaultAgentRunner(AgentRunner, Runner): + async def run( # type: ignore[override] self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -514,17 +507,17 @@ async def _run_impl( if current_span: current_span.finish(reset_current=True) - def _run_sync_impl( + def run_sync( # type: ignore[override] self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") return asyncio.get_event_loop().run_until_complete( self.run( starting_agent, @@ -537,16 +530,17 @@ def _run_sync_impl( ) ) - def _run_streamed_impl( + def run_streamed( # type: ignore[override] self, starting_agent: Agent[TContext], input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -1108,6 +1102,3 @@ async def _get_new_response( context_wrapper.usage.add(new_response.usage) return new_response - - -DEFAULT_RUNNER = DefaultRunner() diff --git a/tests/conftest.py b/tests/conftest.py index f87e8559..7527e11b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from agents.models import _openai_shared from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_responses import OpenAIResponsesModel -from agents.run import set_default_runner from agents.tracing import set_trace_processors from agents.tracing.setup import get_trace_provider @@ -34,11 +33,6 @@ def clear_openai_settings(): _openai_shared._use_responses_by_default = True -@pytest.fixture(autouse=True) -def clear_default_runner(): - set_default_runner(None) - - # This fixture will run after all tests end @pytest.fixture(autouse=True, scope="session") def shutdown_trace_provider(): diff --git a/tests/test_run.py b/tests/test_run.py deleted file mode 100644 index 57e33d50..00000000 --- a/tests/test_run.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from unittest import mock - -import pytest - -from agents import Agent, Runner -from agents.run import set_default_runner - -from .fake_model import FakeModel - - -@pytest.mark.asyncio -async def test_static_run_methods_call_into_default_runner() -> None: - runner = mock.Mock(spec=Runner) - set_default_runner(runner) - - agent = Agent(name="test", model=FakeModel()) - await Runner.run(agent, input="test") - runner._run_impl.assert_called_once() - - Runner.run_streamed(agent, input="test") - runner._run_streamed_impl.assert_called_once() - - Runner.run_sync(agent, input="test") - runner._run_sync_impl.assert_called_once()