8000 feat(bedrock_agents): add optional fields to response payload by anafalcao · Pull Request #6336 · aws-powertools/powertools-lambda-python · GitHub
[go: up one dir, main page]

Skip to content

feat(bedrock_agents): add optional fields to response payload #6336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Response,
)
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
from aws_lambda_powertools.event_handler.lambda_function_url import (
LambdaFunctionUrlResolver,
Expand All @@ -26,6 +26,7 @@
"ALBResolver",
"ApiGatewayResolver",
"BedrockAgentResolver",
"BedrockResponse",
"CORSConfig",
"LambdaFunctionUrlResolver",
"Response",
Expand Down
60 changes: 48 additions & 12 deletions aws_lambda_powertools/event_handler/api_gateway.py
8000 < 628C td id="diff-bdb4b43087f89fdf398381f3d88048f6182562ca67e8907259d5c56d658465c1L2630" data-line-number="2630" class="blob-num blob-num-deletion js-linkable-line-number">
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
_ROUTE_REGEX = "^{}$"
_JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder)
_DEFAULT_CONTENT_TYPE = "application/json"

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")
Expand Down Expand Up @@ -255,6 +256,35 @@ def build_allow_methods(methods: set[str]) -> str:
return ",".join(sorted(methods))


class BedrockResponse(Generic[ResponseT]):
"""
Contains the response body, status code, content type, and optional attributes
for session management and knowledge base configuration.
"""

def __init__(
self,
body: Any = None,
status_code: int = 200,
content_type: str = _DEFAULT_CONTENT_TYPE,
session_attributes: dict[str, Any] | None = None,
prompt_session_attributes: dict[str, Any] | None = None,
knowledge_bases_configuration: list[dict[str, Any]] | None = None,
) -> None:
self.body = body
self.status_code = status_code
self.content_type = content_type
self.session_attributes = session_attributes
self.prompt_session_attributes = prompt_session_attributes
self.knowledge_bases_configuration = knowledge_bases_configuration

def is_json(self) -> bool:
"""
Returns True if the response is JSON, based on the Content-Type.
"""
return True


class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""

Expand Down Expand Up @@ -300,7 +330,7 @@ def is_json(self) -> bool:
content_type = self.headers.get("Content-Type", "")
if isinstance(content_type, list):
content_type = content_type[0]
return content_type.startswith("application/json")
return content_type.startswith(_DEFAULT_CONTENT_TYPE)


class Route:
Expand Down Expand Up @@ -572,7 +602,7 @@ def _get_openapi_path(
operation_responses: dict[int, OpenAPIResponse] = {
422: {
"description": "Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
"content": {_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
},
}

Expand All @@ -581,7 +611,9 @@ def _get_openapi_path(
http_code = self.custom_response_validation_http_code.value
operation_responses[http_code] = {
"description": "Response Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
"content": {
_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}},
},
}
# Add model definition
definitions["ResponseValidationError"] = response_validation_error_response_definition
Expand All @@ -594,7 +626,7 @@ def _get_openapi_path(
# Case 1: there is not 'content' key
if "content" not in response:
response["content"] = {
"application/json": self._openapi_operation_return(
_DEFAULT_CONTENT_TYPE: self._openapi_operation_return(
param=dependant.return_param,
model_name_map=model_name_map,
field_mapping=field_mapping,
Expand Down Expand Up @@ -645,7 +677,7 @@ def _get_openapi_path(
# Add the response schema to the OpenAPI 200 response
operation_responses[200] = {
"description": self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
"content": {"application/json": response_schema},
"content": {_DEFAULT_CONTENT_TYPE: response_schema},
}

operation["responses"] = operation_responses
Expand Down Expand Up @@ -1474,7 +1506,10 @@ def __call__(self, app: ApiGatewayResolver) -> dict | tuple | Response:
return self.current_middleware(app, self.next_middleware)


def _registered_api_adapter(app: ApiGatewayResolver, next_middleware: Callable[..., Any]) -> dict | tuple | Response:
def _registered_api_adapter(
app: ApiGatewayResolver,
next_middleware: Callable[..., Any],
) -> dict | tuple | Response | BedrockResponse:
"""
Calls the registered API using the "_route_args" from the Resolver context to ensure the last call
in the chain will match the API route function signature and ensure that Powertools passes the API
Expand Down Expand Up @@ -1632,7 +1667,7 @@ def _add_resolver_response_validation_error_response_to_route(
response_validation_error_response = {
"description": "Response Validation Error",
"content": {
"application/json": {
_DEFAULT_CONTENT_TYPE: {
"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"},
},
},
Expand Down Expand Up @@ -2151,7 +2186,7 @@ def swagger_handler():
if query_params.get("format") == "json":
return Response(
status_code=200,
content_type="application/json",
content_type=_DEFAULT_CONTENT_TYPE,
body=escaped_spec,
)

Expand Down Expand Up @@ -2538,7 +2573,7 @@ def _call_route(self, route: Route, route_arguments: dict[str, str]) -> Response
self._reset_processed_stack()

return self._response_builder_class(
response=self._to_response(
response=self._to_response( # type: ignore[arg-type]
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
serializer=self._serializer,
Expand Down Expand Up @@ -2627,7 +2662,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild

return None

def _to_response(self, result: dict | tuple | Response) -> Response:
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
"""Convert the route's result to a Response

3 main result types are supported:
Expand All @@ -2638,7 +2673,7 @@ def _to_response(self, result: dict | tuple | Response) -> Response:
- Response: returned as is, and allows for more flexibility
"""
status_code = HTTPStatus.OK
if isinstance(result, Response):
if isinstance(result, (Response, BedrockResponse)):
return result
elif isinstance(result, tuple) and len(result) == 2:
# Unpack result dict and status code from tuple
Expand Down Expand Up @@ -2971,8 +3006,9 @@ def _get_base_path(self) -> str:
# ALB doesn't have a stage variable, so we just return an empty string
return ""

# BedrockResponse is not used here but adding the same signature to keep strong typing
@override
def _to_response(self, result: dict | tuple | Response) -> Response:
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
"""Convert the route's result to a Response

ALB requires a non-null body otherwise it converts as HTTP 5xx
Expand Down
19 changes: 15 additions & 4 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aws_lambda_powertools.event_handler import ApiGatewayResolver
from aws_lambda_powertools.event_handler.api_gateway import (
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
BedrockResponse,
ProxyEventType,
ResponseBuilder,
)
Expand All @@ -32,14 +33,11 @@ class BedrockResponseBuilder(ResponseBuilder):

@override
def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

body = self.response.body
if self.response.is_json() and not isinstance(self.response.body, str):
body = self.serializer(self.response.body)

return {
response = {
"messageVersion": "1.0",
"response": {
"actionGroup": event.action_group,
Expand All @@ -54,6 +52,19 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
},
}

# Add Bedrock-specific attributes
if isinstance(self.response, BedrockResponse):
if self.response.session_attributes:
response["sessionAttributes"] = self.response.session_attributes

if self.response.prompt_session_attributes:
response["promptSessionAttributes"] = self.response.prompt_session_attributes

if self.response.knowledge_bases_configuration:
response["knowledgeBasesConfiguration"] = self.response.knowledge_bases_configuration

return response


class BedrockAgentResolver(ApiGatewayResolver):
"""Bedrock Agent Resolver
Expand Down
11 changes: 11 additions & 0 deletions docs/core/event_handler/bedrock_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ You can enable user confirmation with Bedrock Agents to have your application as

1. Add an openapi extension

### Fine grained responses

???+ info "Note"
The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB.

You can use `BedrockResponse` class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response){target="_blank"}.

```python title="working_with_bedrockresponse.py" title="Customzing your Bedrock Response" hl_lines="5 16"
--8<-- "examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py"
```

## Testing your code

Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from http import HTTPStatus

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import BedrockAgentResolver
from aws_lambda_powertools.event_handler.api_gateway import BedrockResponse
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
app = BedrockAgentResolver()


@app.get("/return_with_session", description="Returns a hello world with session attributes")
@tracer.capture_method
def hello_world():
return BedrockResponse(
status_code=HTTPStatus.OK.value,
body={"message": "Hello from Bedrock!"},
session_attributes={"user_id": "123"},
prompt_session_attributes={"context": "testing"},
knowledge_bases_configuration=[
{
"knowledgeBaseId": "kb-123",
"retrievalConfiguration": {
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
},
},
],
)


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext):
return app.resolve(event, context)
Loading
Loading
0