8000 Merge 2e8757c8e16de95151f7803b68752d3377bae56b into sapling-pr-archiv… · openai/openai-agents-python@6f4a62e · GitHub
[go: up one dir, main page]

Skip to content

Commit 6f4a62e

Browse files
authored
Merge 2e8757c into sapling-pr-archive-rm-openai
2 parents b62a0f0 + 2e8757c commit 6f4a62e

16 files changed

+165
-5
lines changed

src/agents/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
4646
from .models.openai_provider import OpenAIProvider
4747
from .models.openai_responses import OpenAIResponsesModel
48+
from .prompts import DynamicPrompt, GenerateDynamicPromptData, Prompt
4849
from .repl import run_demo_loop
4950
from .result import RunResult, RunResultStreaming
5051
from .run import RunConfig, Runner
@@ -178,6 +179,9 @@ def enable_verbose_stdout_logging():
178179
"AgentsException",
179180
"InputGuardrailTripwireTriggered",
180181
"OutputGuardrailTripwireTriggered",
182+
"DynamicPrompt",
183+
"GenerateDynamicPromptData",
184+
"Prompt",
181185
"MaxTurnsExceeded",
182186
"ModelBehaviorError",
183187
"UserError",

src/agents/agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
99

10+
from openai.types.responses.response_create_params import Prompt as ResponsesPrompt
1011
from typing_extensions import NotRequired, TypeAlias, TypedDict
1112

1213
from .agent_output import AgentOutputSchemaBase
@@ -17,6 +18,7 @@
1718
from .mcp import MCPUtil
1819
from .model_settings import ModelSettings
1920
from .models.interface import Model
21+
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
2022
from .run_context import RunContextWrapper, TContext
2123
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
2224
from .util import _transforms
@@ -95,6 +97,12 @@ class Agent(Generic[TContext]):
9597
return a string.
9698
"""
9799

100+
prompt: Prompt | DynamicPromptFunction | None = None
101+
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
102+
configure the instructions, tools and other config for an agent outside of your code. Only
103+
usable with OpenAI models, using the Responses API.
104+
"""
105+
98106
handoff_description: str | None = None
99107
"""A description of the agent. This is used when the agent is used as a handoff, so that an
100108
LLM knows what it does and when to invoke it.
@@ -242,6 +250,10 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
242250

243251
return None
244252

253+
async def get_prompt(self, run_context: RunContextWrapper[TContext]) -> ResponsesPrompt | None:
254+
"""Get the prompt for the agent."""
255+
return await PromptUtil.to_model_input(self.prompt, run_context, self)
256+
245257
async def get_mcp_tools(self) -> list[Tool]:
246258
"""Fetches the available tools from the MCP servers."""
247259
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)

src/agents/extensions/models/litellm_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ async def get_response(
7171
handoffs: list[Handoff],
7272
tracing: ModelTracing,
7373
previous_response_id: str | None,
74+
prompt: Any | None = None,
7475
) -> ModelResponse:
7576
with generation_span(
7677
model=str(self.model),
@@ -88,6 +89,7 @@ async def get_response(
8889
span_generation,
8990
tracing,
9091
stream=False,
92+
prompt=prompt,
9193
)
9294

9395
assert isinstance(response.choices[0], litellm.types.utils.Choices)
@@ -153,8 +155,8 @@ async def stream_response(
153155
output_schema: AgentOutputSchemaBase | None,
154156
handoffs: list[Handoff],
155157
tracing: ModelTracing,
156-
*,
157158
previous_response_id: str | None,
159+
prompt: Any | None = None,
158160
) -> AsyncIterator[TResponseStreamEvent]:
159161
with generation_span(
160162
model=str(self.model),
@@ -172,6 +174,7 @@ async def stream_response(
172174
span_generation,
173175
tracing,
174176
stream=True,
177+
prompt=prompt,
175178
)
176179

177180
final_response: Response | None = None
@@ -202,6 +205,7 @@ async def _fetch_response(
202205
span: Span[GenerationSpanData],
203206
tracing: ModelTracing,
204207
stream: Literal[True],
208+
prompt: Any | None = None,
205209
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
206210

207211
@overload
@@ -216,6 +220,7 @@ async def _fetch_response(
216220
span: Span[GenerationSpanData],
217221
tracing: ModelTracing,
218222
stream: Literal[False],
223+
prompt: Any | None = None,
219224
) -> litellm.types.utils.ModelResponse: ...
220225

221226
async def _fetch_response(
@@ -229,6 +234,7 @@ async def _fetch_response(
229234
span: Span[GenerationSpanData],
230235
tracing: ModelTracing,
231236
stream: bool = False,
237+
prompt: Any | None = None,
232238
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
233239
converted_messages = Converter.items_to_messages(input)
234240

src/agents/function_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def function_schema(
223223
doc_info = None
224224
param_descs = {}
225225

226-
func_name = name_override or doc_info.name if doc_info else func.__name__
226+
# Ensure name_override takes precedence even if docstring info is disabled.
227+
func_name = name_override or (doc_info.name if doc_info else func.__name__)
227228

228229
# 2. Inspect function signature and get type hints
229230
sig = inspect.signature(func)

src/agents/models/interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from collections.abc import AsyncIterator
66
from typing import TYPE_CHECKING
77

8+
from openai.types.responses.response_create_params import Prompt as ResponsesPrompt
9+
810
from ..agent_output import AgentOutputSchemaBase
911
from ..handoffs import Handoff
1012
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
@@ -46,6 +48,7 @@ async def get_response(
4648
tracing: ModelTracing,
4749
*,
4850
previous_response_id: str | None,
51+
prompt: ResponsesPrompt | None,
4952
) -> ModelResponse:
5053
"""Get a response from the model.
5154
@@ -59,6 +62,7 @@ async def get_response(
5962
tracing: Tracing configuration.
6063
previous_response_id: the ID of the previous response. Generally not used by the model,
6164
except for the OpenAI Responses API.
65+
prompt: The prompt config to use for the model.
6266
6367
Returns:
6468
The full model response.
@@ -77,6 +81,7 @@ def stream_response(
7781
tracing: ModelTracing,
7882
*,
7983
previous_response_id: str | None,
84+
prompt: ResponsesPrompt | None,
8085
) -> AsyncIterator[TResponseStreamEvent]:
8186
"""Stream a response from the model.
8287
@@ -90,6 +95,7 @@ def stream_response(
9095
tracing: Tracing configuration.
9196
previous_response_id: the ID of the previous response. Generally not used by the model,
9297
except for the OpenAI Responses API.
98+
prompt: The prompt config to use for the model.
9399
94100
Returns:
95101
An iterator of response stream events, in OpenAI Responses format.

src/agents/models/openai_chatcompletions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from openai.types import ChatModel
1010
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1111
from openai.types.responses import Response
12+
from openai.types.responses.response_create_params import Prompt as ResponsesPrompt
1213
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
1314

1415
from .. import _debug
@@ -53,6 +54,7 @@ async def get_response(
5354
handoffs: list[Handoff],
5455
tracing: ModelTracing,
5556
previous_response_id: str | None,
57+
prompt: ResponsesPrompt | None = None,
5658
) -> ModelResponse:
5759
with generation_span(
5860
model=str(self.model),
@@ -69,6 +71,7 @@ async def get_response(
6971
span_generation,
7072
tracing,
7173
stream=False,
74+
prompt=prompt,
7275
)
7376

7477
first_choice = response.choices[0]
@@ -136,8 +139,8 @@ async def stream_response(
136139
output_schema: AgentOutputSchemaBase | None,
137140
handoffs: list[Handoff],
138141
tracing: ModelTracing,
139-
*,
140142
previous_response_id: str | None,
143+
prompt: ResponsesPrompt | None = None,
141144
) -> AsyncIterator[TResponseStreamEvent]:
142145
"""
143146
Yields a partial message as it is generated, as well as the usage information.
@@ -157,6 +160,7 @@ async def stream_response(
157160
span_generation,
158161
tracing,
159162
stream=True,
163+
prompt=prompt,
160164
)
161165

162166
final_response: Response | None = None
@@ -187,6 +191,7 @@ async def _fetch_response(
187191
span: Span[GenerationSpanData],
188192
tracing: ModelTracing,
189193
stream: Literal[True],
194+
prompt: ResponsesPrompt | None = None,
190195
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
191196

192197
@overload
@@ -201,6 +206,7 @@ async def _fetch_response(
201206
span: Span[GenerationSpanData],
202207
tracing: ModelTracing,
203208
stream: Literal[False],
209+
prompt: ResponsesPrompt | None = None,
204210
) -> ChatCompletion: ...
205211

206212
async def _fetch_response(
@@ -214,6 +220,7 @@ async def _fetch_response(
214220
span: Span[GenerationSpanData],
215221
tracing: ModelTracing,
216222
stream: bool = False,
223+
prompt: ResponsesPrompt | None = None,
217224
) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]:
218225
converted_messages = Converter.items_to_messages(input)
219226

src/agents/models/openai_responses.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def get_response(
7474
handoffs: list[Handoff],
7575
tracing: ModelTracing,
7676
previous_response_id: str | None,
77+
prompt: ResponsesPrompt | None = None,
7778
) -> ModelResponse:
7879
with response_span(disabled=tracing.is_disabled()) as span_response:
7980
try:
@@ -86,6 +87,7 @@ async def get_response(
8687
handoffs,
8788
previous_response_id,
8889
stream=False,
90+
prompt=prompt,
8991
)
9092

9193
if _debug.DONT_LOG_MODEL_DATA:
@@ -141,6 +143,7 @@ async def stream_response(
141143
handoffs: list[Handoff],
142144
tracing: ModelTracing,
143145
previous_response_id: str | None,
146+
prompt: ResponsesPrompt | None = None,
144147
) -> AsyncIterator[ResponseStreamEvent]:
145148
"""
146149
Yields a partial message as it is generated, as well as the usage information.
@@ -156,6 +159,7 @@ async def stream_response(
156159
handoffs,
157160
previous_response_id,
158161
stream=True,
162+
prompt=prompt,
159163
)
160164

161165
final_response: Response | None = None
@@ -192,6 +196,7 @@ async def _fetch_response(
192196
handoffs: list[Handoff],
193197
previous_response_id: str | None,
194198
stream: Literal[True],
199+
prompt: ResponsesPrompt | None = None,
195200
) -> AsyncStream[ResponseStreamEvent]: ...
196201

197202
@overload
@@ -205,6 +210,7 @@ async def _fetch_response(
205210
handoffs: list[Handoff],
206211
previous_response_id: str | None,
207212
stream: Literal[False],
213+
prompt: ResponsesPrompt | None = None,
208214
) -> Response: ...
209215

210216
async def _fetch_response(
@@ -217,6 +223,7 @@ async def _fetch_response(
217223
handoffs: list[Handoff],
218224
previous_response_id: str | None,
219225
stream: Literal[True] | Literal[False] = False,
226+
prompt: ResponsesPrompt | None = None,
220227
) -> Response | AsyncStream[ResponseStreamEvent]:
221228
list_input = ItemHelpers.input_to_new_input_list(input)
222229

src/agents/prompts.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any, Callable
6+
7+
from openai.types.responses.response_create_params import (
8+
Prompt as ResponsesPrompt,
9+
PromptVariables as ResponsesPromptVariables,
10+
)
11+
from typing_extensions import NotRequired, TypedDict
12+
13+
from agents.util._types import MaybeAwaitable
14+
15+
from .exceptions import UserError
16+
from .run_context import RunContextWrapper
17+
18+
if TYPE_CHECKING:
19+
from .agent import Agent
20+
21+
22+
class Prompt(TypedDict):
23+
"""Prompt configuration to use for interacting with an OpenAI model."""
24+
25+
id: str
26+
"""The unique ID of the prompt."""
27+
28+
version: NotRequired[str]
29+
"""Optional version of the prompt."""
30+
31+
variables: NotRequired[dict[str, ResponsesPromptVariables]]
32+
"""Optional variables to substitute into the prompt."""
33+
34+
35+
@dataclass
36+
class GenerateDynamicPromptData:
37+
"""Inputs to a function that allows you to dynamically generate a prompt."""
38+
39+
context: RunContextWrapper[Any]
40+
"""The run context."""
41+
42+
agent: Agent[Any]
43+
"""The agent for which the prompt is being generated."""
44+
45+
46+
DynamicPromptFunction = Callable[[GenerateDynamicPromptData], MaybeAwaitable[Prompt]]
47+
"""A function that dynamically generates a prompt."""
48+
49+
50+
class PromptUtil:
51+
@staticmethod
52+
async def to_model_input(
53+
prompt: Prompt | DynamicPromptFunction | None,
54+
context: RunContextWrapper[Any],
55+
agent: Agent[Any],
56+
) -> ResponsesPrompt | None:
57+
if prompt is None:
58+
return None
59+
60+
resolved_prompt: Prompt
61+
if isinstance(prompt, Prompt):
62+
resolved_prompt = prompt
63+
else:
64+
func_result = prompt(GenerateDynamicPromptData(context=context, agent=agent))
65+
if inspect.isawaitable(func_result):
66+
resolved_prompt = await func_result
67+
else:
68+
resolved_prompt = func_result
69+
if not isinstance(resolved_prompt, Prompt):
70+
raise UserError("Dynamic prompt function must return a Prompt")
71+
72+
return {
73+
"id": resolved_prompt["id"],
74+
"version": resolved_prompt.get("version"),
75+
"variables": resolved_prompt.get("variables"),
76+
}

0 commit comments

Comments
 (0)
0