8000 feat: Add non-streaming support to BedrockModel (#75) · jer96/sdk-python@3100ea0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3100ea0

Browse files
authored
feat: Add non-streaming support to BedrockModel (strands-agents#75)
* feat: Add non-streaming support to BedrockModel * fix: Add more test coverage * fix: Update with pr comments
1 parent c3895d4 commit 3100ea0

File tree

5 files changed

+627
-54
lines changed

5 files changed

+627
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ from strands.models.llamaapi import LlamaAPIModel
123123
bedrock_model = BedrockModel(
124124
model_id="us.amazon.nova-pro-v1:0",
125125
temperature=0.3,
126+
streaming=True, # Enable/disable streaming
126127
)
127128
agent = Agent(model=bedrock_model)
128129
agent("Tell me about Agentic AI")

src/strands/models/bedrock.py

Lines changed: 178 additions & 49 deletions
< 10000 thead class="sr-only">
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
- Docs: https://aws.amazon.com/bedrock/
44
"""
55

6+
import json
67
import logging
78
import os
8-
from typing import Any, Iterable, Literal, Optional, cast
9+
from typing import Any, Iterable, List, Literal, Optional, cast
910

1011
import boto3
1112
from botocore.config import Config as BotocoreConfig
12-
from botocore.exceptions import ClientError, EventStreamError
13+
from botocore.exceptions import ClientError
1314
from typing_extensions import TypedDict, Unpack, override
1415

1516
from ..types.content import Messages
@@ -61,6 +62,7 @@ class BedrockConfig(TypedDict, total=False):
6162
max_tokens: Maximum number of tokens to generate in the response
6263
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0")
6364
stop_sequences: List of sequences that will stop generation when encountered
65+
streaming: Flag to enable/disable streaming. Defaults to True.
6466
temperature: Controls randomness in generation (higher = more random)
6567
top_p: Controls diversity via nucleus sampling (alternative to temperature)
6668
"""
@@ -81,6 +83,7 @@ class BedrockConfig(TypedDict, total=False):
8183
max_tokens: Optional[int]
8284
model_id: str
8385
stop_sequences: Optional[list[str]]
86+
streaming: Optional[bool]
8487
temperature: Optional[float]
8588
top_p: Optional[float]
8689

@@ -246,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
246249
"""
247250
return cast(StreamEvent, event)
248251

252+
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
253+
"""Check if guardrail data contains any blocked policies.
254+
255+
Args:
256+
guardrail_data: Guardrail data from trace information.
257+
258+
Returns:
259+
True if any blocked guardrail is detected, False otherwise.
260+
"""
261+
input_assessment = guardrail_data.get("inputAssessment", {})
262+
output_assessments = guardrail_data.get("outputAssessments", {})
263+
264+
# Check input assessments
265+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
266+
return True
267+
268+
# Check output assessments
269+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
270+
return True
271+
272+
return False
273+
274+
def _generate_redaction_events(self) -> list[StreamEvent]:
275+
"""Generate redaction events based on configuration.
276+
277+
Returns:
278+
List of redaction events to yield.
279+
"""
280+
events: List[StreamEvent] = []
281+
282+
if self.config.get("guardrail_redact_input", True):
283+
logger.debug("Redacting user input due to guardrail.")
284+
events.append(
285+
{
286+
"redactContent": {
287+
"redactUserContentMessage": self.config.get(
288+
"guardrail_redact_input_message", "[User input redacted.]"
289+
)
290+
}
291+
}
292+
)
293+
294+
if self.config.get("guardrail_redact_output", False):
295+
logger.debug("Redacting assistant output due to guardrail.")
296+
events.append(
297+
{
298+
"redactContent": {
299+
"redactAssistantContentMessage": self.config.get(
300+
"guardrail_redact_output_message", "[Assistant output redacted.]"
301+
)
302+
}
303+
}
304+
)
305+
306+
return events
307+
249308
@override
250-
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
251-
"""Send the request to the Bedrock model and get the streaming response.
309+
def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]:
310+
"""Send the request to the Bedrock model and get the response.
252311
253-
This method calls the Bedrock converse_stream API and returns the stream of response events.
312+
This method calls either the Bedrock converse_stream API or the converse API
313+
based on the streaming parameter in the configuration.
254314
255315
Args:
256316
request: The formatted request to send to the Bedrock model
@@ -260,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
260320
261321
Raises:
262322
ContextWindowOverflowException: If the input exceeds the model's context window.
263-
EventStreamError: For all other Bedrock API errors.
323+
ModelThrottledException: If the model service is throttling requests.
264324
"""
325+
streaming = self.config.get("streaming", True)
326+
265327
try:
266-
response = self.client.converse_stream(**request)
267-
for chunk in response["stream"]:
268-
if self.config.get("guardrail_redact_input", True) or self.config.get("guardrail_redact_output", False):
328+
if streaming:
329+
# Streaming implementation
330+
response = self.client.converse_stream(**request)
331+
for chunk in response["stream"]:
269332
if (
270333
"metadata" in chunk
271334
and "trace" in chunk["metadata"]
272335
and "guardrail" in chunk["metadata"]["trace"]
273336
):
274-
inputAssessment = chunk["metadata"]["trace"]["guardrail"].get("inputAssessment", {})
275-
outputAssessments = chunk["metadata"]["trace"]["guardrail"].get("outputAssessments", {})
276-
277-
# Check if an input or output guardrail was triggered
278-
if any(
279-
self._find_detected_and_blocked_policy(assessment)
280-
for assessment in inputAssessment.values()
281-
) or any(
282-
self._find_detected_and_blocked_policy(assessment)
283-
for assessment in outputAssessments.values()
284-
):
285-
if self.config.get("guardrail_redact_input", True):
286-
logger.debug("Found blocked input guardrail. Redacting input.")
287-
yield {
288-
"redactContent": {
289-
"redactUserContentMessage": self.config.get(
290-
"guardrail_redact_input_message", "[User input redacted.]"
291-
)
292-
}
293-
}
294-
if self.config.get("guardrail_redact_output", False):
295-
logger.debug("Found blocked output guardrail. Redacting output.")
296-
yield {
297-
"redactContent": {
298-
"redactAssistantContentMessage": self.config.get(
299-
"guardrail_redact_output_message", "[Assistant output redacted.]"
300-
)
301-
}
302-
}
337+
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
338+
if self._has_blocked_guardrail(guardrail_data):
339+
yield from self._generate_redaction_events()
340+
yield chunk
341+
else:
342+
# Non-streaming implementation
343+
response = self.client.converse(**request)
344+
345+
# Convert and yield from the response
346+
yield from self._convert_non_streaming_to_streaming(response)
303347

304-
yield chunk
305-
except EventStreamError as e:
306-
# Handle throttling that occurs mid-stream?
307-
if "ThrottlingException" in str(e) and "ConverseStream" in str(e):
308-
raise ModelThrottledException(str(e)) from e
348+
# Check for guardrail triggers after yielding any events (same as streaming path)
349+
if (
350+
"trace" in response
351+
and "guardrail" in response["trace"]
352+
and self._has_blocked_guardrail(response["trace"]["guardrail"])
353+
):
354+
yield from self._generate_redaction_events()
309355

310-
if any(overflow_message in str(e) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
356+
except ClientError as e:
357+
error_message = str(e)
358+
359+
# Handle throttling error
360+
if e.response["Error"]["Code"] == "ThrottlingException":
361+
raise ModelThrottledException(error_message) from e
362+
363+
# Handle context window overflow
364+
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
311365
logger.warning("bedrock threw context window overflow error")
312366
raise ContextWindowOverflowException(e) from e
367+
368+
# Otherwise raise the error
313369
raise e
314-
except ClientError as e:
315-
# Handle throttling that occurs at the beginning of the call
316-
if e.response["Error"]["Code"] == "ThrottlingException":
317-
raise ModelThrottledException(str(e)) from e
318370

319-
raise
371+
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
372+
"""Convert a non-streaming response to the streaming format.
373+
374+
Args:
375+
response: The non-streaming response from the Bedrock model.
376+
377+
Returns:
378+
An iterable of response events in the streaming format.
379+
"""
380+
# Yield messageStart event
381+
yield {"messageStart": {"role": response["output"]["message"]["role"]}}
382+
383+
# Process content blocks
384+
for content in response["output"]["message"]["content"]:
385+
# Yield contentBlockStart event if needed
386+
if "toolUse" in content:
387+
yield {
388+
"contentBlockStart": {
389+
"start": {
390+
"toolUse": {
391+
"toolUseId": content["toolUse"]["toolUseId"],
392+
"name": content["toolUse"]["name"],
393+
}
394+
},
395+
}
396+
}
397+
398+
# For tool use, we need to yield the input as a delta
399+
input_value = json.dumps(content["toolUse"]["input"])
400+
401+
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
402+
elif "text" in content:
403+
# Then yield the text as a delta
404+
yield {
405+
"contentBlockDelta": {
406+
"delta": {"text": content["text"]},
407+
}
408+
}
409+
elif "reasoningContent" in content:
410+
# Then yield the reasoning content as a delta
411+
yield {
412+
"contentBlockDelta": {
413+
"delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}}
414+
}
415+
}
416+
417+
if "signature" in content["reasoningContent"]["reasoningText"]:
418+
yield {
419+
"contentBlockDelta": {
420+
"delta": {
421+
"reasoningContent": {
422+
"signature": content["reasoningContent"]["reasoningText"]["signature"]
423+
}
424+
}
425+
}
426+
}
427+
428+
# Yield contentBlockStop event
429+
yield {"contentBlockStop": {}}
430+
431+
# Yield messageStop event
432+
yield {
433+
"messageStop": {
434+
"stopReason": response["stopReason"],
435+
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
436+
}
437+
}
438+
439+
# Yield metadata event
440+
if "usage" in response or "metrics" in response or "trace" in response:
441+
metadata: StreamEvent = {"metadata": {}}
442+
if "usage" in response:
443+
metadata["metadata"]["usage"] = response["usage"]
444+
if "metrics" in response:
445+
metadata["metadata"]["metrics"] = response["metrics"]
446+
if "trace" in response:
447+
metadata["metadata"]["trace"] = response["trace"]
448+
yield metadata
320449

321450
def _find_detected_and_blocked_policy(self, input: Any) -> bool:
322451
"""Recursively checks if the assessment contains a detected and blocked guardrail.

src/strands/types/streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ModelStreamErrorEvent(ExceptionEvent):
157157
originalStatusCode: int
158158

159159

160-
class RedactContentEvent(TypedDict):
160+
class RedactContentEvent(TypedDict, total=False):
161161
"""Event for redacting content.
162162
163163
Attributes:

0 commit comments

Comments
 (0)
0