diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py index 7f7c4acebaac3..5bc2474d386ca 100644 --- a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py +++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py @@ -21,7 +21,6 @@ from localstack.constants import APPLICATION_JSON, INTERNAL_AWS_ACCESS_KEY_ID from localstack.utils.aws.arns import extract_region_from_arn from localstack.utils.aws.client_types import ServicePrincipal -from localstack.utils.collections import merge_dicts from localstack.utils.strings import to_bytes, to_str from ..context import ( @@ -390,10 +389,7 @@ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse: headers = Headers({"Content-Type": APPLICATION_JSON}) - response_headers = merge_dicts( - lambda_response.get("headers") or {}, - lambda_response.get("multiValueHeaders") or {}, - ) + response_headers = self._merge_lambda_response_headers(lambda_response) headers.update(response_headers) return EndpointResponse( @@ -467,8 +463,6 @@ def serialize_header(value: bool | str) -> str: if multi_value_headers := lambda_response.get("multiValueHeaders"): lambda_response["multiValueHeaders"] = { k: [serialize_header(v) for v in values] - if isinstance(values, list) - else serialize_header(values) for k, values in multi_value_headers.items() } @@ -482,13 +476,20 @@ def _is_lambda_response_valid(lambda_response: dict) -> bool: if not validate_sub_dict_of_typed_dict(LambdaProxyResponse, lambda_response): return False - if "headers" in lambda_response: - headers = lambda_response["headers"] + if (headers := lambda_response.get("headers")) is not None: if not isinstance(headers, dict): return False if any(not isinstance(header_value, (str, bool)) for header_value in headers.values()): return False + if (multi_value_headers := lambda_response.get("multiValueHeaders")) is not None: + if not isinstance(multi_value_headers, dict): + return False + if any( + not isinstance(header_value, list) for header_value in multi_value_headers.values() + ): + return False + if "statusCode" in lambda_response: try: int(lambda_response["statusCode"]) @@ -550,3 +551,19 @@ def _format_body(body: bytes) -> tuple[str, bool]: return body.decode("utf-8"), False except UnicodeDecodeError: return to_str(base64.b64encode(body)), True + + @staticmethod + def _merge_lambda_response_headers(lambda_response: LambdaProxyResponse) -> dict: + headers = lambda_response.get("headers") or {} + + if multi_value_headers := lambda_response.get("multiValueHeaders"): + # multiValueHeaders has the priority and will decide the casing of the final headers, as they are merged + headers_low_keys = {k.lower(): v for k, v in headers.items()} + + for k, values in multi_value_headers.items(): + if (k_lower := k.lower()) in headers_low_keys: + headers[k] = [*values, headers_low_keys[k_lower]] + else: + headers[k] = values + + return headers diff --git a/tests/aws/services/apigateway/test_apigateway_lambda.py b/tests/aws/services/apigateway/test_apigateway_lambda.py index 9bd8355e80a9f..4eb10905a1401 100644 --- a/tests/aws/services/apigateway/test_apigateway_lambda.py +++ b/tests/aws/services/apigateway/test_apigateway_lambda.py @@ -1,3 +1,4 @@ +import base64 import json import os @@ -33,6 +34,8 @@ CLOUDFRONT_SKIP_HEADERS = [ "$..Via", "$..X-Amz-Cf-Id", + "$..X-Amz-Cf-Pop", + "$..X-Cache", "$..CloudFront-Forwarded-Proto", "$..CloudFront-Is-Desktop-Viewer", "$..CloudFront-Is-Mobile-Viewer", @@ -42,6 +45,16 @@ "$..CloudFront-Viewer-Country", ] +LAMBDA_RESPONSE_FROM_BODY = """ +import json +import base64 +def handler(event, context, *args): + body = event["body"] + if event.get("isBase64Encoded"): + body = base64.b64decode(body) + return json.loads(body) +""" + @markers.aws.validated @markers.snapshot.skip_snapshot_verify(paths=CLOUDFRONT_SKIP_HEADERS) @@ -871,6 +884,208 @@ def invoke_api(url): assert response.json() == {"message": "Internal server error"} +@markers.snapshot.skip_snapshot_verify( + paths=[ + *CLOUDFRONT_SKIP_HEADERS, + # returned by LocalStack by default + "$..headers.Server", + ] +) +@markers.aws.validated +def test_aws_proxy_response_payload_format_validation( + create_rest_apigw, + create_lambda_function, + create_role_with_policy, + aws_client, + region_name, + snapshot, +): + snapshot.add_transformers_list( + [ + snapshot.transform.key_value("Via"), + snapshot.transform.key_value("X-Cache"), + snapshot.transform.key_value("x-amz-apigw-id"), + snapshot.transform.key_value("X-Amz-Cf-Pop"), + snapshot.transform.key_value("X-Amz-Cf-Id"), + snapshot.transform.key_value("X-Amzn-Trace-Id"), + snapshot.transform.key_value( + "Date", reference_replacement=False, value_replacement="" + ), + ] + ) + snapshot.add_transformers_list( + [ + snapshot.transform.jsonpath("$..headers.Host", value_replacement="host"), + snapshot.transform.jsonpath("$..multiValueHeaders.Host[0]", value_replacement="host"), + snapshot.transform.key_value( + "X-Forwarded-For", + value_replacement="", + reference_replacement=False, + ), + snapshot.transform.key_value( + "X-Forwarded-Port", + value_replacement="", + reference_replacement=False, + ), + snapshot.transform.key_value( + "X-Forwarded-Proto", + value_replacement="", + reference_replacement=False, + ), + ], + priority=-1, + ) + + stage_name = "test" + _, role_arn = create_role_with_policy( + "Allow", "lambda:InvokeFunction", json.dumps(APIGATEWAY_ASSUME_ROLE_POLICY), "*" + ) + + function_name = f"response-format-apigw-{short_uid()}" + create_function_response = create_lambda_function( + handler_file=LAMBDA_RESPONSE_FROM_BODY, + func_name=function_name, + runtime=Runtime.python3_12, + ) + # create invocation role + lambda_arn = create_function_response["CreateFunctionResponse"]["FunctionArn"] + + # create rest api + api_id, _, root = create_rest_apigw( + name=f"test-api-{short_uid()}", + description="Integration test API", + ) + + resource_id = aws_client.apigateway.create_resource( + restApiId=api_id, parentId=root, pathPart="{proxy+}" + )["id"] + + aws_client.apigateway.put_method( + restApiId=api_id, + resourceId=resource_id, + httpMethod="ANY", + authorizationType="NONE", + ) + + # Lambda AWS_PROXY integration + aws_client.apigateway.put_integration( + restApiId=api_id, + resourceId=resource_id, + httpMethod="ANY", + type="AWS_PROXY", + integrationHttpMethod="POST", + uri=f"arn:aws:apigateway:{region_name}:lambda:path/2015-03-31/functions/{lambda_arn}/invocations", + credentials=role_arn, + ) + + aws_client.apigateway.create_deployment(restApiId=api_id, stageName=stage_name) + endpoint = api_invoke_url(api_id=api_id, path="/test", stage=stage_name) + + def _invoke( + body: dict | str, expected_status_code: int = 200, return_headers: bool = False + ) -> dict: + kwargs = {} + if body: + kwargs["json"] = body + + _response = requests.post( + url=endpoint, + headers={"User-Agent": "python/test"}, + verify=False, + **kwargs, + ) + + assert _response.status_code == expected_status_code + + try: + content = _response.json() + except json.JSONDecodeError: + content = _response.content.decode() + + dict_resp = {"content": content} + if return_headers: + dict_resp["headers"] = dict(_response.headers) + + return dict_resp + + response = retry(_invoke, sleep=1, retries=10, body={"statusCode": 200}) + snapshot.match("invoke-api-no-body", response) + + response = _invoke( + body={"statusCode": 200, "headers": {"test-header": "value", "header-bool": True}}, + return_headers=True, + ) + snapshot.match("invoke-api-with-headers", response) + + response = _invoke( + body={"statusCode": 200, "headers": None}, + return_headers=True, + ) + snapshot.match("invoke-api-with-headers-null", response) + + response = _invoke(body={"statusCode": 200, "wrongValue": "value"}, expected_status_code=502) + snapshot.match("invoke-api-wrong-format", response) + + response = _invoke(body={}, expected_status_code=502) + snapshot.match("invoke-api-empty-response", response) + + response = _invoke( + body={ + "statusCode": 200, + "body": base64.b64encode(b"test-data").decode(), + "isBase64Encoded": True, + } + ) + snapshot.match("invoke-api-b64-encoded-true", response) + + response = _invoke(body={"statusCode": 200, "body": base64.b64encode(b"test-data").decode()}) + snapshot.match("invoke-api-b64-encoded-false", response) + + response = _invoke( + body={"statusCode": 200, "multiValueHeaders": {"test-multi": ["value1", "value2"]}}, + return_headers=True, + ) + snapshot.match("invoke-api-multi-headers-valid", response) + + response = _invoke( + body={ + "statusCode": 200, + "multiValueHeaders": {"test-multi": ["value-multi"]}, + "headers": {"test-multi": "value-solo"}, + }, + return_headers=True, + ) + snapshot.match("invoke-api-multi-headers-overwrite", response) + + response = _invoke( + body={ + "statusCode": 200, + "multiValueHeaders": {"tesT-Multi": ["value-multi"]}, + "headers": {"test-multi": "value-solo"}, + }, + return_headers=True, + ) + snapshot.match("invoke-api-multi-headers-overwrite-casing", response) + + response = _invoke( + body={"statusCode": 200, "multiValueHeaders": {"test-multi-invalid": "value1"}}, + expected_status_code=502, + ) + snapshot.match("invoke-api-multi-headers-invalid", response) + + response = _invoke(body={"statusCode": "test"}, expected_status_code=502) + snapshot.match("invoke-api-invalid-status-code", response) + + response = _invoke(body={"statusCode": "201"}, expected_status_code=201) + snapshot.match("invoke-api-status-code-str", response) + + response = _invoke(body="justAString", expected_status_code=502) + snapshot.match("invoke-api-just-string", response) + + response = _invoke(body={"headers": {"test-header": "value"}}, expected_status_code=200) + snapshot.match("invoke-api-only-headers", response) + + # Testing the integration with Rust to prevent future regression with strongly typed language integration # TODO make the test compatible for ARM @markers.aws.validated diff --git a/tests/aws/services/apigateway/test_apigateway_lambda.snapshot.json b/tests/aws/services/apigateway/test_apigateway_lambda.snapshot.json index 66b17012c07a9..1ff320770f0ad 100644 --- a/tests/aws/services/apigateway/test_apigateway_lambda.snapshot.json +++ b/tests/aws/services/apigateway/test_apigateway_lambda.snapshot.json @@ -1665,5 +1665,135 @@ "status_code": 200 } } + }, + "tests/aws/services/apigateway/test_apigateway_lambda.py::test_aws_proxy_response_payload_format_validation": { + "recorded-date": "15-11-2024, 17:48:06", + "recorded-content": { + "invoke-api-no-body": { + "content": "" + }, + "invoke-api-with-headers": { + "content": "", + "headers": { + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json", + "Date": "", + "Via": "", + "X-Amz-Cf-Id": "", + "X-Amz-Cf-Pop": "", + "X-Amzn-Trace-Id": "", + "X-Cache": "", + "header-bool": "true", + "test-header": "value", + "x-amz-apigw-id": "", + "x-amzn-RequestId": "" + } + }, + "invoke-api-with-headers-null": { + "content": "", + "headers": { + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json", + "Date": "", + "Via": "", + "X-Amz-Cf-Id": "", + "X-Amz-Cf-Pop": "", + "X-Amzn-Trace-Id": "", + "X-Cache": "", + "x-amz-apigw-id": "", + "x-amzn-RequestId": "" + } + }, + "invoke-api-wrong-format": { + "content": { + "message": "Internal server error" + } + }, + "invoke-api-empty-response": { + "content": { + "message": "Internal server error" + } + }, + "invoke-api-b64-encoded-true": { + "content": "dGVzdC1kYXRh" + }, + "invoke-api-b64-encoded-false": { + "content": "dGVzdC1kYXRh" + }, + "invoke-api-multi-headers-valid": { + "content": "", + "headers": { + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json", + "Date": "", + "Via": "", + "X-Amz-Cf-Id": "", + "X-Amz-Cf-Pop": "", + "X-Amzn-Trace-Id": "", + "X-Cache": "", + "test-multi": "value1, value2", + "x-amz-apigw-id": "", + "x-amzn-RequestId": "" + } + }, + "invoke-api-multi-headers-overwrite": { + "content": "", + "headers": { + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json", + "Date": "", + "Via": "", + "X-Amz-Cf-Id": "", + "X-Amz-Cf-Pop": "", + "X-Amzn-Trace-Id": "", + "X-Cache": "", + "test-multi": "value-multi, value-solo", + "x-amz-apigw-id": "", + "x-amzn-RequestId": "" + } + }, + "invoke-api-multi-headers-overwrite-casing": { + "content": "", + "headers": { + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json", + "Date": "", + "Via": "", + "X-Amz-Cf-Id": "", + "X-Amz-Cf-Pop": "", + "X-Amzn-Trace-Id": "", + "X-Cache": "", + "tesT-Multi": "value-multi, value-solo", + "x-amz-apigw-id": "", + "x-amzn-RequestId": "" + } + }, + "invoke-api-multi-headers-invalid": { + "content": { + "message": "Internal server error" + } + }, + "invoke-api-invalid-status-code": { + "content": { + "message": "Internal server error" + } + }, + "invoke-api-status-code-str": { + "content": "" + }, + "invoke-api-just-string": { + "content": { + "message": "Internal server error" + } + }, + "invoke-api-only-headers": { + "content": "" + } + } } } diff --git a/tests/aws/services/apigateway/test_apigateway_lambda.validation.json b/tests/aws/services/apigateway/test_apigateway_lambda.validation.json index b37661ae02b59..70ab1fb72eac8 100644 --- a/tests/aws/services/apigateway/test_apigateway_lambda.validation.json +++ b/tests/aws/services/apigateway/test_apigateway_lambda.validation.json @@ -1,4 +1,7 @@ { + "tests/aws/services/apigateway/test_apigateway_lambda.py::test_aws_proxy_response_payload_format_validation": { + "last_validated_date": "2024-11-15T17:48:06+00:00" + }, "tests/aws/services/apigateway/test_apigateway_lambda.py::test_lambda_aws_integration": { "last_validated_date": "2023-05-31T21:11:42+00:00" },