From cc06031b6da889d0be4316c81fbbb805e9f62392 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 9 Nov 2023 15:48:21 +0100 Subject: [PATCH 1/3] fix(event-handler): enable path parameters on Bedrock handler --- .../event_handler/api_gateway.py | 7 +++++- .../event_handler/bedrock_agent.py | 10 ++++++++ .../data_classes/bedrock_agent_event.py | 8 ++++++- .../bedrockAgentEventWithPathParams.json | 23 +++++++++++++++++++ .../event_handler/test_bedrock_agent.py | 22 ++++++++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 tests/events/bedrockAgentEventWithPathParams.json diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1e494fd1c0f..535d79e5874 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1829,7 +1829,8 @@ def _resolve(self) -> ResponseBuilder: # Add matched Route reference into the Resolver context self.append_context(_route=route, _path=path) - return self._call_route(route, match_results.groupdict()) # pass fn args + route_keys = self._convert_matches_into_route_keys(match_results) + return self._call_route(route, route_keys) # pass fn args logger.debug(f"No match found for path {path} and method {method}") return self._not_found(method) @@ -1858,6 +1859,10 @@ def _remove_prefix(self, path: str) -> str: return path + def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]: + """Converts the regex match into a dict of route keys""" + return match.groupdict() + @staticmethod def _path_starts_with(path: str, prefix: str): """Returns true if the `path` starts with a prefix plus a `/`""" diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 258fc7dcaee..f3a21958d60 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,3 +1,4 @@ +from re import Match from typing import Any, Dict from typing_extensions import override @@ -75,3 +76,12 @@ def __init__(self, debug: bool = False, enable_validation: bool = True): enable_validation=enable_validation, ) self._response_builder_class = BedrockResponseBuilder + + @override + def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]: + # In Bedrock Agents, all the parameters come inside the "parameters" key, not on the apiPath + # So we have to search for route parameters in the parameters key + parameters: Dict[str, str] = {} + if match.groupdict() and self.current_event.parameters: + parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters} + return parameters diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 1577ad62895..9534af0e7f6 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -98,7 +98,13 @@ def session_attributes(self) -> Dict[str, str]: def prompt_session_attributes(self) -> Dict[str, str]: return self["promptSessionAttributes"] - # For compatibility with BaseProxyEvent + # The following methods add compatibility with BaseProxyEvent @property def path(self) -> str: return self["apiPath"] + + @property + def query_string_parameters(self) -> Optional[Dict[str, str]]: + # In Bedrock Agent events, query string parameters are passed as undifferentiated parameters, + # together with the other parameters. So we just return all parameters here. + return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None diff --git a/tests/events/bedrockAgentEventWithPathParams.json b/tests/events/bedrockAgentEventWithPathParams.json new file mode 100644 index 00000000000..c7ce8ccaf54 --- /dev/null +++ b/tests/events/bedrockAgentEventWithPathParams.json @@ -0,0 +1,23 @@ +{ + "actionGroup": "ClaimManagementActionGroup", + "messageVersion": "1.0", + "sessionId": "12345678912345", + "sessionAttributes": {}, + "promptSessionAttributes": {}, + "inputText": "I want to claim my insurance", + "agent": { + "alias": "TSTALIASID", + "name": "test", + "version": "DRAFT", + "id": "8ZXY0W8P1H" + }, + "parameters": [ + { + "type": "string", + "name": "claim_id", + "value": "123" + } + ], + "httpMethod": "GET", + "apiPath": "/claims/" +} diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index dcdca460d25..f112acf0463 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -34,6 +34,28 @@ def claims() -> Dict[str, Any]: assert body == json.dumps({"output": claims_response}) +def test_bedrock_agent_with_path_params(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims/") + def claims(claim_id: str): + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + assert claim_id == "123" + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEventWithPathParams.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims/" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 200 + + def test_bedrock_agent_event_with_response(): # GIVEN a Bedrock Agent event app = BedrockAgentResolver() From df3ad190acf76c7f4fec8a8060e1487742af9610 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 9 Nov 2023 23:54:21 +0100 Subject: [PATCH 2/3] chore: change default openapi version to 3.0.0 --- aws_lambda_powertools/event_handler/openapi/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/constants.py b/aws_lambda_powertools/event_handler/openapi/constants.py index f5d72d47f7e..e41063f5282 100644 --- a/aws_lambda_powertools/event_handler/openapi/constants.py +++ b/aws_lambda_powertools/event_handler/openapi/constants.py @@ -1,2 +1,2 @@ DEFAULT_API_VERSION = "1.0.0" -DEFAULT_OPENAPI_VERSION = "3.1.0" +DEFAULT_OPENAPI_VERSION = "3.0.0" From ac713570744ccdd3b51148afe1d3a342e001ed7a Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 9 Nov 2023 23:56:14 +0100 Subject: [PATCH 3/3] fix: unexport Form, File and Header --- .../event_handler/openapi/dependant.py | 18 ++++++++++-------- .../event_handler/openapi/params.py | 6 +++--- .../event_handler/test_openapi_params.py | 6 +++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 8cbb8b942ed..87e0c7dfb3d 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -14,12 +14,12 @@ from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, - File, - Form, - Header, Param, ParamTypes, Query, + _File, + _Form, + _Header, analyze_param, create_response_field, get_flat_dependant, @@ -235,7 +235,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field): return False else: if not isinstance(param_field.field_info, Body): @@ -326,10 +326,12 @@ def get_body_field_info( if not required: body_field_info_kwargs["default"] = None - if any(isinstance(f.field_info, File) for f in flat_dependant.body_params): - body_field_info: Type[Body] = File - elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): - body_field_info = Form + if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params): + # MAINTENANCE: body_field_info: Type[Body] = _File + raise NotImplementedError("_File fields are not supported in request bodies") + elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params): + # MAINTENANCE: body_field_info: Type[Body] = _Form + raise NotImplementedError("_Form fields are not supported in request bodies") else: body_field_info = Body diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 797b44f6232..c8099d20404 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -308,7 +308,7 @@ def __init__( ) -class Header(Param): +class _Header(Param): """ A class used internally to represent a header parameter in a path operation. """ @@ -471,7 +471,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.default})" -class Form(Body): +class _Form(Body): """ A class used internally to represent a form parameter in a path operation. """ @@ -543,7 +543,7 @@ def __init__( ) -class File(Form): +class _File(_Form): """ A class used internally to represent a file parameter in a path operation. """ diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 3972b2c4626..ec31bb14236 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -14,11 +14,11 @@ ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, - Header, Param, ParamTypes, Query, _create_model_field, + _Header, ) from aws_lambda_powertools.shared.types import Annotated @@ -375,7 +375,7 @@ def secret(): def test_create_header(): - header = Header(convert_underscores=True) + header = _Header(convert_underscores=True) assert header.convert_underscores is True @@ -400,7 +400,7 @@ def test_create_model_field_with_empty_in(): # Tests that when we try to create a model field with convert_underscore, we convert the field name def test_create_model_field_convert_underscore(): - field_info = Header(alias=None, convert_underscores=True) + field_info = _Header(alias=None, convert_underscores=True) result = _create_model_field(field_info, int, "user_id", False) assert result.alias == "user-id"