8000 fix AWS_PROXY lambda response validation · localstack/localstack@f8a00e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit f8a00e4

Browse files
committed
fix AWS_PROXY lambda response validation
1 parent 385a141 commit f8a00e4

File tree

4 files changed

+375
-8
lines changed

4 files changed

+375
-8
lines changed

localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from localstack.constants import APPLICATION_JSON, INTERNAL_AWS_ACCESS_KEY_ID
2222
from localstack.utils.aws.arns import extract_region_from_arn
2323
from localstack.utils.aws.client_types import ServicePrincipal
24-
from localstack.utils.collections import merge_dicts
2524
from localstack.utils.strings import to_bytes, to_str
2625

2726
from ..context import (
@@ -390,10 +389,7 @@ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
390389

391390
headers = Headers({"Content-Type": APPLICATION_JSON})
392391

393-
response_headers = merge_dicts(
394-
lambda_response.get("headers") or {},
395-
lambda_response.get("multiValueHeaders") or {},
396-
)
392+
response_headers = self._merge_lambda_response_headers(lambda_response)
397393
headers.update(response_headers)
398394

399395
return EndpointResponse(
@@ -467,8 +463,6 @@ def serialize_header(value: bool | str) -> str:
467463
if multi_value_headers := lambda_response.get("multiValueHeaders"):
468464
lambda_response["multiValueHeaders"] = {
469465
k: [serialize_header(v) for v in values]
470-
if isinstance(values, list)
471-
else serialize_header(values)
472466
for k, values in multi_value_headers.items()
473467
}
474468

@@ -482,13 +476,22 @@ def _is_lambda_response_valid(lambda_response: dict) -> bool:
482476
if not validate_sub_dict_of_typed_dict(LambdaProxyResponse, lambda_response):
483477
return False
484478

485-
if "headers" in lambda_response:
479+
if lambda_response.get("headers") is not None:
486480
headers = lambda_response["headers"]
487481
if not isinstance(headers, dict):
488482
return False
489483
if any(not isinstance(header_value, (str, bool)) for header_value in headers.values()):
490484
return False
491485

486+
if lambda_response.get("multiValueHeaders") is not None:
487+
multi_value_headers = lambda_response["multiValueHeaders"]
488+
if not isinstance(multi_value_headers, dict):
489+
return False
490+
if any(
491+
not isinstance(header_value, list) for header_value in multi_value_headers.values()
492+
):
493+
return False
494+
492495
if "statusCode" in lambda_response:
493496
try:
494497
int(lambda_response["statusCode"])
@@ -550,3 +553,19 @@ def _format_body(body: bytes) -> tuple[str, bool]:
550553
return body.decode("utf-8"), False
551554
except UnicodeDecodeError:
552555
return to_str(base64.b64encode(body)), True
556+
557+
@staticmethod
558+
def _merge_lambda_response_headers(lambda_response: LambdaProxyResponse) -> dict:
559+
headers = lambda_response.get("headers") or {}
560+
561+
if multi_value_headers := lambda_response.get("multiValueHeaders"):
562+
# multiValueHeaders has the priority and will decide the casing of the final headers, as they are merged
563+
headers_low_keys = {k.lower(): v for k, v in headers.items()}
564+
565+
for k, values in multi_value_headers.items():
566+
if (k_lower := k.lower()) in headers_low_keys:
567+
headers[k] = [*values, headers_low_keys[k_lower]]
568+
else:
569+
headers[k] = values
570+
571+
return headers

tests/aws/services/apigateway/test_apigateway_lambda.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import json
23
import os
34

@@ -33,6 +34,8 @@
3334
CLOUDFRONT_SKIP_HEADERS = [
3435
"$..Via",
3536
"$..X-Amz-Cf-Id",
37+
"$..X-Amz-Cf-Pop",
38+
"$..X-Cache",
3639
"$..CloudFront-Forwarded-Proto",
3740
"$..CloudFront-Is-Desktop-Viewer",
3841
"$..CloudFront-Is-Mobile-Viewer",
@@ -42,6 +45,16 @@
4245
"$..CloudFront-Viewer-Country",
4346
]
4447

48+
LAMBDA_RESPONSE_FROM_BODY = """
49+
import json
50+
import base64
51+
def handler(event, context, *args):
52+
body = event["body"]
53+
if event.get("isBase64Encoded"):
54+
body = base64.b64decode(body)
55+
return json.loads(body)
56+
"""
57+
4558

4659
@markers.aws.validated
4760
@markers.snapshot.skip_snapshot_verify(paths=CLOUDFRONT_SKIP_HEADERS)
@@ -871,6 +884,208 @@ def invoke_api(url):
871884
assert response.json() == {"message": "Internal server error"}
872885

873886

887+
@markers.snapshot.skip_snapshot_verify(
888+
paths=[
889+
*CLOUDFRONT_SKIP_HEADERS,
890+
# returned by LocalStack by default
891+
"$..headers.Server",
892+
]
893+
)
894+
@markers.aws.validated
895+
def test_aws_proxy_response_payload_format_validation(
896+
create_rest_apigw,
897+
create_lambda_function,
898+
create_role_with_policy,
899+
aws_client,
900+
region_name,
901+
snapshot,
902+
):
903+
snapshot.add_transformers_list(
904+
[
905+
snapshot.transform.key_value("Via"),
906+
snapshot.transform.key_value("X-Cache"),
907+
snapshot.transform.key_value("x-amz-apigw-id"),
908+
snapshot.transform.key_value("X-Amz-Cf-Pop"),
909+
snapshot.transform.key_value("X-Amz-Cf-Id"),
910+
snapshot.transform.key_value("X-Amzn-Trace-Id"),
911+
snapshot.transform.key_value(
912+
"Date", reference_replacement=False, value_replacement="<date>"
913+
),
914+
]
915+
)
916+
snapshot.add_transformers_list(
917+
[
918+
snapshot.transform.jsonpath("$..headers.Host", value_replacement="host"),
919+
snapshot.transform.jsonpath("$..multiValueHeaders.Host[0]", value_replacement="host"),
920+
snapshot.transform.key_value(
921+
"X-Forwarded-For",
922+
value_replacement="<X-Forwarded-For>",
923+
reference_replacement=False,
924+
),
925+
snapshot.transform.key_value(
926+
"X-Forwarded-Port",
927+
value_replacement="<X-Forwarded-Port>",
928+
reference_replacement=False,
929+
),
930+
snapshot.transform.key_value(
931+
"X-Forwarded-Proto",
932+
value_replacement="<X-Forwarded-Proto>",
933+
reference_replacement=False,
934+
),
935+
],
936+
priority=-1,
937+
)
938+
939+
stage_name = "test"
940+
_, role_arn = create_role_with_policy(
941+
"Allow", "lambda:InvokeFunction", json.dumps(APIGATEWAY_ASSUME_ROLE_POLICY), "*"
942+
)
943+
944+
function_name = f"response-format-apigw-{short_uid()}"
945+
create_function_response = create_lambda_function(
946+
handler_file=LAMBDA_RESPONSE_FROM_BODY,
947+
func_name=function_name,
948+
runtime=Runtime.python3_12,
949+
)
950+
# create invocation role
951+
lambda_arn = create_function_response["CreateFunctionResponse"]["FunctionArn"]
952+
953+
# create rest api
954+
api_id, _, root = create_rest_apigw(
955+
name=f"test-api-{short_uid()}",
956+
description="Integration test API",
957+
)
958+
959+
resource_id = aws_client.apigateway.create_resource(
960+
restApiId=api_id, parentId=root, pathPart="{proxy+}"
961+
)["id"]
962+
963+
aws_client.apigateway.put_method(
964+
restApiId=api_id,
965+
resourceId=resource_id,
966+
httpMethod="ANY",
967+
authorizationType="NONE",
968+
)
969+
970+
# Lambda AWS_PROXY integration
971+
aws_client.apigateway.put_integration(
972+
restApiId=api_id,
973+
resourceId=resource_id,
974+
httpMethod="ANY",
975+
type="AWS_PROXY",
976+
integrationHttpMethod="POST",
977+
uri=f"arn:aws:apigateway:{region_name}:lambda:path/2015-03-31/functions/{lambda_arn}/invocations",
978+
credentials=role_arn,
979+
)
980+
981+
aws_client.apigateway.create_deployment(restApiId=api_id, stageName=stage_name)
982+
endpoint = api_invoke_url(api_id=api_id, path="/test", stage=stage_name)
983+
984+
def _invoke(
985+
body: dict | str, expected_status_code: int = 200, return_headers: bool = False
986+
) -> dict:
987+
kwargs = {}
988+
if body:
989+
kwargs["json"] = body
990+
991+
_response = requests.post(
992+
url=endpoint,
993+
headers={"User-Agent": "python/test"},
994+
verify=False,
995+
**kwargs,
996+
)
997+
998+
assert _response.status_code == expected_status_code
999+
1000+
try:
1001+
content = _response.json()
1002+
except json.JSONDecodeError:
1003+
content = _response.content.decode()
1004+
1005+
dict_resp = {"content": content}
1006+
if return_headers:
1007+
dict_resp["headers"] = dict(_response.headers)
1008+
1009+
return dict_resp
1010+
1011+
response = retry(_invoke, sleep=1, retries=10, body={"statusCode": 200})
1012+
snapshot.match("invoke-api-no-body", response)
1013+
1014+
response = _invoke(
1015+
body={"statusCode": 200, "headers": {"test-header": "value", "header-bool": True}},
1016+
return_headers=True,
1017+
)
1018+
snapshot.match("invoke-api-with-headers", response)
1019+
1020+
response = _invoke(
1021+
body={"statusCode": 200, "headers": None},
1022+
return_headers=True,
1023+
)
1024+
snapshot.match("invoke-api-with-headers-null", response)
1025+
1026+
response = _invoke(body={"statusCode": 200, "wrongValue": "value"}, expected_status_code=502)
1027+
snapshot.match("invoke-api-wrong-format", response)
1028+
1029+
response = _invoke(body={}, expected_status_code=502)
1030+
snapshot.match("invoke-api-empty-response", response)
1031+
1032+
response = _invoke(
1033+
body={
1034+
"statusCode": 200,
1035+
"body": base64.b64encode(b"test-data").decode(),
1036+
"isBase64Encoded": True,
1037+
}
1038+
)
1039+
snapshot.match("invoke-api-b64-encoded-true", response)
1040+
1041+
response = _invoke(body={"statusCode": 200, "body": base64.b64encode(b"test-data").decode()})
1042+
snapshot.match("invoke-api-b64-encoded-false", response)
1043+
1044+
response = _invoke(
1045+
body={"statusCode": 200, "multiValueHeaders": {"test-multi": ["value1", "value2"]}},
1046+
return_headers=True,
1047+
)
1048+
snapshot.match("invoke-api-multi-headers-valid", response)
1049+
1050+
response = _invoke(
1051+
body={
1052+
"statusCode": 200,
1053+
"multiValueHeaders": {"test-multi": ["value-multi"]},
1054+
"headers": {"test-multi": "value-solo"},
1055+
},
1056+
return_headers=True,
1057+
)
1058+
snapshot.match("invoke-api-multi-headers-overwrite", response)
1059+
1060+
response = _invoke(
1061+
body={
1062+
"statusCode": 200,
1063+
"multiValueHeaders": {"tesT-Multi": ["value-multi"]},
1064+
"headers": {"test-multi": "value-solo"},
1065+
},
1066+
return_headers=True,
1067+
)
1068+
snapshot.match("invoke-api-multi-headers-overwrite-casing", response)
1069+
1070+
response = _invoke(
1071+
body={"statusCode": 200, "multiValueHeaders": {"test-multi-invalid": "value1"}},
1072+
expected_status_code=502,
1073+
)
1074+
snapshot.match("invoke-api-multi-headers-invalid", response)
1075+
1076+
response = _invoke(body={"statusCode": "test"}, expected_status_code=502)
1077+
snapshot.match("invoke-api-invalid-status-code", response)
1078+
1079+
response = _invoke(body={"statusCode": "201"}, expected_status_code=201)
1080+
snapshot.match("invoke-api-status-code-str", response)
1081+
1082+
response = _invoke(body="justAString", expected_status_code=502)
1083+
snapshot.match("invoke-api-just-string", response)
1084+
1085+
response = _invoke(body={"headers": {"test-header": "value"}}, expected_status_code=200)
1086+
snapshot.match("invoke-api-only-headers", response)
1087+
1088+
8741089
# Testing the integration with Rust to prevent future regression with strongly typed language integration
8751090
# TODO make the test compatible for ARM
8761091
@markers.aws.validated

0 commit comments

Comments
 (0)
0