8000 iterative streaming (#241) · 0xLiam-bit/sdk-python@d8ce2d5 · GitHub
[go: up one dir, main page]

Skip to content

Commit d8ce2d5

Browse files
authored
iterative streaming (strands-agents#241)
1 parent e693738 commit d8ce2d5

File tree

7 files changed

+427
-166
lines changed

7 files changed

+427
-166
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,19 @@ def event_loop_cycle(
130130
)
131131

132132
try:
133-
stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages(
134-
model,
135-
system_prompt,
136-
messages,
137-
tool_config,
138-
callback_handler,
139-
**kwargs,
140-
)
133+
# TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the
134+
# call stack. At this point, we converted all events that were previously passed to the handler in
135+
# `stream_messages` into yielded events that now have the "callback" key. To maintain backwards
136+
# compatability, we need to combine the event with kwargs before passing to the handler. This we will
137+
# revisit when migrating to strongly typed events.
138+
for event in stream_messages(model, system_prompt, messages, tool_config):
139+
if "callback" in event:
140+
inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}
141+
callback_handler(**inputs)
142+
else:
143+
stop_reason, message, usage, metrics = event["stop"]
144+
kwargs.setdefault("request_state", {})
145+
141146
if model_invoke_span:
142147
tracer.end_model_invoke_span(model_invoke_span, message, usage)
143148
break # Success! Break out of retry loop
@@ -334,7 +339,7 @@ def _handle_tool_execution(
334339
kwargs (Dict[str, Any]): Additional keyword arguments, including request state.
335340
336341
Returns:
337-
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
342+
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
338343
- The stop reason,
339344
- The updated message,
340345
- The updated event loop metrics,

src/strands/event_loop/streaming.py

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import Any, Dict, Iterable, List, Optional, Tuple
5+
from typing import Any, Generator, Iterable, Optional
66

77
from ..types.content import ContentBlock, Message, Messages
88
from ..types.models import Model
@@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message:
8080
return message
8181

8282

83-
def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:
83+
def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
8484
"""Handles the start of a content block by extracting tool usage information if any.
8585
8686
Args:
@@ -102,61 +102,59 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:
102102

103103

104104
def handle_content_block_delta(
105-
event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any
106-
) -> Dict[str, Any]:
105+
event: ContentBlockDeltaEvent, state: dict[str, Any]
106+
) -> tuple[dict[str, Any], dict[str, Any]]:
107107
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.
108108
109109
Args:
110110
event: Delta event.
111111
state: The current state of message processing.
112-
callback_handler: Callback for processing events as they happen.
113-
**kwargs: Additional keyword arguments to pass to the callback handler.
114112
115113
Returns:
116114
Updated state with appended text or tool input.
117115
"""
118116
delta_content = event["delta"]
119117

118+
callback_event = {}
119+
120120
if "toolUse" in delta_content:
121121
if "input" not in state["current_tool_use"]:
122122
state["current_tool_use"]["input"] = ""
123123

124124
state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
125-
callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs)
125+
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
126126

127127
elif "text" in delta_content:
128128
state["text"] += delta_content["text"]
129-
callback_handler(data=delta_content["text"], delta=delta_content, **kwargs)
129+
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
130130

131131
elif "reasoningContent" in delta_content:
132132
if "text" in delta_content["reasoningContent"]:
133133
if "reasoningText" not in state:
134134
state["reasoningText"] = ""
135135

136136
state["reasoningText"] += delta_content["reasoningContent"]["text"]
137-
callback_handler(
138-
reasoningText=delta_content["reasoningContent"]["text"],
139-
delta=delta_content,
140-
reasoning=True,
141-
**kwargs,
142-
)
137+
callback_event["callback"] = {
138+
"reasoningText": delta_content["reasoningContent"]["text"],
139+
"delta": delta_content,
140+
"reasoning": True,
141+
}
143142

144143
elif "signature" in delta_content["reasoningContent"]:
145144
if "signature" not in state:
146145
state["signature"] = ""
147146

148147
state["signature"] += delta_content["reasoningContent"]["signature"]
149-
callback_handler(
150-
reasoning_signature=delta_content["reasoningContent"]["signature"],
151-
delta=delta_content,
152-
reasoning=True,
153-
**kwargs,
154-
)
148+
callback_event["callback"] = {
149+
"reasoning_signature": delta_content["reasoningContent"]["signature"],
150+
"delta": delta_content,
151+
"reasoning": True,
152+
}
155153

156-
return state
154+
return state, callback_event
157155

158156

159-
def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
157+
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
160158
"""Handles the end of a content block by finalizing tool usage, text content, or reasoning content.
161159
162160
Args:
@@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
165163
Returns:
166164
Updated state with finalized content block.
167165
"""
168-
content: List[ContentBlock] = state["content"]
166+
content: list[ContentBlock] = state["content"]
169167

170168
current_tool_use = state["current_tool_use"]
171169
text = state["text"]
@@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
223221
return event["stopReason"]
224222

225223

226-
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None:
224+
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None:
227225
"""Handles redacting content from the input or output.
228226
229227
Args:
@@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state:
238236
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}]
239237

240238

241-
def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:
239+
def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
242240
"""Extracts usage metrics from the metadata chunk.
243241
244242
Args:
@@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:
255253

256254
def process_stream(
257255
chunks: Iterable[StreamEvent],
258-
callback_handler: Any,
259256
messages: Messages,
260-
**kwargs: Any,
261-
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
257+
) -> Generator[dict[str, Any], None, None]:
262258
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
263259
264260
Args:
265261
chunks: The chunks of the response stream from the model.
266-
callback_handler: Callback for processing events as they happen.
267262
messages: The agents messages.
268-
**kwargs: Additional keyword arguments that will be passed to the callback handler.
269-
And also returned in the request_state.
270263
271264
Returns:
272-
The reason for stopping, the constructed message, the usage metrics, and the updated request state.
265+
The reason for stopping, the constructed message, and the usage metrics.
273266
"""
274267
stop_reason: StopReason = "end_turn"
275268

276-
state: Dict[str, Any] = {
269+
state: dict[str, Any] = {
277270
"message": {"role": "assistant", "content": []},
278271
"text": "",
279272
"current_tool_use": {},
@@ -285,18 +278,16 @@ def process_stream(
285278
usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
286279
metrics: Metrics = Metrics(latencyMs=0)
287280

288-
kwargs.setdefault("request_state", {})
289-
290281
for chunk in chunks:
291-
# Callback handler call here allows each event to be visible to the caller
292-
callback_handler(event=chunk)
282+
yield {"callback": {"event": chunk}}
293283

294284
if "messageStart" in chunk:
295285
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
296286
elif "contentBlockStart" in chunk:
297287
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
298288
elif "contentBlockDelta" in chunk:
299-
state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs)
289+
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
290+
yield callback_event
300291
elif "contentBlockStop" in chunk:
301292
state = handle_content_block_stop(state)
302293
elif "messageStop" in chunk:
@@ -306,35 +297,30 @@ def process_stream(
306297
elif "redactContent" in chunk:
307298
handle_redact_content(chunk["redactContent"], messages, state)
308299

309-
return stop_reason, state["message"], usage, metrics, kwargs["request_state"]
300+
yield {"stop": (stop_reason, state["message"], usage, metrics)}
310301

311302

312303
def stream_messages(
313304
model: Model,
314305
system_prompt: Optional[str],
315306
messages: Messages,
316307
tool_config: Optional[ToolConfig],
317-
callback_handler: Any,
318-
**kwargs: Any,
319-
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
308+
) -> Generator[dict[str, Any], None, None]:
320309
"""Streams messages to the model and processes the response.
321310
322311
Args:
323312
model: Model provider.
324313
system_prompt: The system prompt to send.
325314
messages: List of messages to send.
326315
tool_config: Configuration for the tools to use.
327-
callback_handler: Callback for processing events as they happen.
328-
**kwargs: Additional keyword arguments that will be passed to the callback handler.
329-
And also returned in the request_state.
330316
331317
Returns:
332-
The reason for stopping, the final message, the usage metrics, and updated request state.
318+
The reason for stopping, the final message, and the usage metrics
333319
"""
334320
logger.debug("model=<%s> | streaming messages", model)
335321

336322
messages = remove_blank_messages_content_text(messages)
337323
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None
338324

339325
chunks = model.converse(messages, tool_specs, system_prompt)
340-
return process_stream(chunks, callback_handler, messages, **kwargs)
326+
yield from process_stream(chunks, messages)

src/strands/models/anthropic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,15 +387,15 @@ def structured_output(
387387
prompt(Messages): The prompt messages to use for the agent.
388388
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
389389
"""
390+
callback_handler = callback_handler or PrintingCallbackHandler()
390391
tool_spec = convert_pydantic_to_tool_spec(output_model)
391392

392393
response = self.converse(messages=prompt, tool_specs=[tool_spec])
393-
# process the stream and get the tool use input
394-
results = process_stream(
395-
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
396-
)
397-
398-
stop_reason, messages, _, _, _ = results
394+
for event in process_stream(response, prompt):
395+
if "callback" in event:
396+
callback_handler(**event["callback"])
397+
else:
398+
stop_reason, messages, _, _ = event["stop"]
399399

400400
if stop_reason != "tool_use":
401401
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

src/strands/models/bedrock.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,15 +504,15 @@ def structured_output(
504504
prompt(Messages): The prompt messages to use for the agent.
505505
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
506506
"""
507+
callback_handler = callback_handler or PrintingCallbackHandler()
507508
tool_spec = convert_pydantic_to_tool_spec(output_model)
508509

509510
response = self.converse(messages=prompt, tool_specs=[tool_spec])
510-
# process the stream and get the tool use input
511-
results = process_stream(
512-
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
513-
)
514-
515-
stop_reason, messages, _, _, _ = results
511+
for event in process_stream(response, prompt):
512+
if "callback" in event:
513+
callback_handler(**event["callback"])
514+
else:
515+
stop_reason, messages, _, _ = event["stop"]
516516

517517
if stop_reason != "tool_use":
518518
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

tests-integ/test_model_anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def tool_weather() -> str:
3434

3535
@pytest.fixture
3636
def system_prompt():
37-
return "You are an AI assistant that uses & instead of ."
37+
return "You are an AI assistant."
3838

3939

4040
@pytest.fixture
@@ -47,7 +47,7 @@ def test_agent(agent):
4747
result = agent("What is the time and weather in New York?")
4848
text = result.message["content"][0]["text"].lower()
4949

50-
assert all(string in text for string in ["12:00", "sunny", "&"])
50+
assert all(string in text for string in ["12:00", "sunny"])
5151

5252

5353
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")

0 commit comments

Comments
 (0)
0