8000 Introduce AgentRunner class by pakrym-oai · Pull Request #886 · openai/openai-agents-python · GitHub
[go: up one dir, main page]

Skip to content

Introduce AgentRunner class #886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .models.openai_responses import OpenAIResponsesModel
from .repl import run_demo_loop
from .result import RunResult, RunResultStreaming
from .run import DefaultRunner, RunConfig, Runner, set_default_runner
from .run import AgentRunner, DefaultAgentRunner, RunConfig, Runner
from .run_context import RunContextWrapper, TContext
from .stream_events import (
AgentUpdatedStreamEvent,
Expand Down Expand Up @@ -162,7 +162,8 @@ def enable_verbose_stdout_logging():
"ToolsToFinalOutputFunction",
"ToolsToFinalOutputResult",
"Runner",
"DefaultRunner",
"AgentRunner",
"DefaultAgentRunner",
"run_demo_loop",
"Model",
"ModelProvider",
Expand Down Expand Up @@ -241,7 +242,6 @@ def enable_verbose_stdout_logging():
"generation_span",
"get_current_span",
"get_current_trace",
"get_default_runner",
"guardrail_span",
"handoff_span",
"set_trace_processors",
Expand Down Expand Up @@ -270,7 +270,6 @@ def enable_verbose_stdout_logging():
"set_default_openai_key",
"set_default_openai_client",
"set_default_openai_api",
"set_default_runner",
"set_tracing_export_api_key",
"enable_verbose_stdout_logging",
"gen_trace_id",
Expand Down
131 changes: 61 additions & 70 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import asyncio
import copy
from dataclasses import dataclass, field
from typing import Any, cast
from typing import Any, Generic, cast

from openai.types.responses import ResponseCompletedEvent
from typing_extensions import TypedDict, Unpack

from ._run_impl import (
AgentToolUseTracker,
Expand Down Expand Up @@ -47,23 +48,6 @@
from .util import _coro, _error_tracing

DEFAULT_MAX_TURNS = 10
DEFAULT_RUNNER: Runner = None # type: ignore
# assigned at the end of the module initialization


def set_default_runner(runner: Runner | None) -> None:
"""
Set the default runner to use for the agent run.
"""
global DEFAULT_RUNNER
DEFAULT_RUNNER = runner or DefaultRunner()

def get_default_runner() -> Runner | None:
"""
Get the default runner to use for the agent run.
"""
global DEFAULT_RUNNER
return DEFAULT_RUNNER


@dataclass
Expand Down Expand Up @@ -125,48 +109,57 @@ class RunConfig:
"""


class Runner(abc.ABC):
class AgentRunnerParams(TypedDict, Generic[TContext]):
"""Arguments for ``AgentRunner`` methods."""

context: TContext | None
"""The context for the run."""

max_turns: int
"""The maximum number of turns to run for."""

hooks: RunHooks[TContext] | None
"""Lifecycle hooks for the run."""

run_config: RunConfig | None
"""Run configuration."""

previous_response_id: str | None
"""The ID of the previous response, if any."""


class AgentRunner(abc.ABC):
@abc.abstractmethod
async def _run_impl(
async def run(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResult:
pass

@abc.abstractmethod
def _run_sync_impl(
def run_sync(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResult:
pass

@abc.abstractmethod
def _run_streamed_impl(
def run_streamed(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResultStreaming:
pass


class Runner:
pass

@classmethod
async def run(
cls,
Expand Down Expand Up @@ -205,8 +198,8 @@ async def run(
A run result containing all the inputs, guardrail results and the output of the last
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
runner = DEFAULT_RUNNER
return await runner._run_impl(
runner = DefaultAgentRunner()
return await runner.run(
starting_agent,
input,
context=context,
Expand Down Expand Up @@ -257,8 +250,8 @@ def run_sync(
A run result containing all the inputs, guardrail results and the output of the last
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
runner = DEFAULT_RUNNER
return runner._run_sync_impl(
runner = DefaultAgentRunner()
return runner.run_sync(
starting_agent,
input,
context=context,
Expand Down Expand Up @@ -305,8 +298,8 @@ def run_streamed(
Returns:
A result object that contains data about the run, as well as a method to stream events.
"""
runner = DEFAULT_RUNNER
return runner._run_streamed_impl(
runner = DefaultAgentRunner()
return runner.run_streamed(
starting_agent,
input,
context=context,
Expand All @@ -316,7 +309,6 @@ def run_streamed(
previous_response_id=previous_response_id,
)


@classmethod
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
if agent.output_type is None or agent.output_type is str:
Expand Down Expand Up @@ -353,18 +345,19 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

return run_config.model_provider.get_model(agent.model)

class DefaultRunner(Runner):
async def _run_impl(

class DefaultAgentRunner(AgentRunner, Runner):
async def run( # type: ignore[override]
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResult:
context = kwargs.get("context")
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
run_config = kwargs.get("run_config")
previous_response_id = kwargs.get("previous_response_id")
if hooks is None:
hooks = RunHooks[Any]()
if run_config is None:
Expand Down Expand Up @@ -514,17 +507,17 @@ async def _run_impl(
if current_span:
current_span.finish(reset_current=True)

def _run_sync_impl(
def run_sync( # type: ignore[override]
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResult:
context = kwargs.get("context")
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
run_config = kwargs.get("run_config")
previous_response_id = kwargs.get("previous_response_id")
return asyncio.get_event_loop().run_until_complete(
self.run(
starting_agent,
Expand All @@ -537,16 +530,17 @@ def _run_sync_impl(
)
)

def _run_streamed_impl(
def run_streamed( # type: ignore[override]
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
A3E2 hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
**kwargs: Unpack[AgentRunnerParams[TContext]],
) -> RunResultStreaming:
context = kwargs.get("context")
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
run_config = kwargs.get("run_config")
previous_response_id = kwargs.get("previous_response_id")
if hooks is None:
hooks = RunHooks[Any]()
if run_config is None:
Expand Down Expand Up @@ -1108,6 +1102,3 @@ async def _get_new_response(
context_wrapper.usage.add(new_response.usage)

return new_response


DEFAULT_RUNNER = DefaultRunner()
6 changes: 0 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from agents.models import _openai_shared
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
from agents.models.openai_responses import OpenAIResponsesModel
from agents.run import set_default_runner
from agents.tracing import set_trace_processors
from agents.tracing.setup import get_trace_provider

Expand Down Expand Up @@ -34,11 +33,6 @@ def clear_openai_settings():
_openai_shared._use_responses_by_default = True


@pytest.fixture(autouse=True)
def clear_default_runner():
set_default_runner(None)


# This fixture will run after all tests end
@pytest.fixture(autouse=True, scope="session")
def shutdown_trace_provider():
Expand Down
26 changes: 0 additions & 26 deletions tests/test_run.py

This file was deleted.

0