3
3
- Docs: https://aws.amazon.com/bedrock/
4
4
"""
5
5
6
+ import json
6
7
import logging
7
8
import os
8
- from typing import Any , Iterable , Literal , Optional , cast
9
+ from typing import Any , Iterable , List , Literal , Optional , cast
9
10
10
11
import boto3
11
12
from botocore .config import Config as BotocoreConfig
12
- from botocore .exceptions import ClientError , EventStreamError
13
+ from botocore .exceptions import ClientError
13
14
from typing_extensions import TypedDict , Unpack , override
14
15
15
16
from ..types .content import Messages
@@ -61,6 +62,7 @@ class BedrockConfig(TypedDict, total=False):
61
62
max_tokens: Maximum number of tokens to generate in the response
62
63
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0")
63
64
stop_sequences: List of sequences that will stop generation when encountered
65
+ streaming: Flag to enable/disable streaming. Defaults to True.
64
66
temperature: Controls randomness in generation (higher = more random)
65
67
top_p: Controls diversity via nucleus sampling (alternative to temperature)
66
68
"""
@@ -81,6 +83,7 @@ class BedrockConfig(TypedDict, total=False):
81
83
max_tokens : Optional [int ]
82
84
model_id : str
83
85
stop_sequences : Optional [list [str ]]
86
+ streaming : Optional [bool ]
84
87
temperature : Optional [float ]
85
88
top_p : Optional [float ]
86
89
@@ -246,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
246
249
"""
247
250
return cast (StreamEvent , event )
248
251
252
+ def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> bool :
253
+ """Check if guardrail data contains any blocked policies.
254
+
255
+ Args:
256
+ guardrail_data: Guardrail data from trace information.
257
+
258
+ Returns:
259
+ True if any blocked guardrail is detected, False otherwise.
260
+ """
261
+ input_assessment = guardrail_data .get ("inputAssessment" , {})
262
+ output_assessments = guardrail_data .get ("outputAssessments" , {})
263
+
264
+ # Check input assessments
265
+ if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in input_assessment .values ()):
266
+ return True
267
+
268
+ # Check output assessments
269
+ if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in output_assessments .values ()):
270
+ return True
271
+
272
+ return False
273
+
274
+ def _generate_redaction_events (self ) -> list [StreamEvent ]:
275
+ """Generate redaction events based on configuration.
276
+
277
+ Returns:
278
+ List of redaction events to yield.
279
+ """
280
+ events : List [StreamEvent ] = []
281
+
282
+ if self .config .get ("guardrail_redact_input" , True ):
283
+ logger .debug ("Redacting user input due to guardrail." )
284
+ events .append (
285
+ {
286
+ "redactContent" : {
287
+ "redactUserContentMessage" : self .config .get (
288
+ "guardrail_redact_input_message" , "[User input redacted.]"
289
+ )
290
+ }
291
+ }
292
+ )
293
+
294
+ if self .config .get ("guardrail_redact_output" , False ):
295
+ logger .debug ("Redacting assistant output due to guardrail." )
296
+ events .append (
297
+ {
298
+ "redactContent" : {
299
+ "redactAssistantContentMessage" : self .config .get (
300
+ "guardrail_redact_output_message" , "[Assistant output redacted.]"
301
+ )
302
+ }
303
+ }
304
+ )
305
+
306
+ return events
307
+
249
308
@override
250
- def stream (self , request : dict [str , Any ]) -> Iterable [dict [ str , Any ] ]:
251
- """Send the request to the Bedrock model and get the streaming response.
309
+ def stream (self , request : dict [str , Any ]) -> Iterable [StreamEvent ]:
310
+ """Send the request to the Bedrock model and get the response.
252
311
253
- This method calls the Bedrock converse_stream API and returns the stream of response events.
312
+ This method calls either the Bedrock converse_stream API or the converse API
313
+ based on the streaming parameter in the configuration.
254
314
255
315
Args:
256
316
request: The formatted request to send to the Bedrock model
@@ -260,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
260
320
261
321
Raises:
262
322
ContextWindowOverflowException: If the input exceeds the model's context window.
263
- EventStreamError: For all other Bedrock API errors .
323
+ ModelThrottledException: If the model service is throttling requests .
264
324
"""
325
+ streaming = self .config .get ("streaming" , True )
326
+
265
327
try :
266
- response = self .client .converse_stream (** request )
267
- for chunk in response ["stream" ]:
268
- if self .config .get ("guardrail_redact_input" , True ) or self .config .get ("guardrail_redact_output" , False ):
328
+ if streaming :
329
+ # Streaming implementation
330
+ response = self .client .converse_stream (** request )
331
+ for chunk in response ["stream" ]:
269
332
if (
270
333
"metadata" in chunk
271
334
and "trace" in chunk ["metadata" ]
272
335
and "guardrail" in chunk ["metadata" ]["trace" ]
273
336
):
274
- inputAssessment = chunk ["metadata" ]["trace" ]["guardrail" ].get ("inputAssessment" , {})
275
- outputAssessments = chunk ["metadata" ]["trace" ]["guardrail" ].get ("outputAssessments" , {})
276
-
277
- # Check if an input or output guardrail was triggered
278
- if any (
279
- self ._find_detected_and_blocked_policy (assessment )
280
- for assessment in inputAssessment .values ()
281
- ) or any (
282
- self ._find_detected_and_blocked_policy (assessment )
283
- for assessment in outputAssessments .values ()
284
- ):
285
- if self .config .get ("guardrail_redact_input" , True ):
286
- logger .debug ("Found blocked input guardrail. Redacting input." )
287
- yield {
288
- "redactContent" : {
289
- "redactUserContentMessage" : self .config .get (
290
- "guardrail_redact_input_message" , "[User input redacted.]"
291
- )
292
- }
293
- }
294
- if self .config .get ("guardrail_redact_output" , False ):
295
- logger .debug ("Found blocked output guardrail. Redacting output." )
296
- yield {
297
- "redactContent" : {
298
- "redactAssistantContentMessage" : self .config .get (
299
- "guardrail_redact_output_message" , "[Assistant output redacted.]"
300
- )
301
- }
302
- }
337
+ guardrail_data = chunk ["metadata" ]["trace" ]["guardrail" ]
338
+ if self ._has_blocked_guardrail (guardrail_data ):
339
+ yield from self ._generate_redaction_events ()
340
+ yield chunk
341
+ else :
342
+ # Non-streaming implementation
343
+ response = self .client .converse (** request )
344
+
345
+ # Convert and yield from the response
346
+ yield from self ._convert_non_streaming_to_streaming (response )
303
347
304
- yield chunk
305
- except EventStreamError as e :
306
- # Handle throttling that occurs mid-stream?
307
- if "ThrottlingException" in str (e ) and "ConverseStream" in str (e ):
308
- raise ModelThrottledException (str (e )) from e
348
+ # Check for guardrail triggers after yielding any events (same as streaming path)
349
+ if (
350
+ "trace" in response
351
+ and "guardrail" in response ["trace" ]
352
+ and self ._has_blocked_guardrail (response ["trace" ]["guardrail" ])
353
+ ):
354
+ yield from self ._generate_redaction_events ()
309
355
310
- if any (overflow_message in str (e ) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
356
+ except ClientError as e :
357
+ error_message = str (e )
358
+
359
+ # Handle throttling error
360
+ if e .response ["Error" ]["Code" ] == "ThrottlingException" :
361
+ raise ModelThrottledException (error_message ) from e
362
+
363
+ # Handle context window overflow
364
+ if any (overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
311
365
logger .warning ("bedrock threw context window overflow error" )
312
366
raise ContextWindowOverflowException (e ) from e
367
+
368
+ # Otherwise raise the error
313
369
raise e
314
- except ClientError as e :
315
- # Handle throttling that occurs at the beginning of the call
316
- if e .response ["Error" ]["Code" ] == "ThrottlingException" :
317
- raise ModelThrottledException (str (e )) from e
318
370
319
- raise
371
+ def _convert_non_streaming_to_streaming (self , response : dict [str , Any ]) -> Iterable [StreamEvent ]:
372
+ """Convert a non-streaming response to the streaming format.
373
+
374
+ Args:
375
+ response: The non-streaming response from the Bedrock model.
376
+
377
+ Returns:
378
+ An iterable of response events in the streaming format.
379
+ """
380
+ # Yield messageStart event
381
+ yield {"messageStart" : {"role" : response ["output" ]["message" ]["role" ]}}
382
+
383
+ # Process content blocks
384
+ for content in response ["output" ]["message" ]["content" ]:
385
+ # Yield contentBlockStart event if needed
386
+ if "toolUse" in content :
387
+ yield {
388
+ "contentBlockStart" : {
389
+ "start" : {
390
+ "toolUse" : {
391
+ "toolUseId" : content ["toolUse" ]["toolUseId" ],
392
+ "name" : content ["toolUse" ]["name" ],
393
+ }
394
+ },
395
+ }
396
+ }
397
+
398
+ # For tool use, we need to yield the input as a delta
399
+ input_value = json .dumps (content ["toolUse" ]["input" ])
400
+
401
+ yield {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : input_value }}}}
402
+ elif "text" in content :
403
+ # Then yield the text as a delta
404
+ yield {
405
+ "contentBlockDelta" : {
406
+ "delta" : {"text" : content ["text" ]},
407
+ }
408
+ }
409
+ elif "reasoningContent" in content :
410
+ # Then yield the reasoning content as a delta
411
+ yield {
412
+ "contentBlockDelta" : {
413
+ "delta" : {"reasoningContent" : {"text" : content ["reasoningContent" ]["reasoningText" ]["text" ]}}
414
+ }
415
+ }
416
+
417
+ if "signature" in content ["reasoningContent" ]["reasoningText" ]:
418
+ yield {
419
+ "contentBlockDelta" : {
420
+ "delta" : {
421
+ "reasoningContent" : {
422
+ "signature" : content ["reasoningContent" ]["reasoningText" ]["signature" ]
423
+ }
424
+ }
425
+ }
426
+ }
427
+
428
+ # Yield contentBlockStop event
429
+ yield {"contentBlockStop" : {}}
430
+
431
+ # Yield messageStop event
432
+ yield {
433
+ "messageStop" : {
434
+ "stopReason" : response ["stopReason" ],
435
+ "additionalModelResponseFields" : response .get ("additionalModelResponseFields" ),
436
+ }
437
+ }
438
+
439
+ # Yield metadata event
440
+ if "usage" in response or "metrics" in response or "trace" in response :
441
+ metadata : StreamEvent = {"metadata" : {}}
442
+ if "usage" in response :
443
+ metadata ["metadata" ]["usage" ] = response ["usage" ]
444
+ if "metrics" in response :
445
+ metadata ["metadata" ]["metrics" ] = response ["metrics" ]
446
+ if "trace" in response :
447
+ metadata ["metadata" ]["trace" ] = response ["trace" ]
448
+ yield metadata
320
449
321
450
def _find_detected_and_blocked_policy (self , input : Any ) -> bool :
322
451
"""Recursively checks if the assessment contains a detected and blocked guardrail.
0 commit comments