10000 Apigw/add support for response override in request (#12628) · localstack/localstack@5f9ae76 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5f9ae76

Browse files
authored
Apigw/add support for response override in request (#12628)
1 parent 9990b6f commit 5f9ae76

File tree

13 files changed

+202
-29
lines changed

13 files changed

+202
-29
lines changed

localstack-core/localstack/services/apigateway/next_gen/execute_api/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from localstack.aws.api.apigateway import Integration, Method, Resource
99
from localstack.services.apigateway.models import RestApiDeployment
1010

11-
from .variables import ContextVariables, LoggingContextVariables
11+
from .variables import ContextVariableOverrides, ContextVariables, LoggingContextVariables
1212

1313

1414
class InvocationRequest(TypedDict, total=False):
@@ -98,6 +98,9 @@ class RestApiInvocationContext(RequestContext):
9898
"""The Stage variables, also used in parameters mapping and mapping templates"""
9999
context_variables: Optional[ContextVariables]
100100
"""The $context used in data models, authorizers, mapping templates, and CloudWatch access logging"""
101+
context_variable_overrides: Optional[ContextVariableOverrides]
102+
"""requestOverrides and responseOverrides are passed from request templates to response templates but are
103+
not in the integration context"""
101104
logging_context_variables: Optional[LoggingContextVariables]
102105
"""Additional $context variables available only for access logging, not yet implemented"""
103106
invocation_request: Optional[InvocationRequest]
@@ -129,3 +132,4 @@ def __init__(self, request: Request):
129132
self.endpoint_response = None
130133
self.invocation_response = None
131134
self.trace_id = None
135+
self.context_variable_overrides = None

localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_request.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
MappingTemplateParams,
2323
MappingTemplateVariables,
2424
)
25-
from ..variables import ContextVarsRequestOverride
25+
from ..variables import ContextVariableOverrides, ContextVarsRequestOverride
2626

2727
LOG = logging.getLogger(__name__)
2828

@@ -119,13 +119,16 @@ def __call__(
119119

120120
converted_body = self.convert_body(context)
121121

122-
body, request_override = self.render_request_template_mapping(
122+
body, mapped_overrides = self.render_request_template_mapping(
123123
context=context, body=converted_body, template=request_template
124124
)
125+
# Update the context with the returned mapped overrides
126+
context.context_variable_overrides = mapped_overrides
125127
# mutate the ContextVariables with the requestOverride result, as we copy the context when rendering the
126128
# template to avoid mutation on other fields
127-
# the VTL responseTemplate can access the requestOverride
128-
context.context_variables["requestOverride"] = request_override
129+
request_override: ContextVarsRequestOverride = mapped_overrides.get(
130+
"requestOverride", {}
131+
)
129132
# TODO: log every override that happens afterwards (in a loop on `request_override`)
130133
merge_recursive(request_override, request_data_mapping, overwrite=True)
131134

@@ -180,18 +183,18 @@ def render_request_template_mapping(
180183
context: RestApiInvocationContext,
181184
body: str | bytes,
182185
template: str,
183-
) -> tuple[bytes, ContextVarsRequestOverride]:
186+
) -> tuple[bytes, ContextVariableOverrides]:
184187
request: InvocationRequest = context.invocation_request
185188

186189
if not template:
187-
return to_bytes(body), {}
190+
return to_bytes(body), context.context_variable_overrides
188191

189192
try:
190193
body_utf8 = to_str(body)
191194
except UnicodeError:
192195
raise InternalServerError("Internal server error")
193196

194-
body, request_override = self._vtl_template.render_request(
197+
body, mapped_overrides = self._vtl_template.render_request(
195198
template=template,
196199
variables=MappingTemplateVariables(
197200
context=context.context_variables,
@@ -205,8 +208,9 @@ def render_request_template_mapping(
205208
),
206209
),
207210
),
211+
context_overrides=context.context_variable_overrides,
208212
)
209-
return to_bytes(body), request_override
213+
return to_bytes(body), mapped_overrides
210214

211215
@staticmethod
212216
def get_request_template(integration: Integration, request: InvocationRequest) -> str:

localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def render_response_template_mapping(
263263
self, context: RestApiInvocationContext, template: str, body: bytes | str
264264
) -> tuple[bytes, ContextVarsResponseOverride]:
265265
if not template:
266-
return to_bytes(body), ContextVarsResponseOverride(status=0, header={})
266+
return to_bytes(body), context.context_variable_overrides["responseOverride"]
267267

268268
# if there are no template, we can pass binary data through
269269
if not isinstance(body, str):
@@ -284,6 +284,7 @@ def render_response_template_mapping(
284284
),
285285
),
286286
),
287+
context_overrides=context.context_variable_overrides,
287288
)
288289

289290
# AWS ignores the status if the override isn't an integer between 100 and 599

localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/parse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818
from ..header_utils import should_drop_header_from_invocation
1919
from ..helpers import generate_trace_id, generate_trace_parent, parse_trace_id
2020
from ..moto_helpers import get_stage_variables
21-
from ..variables import ContextVariables, ContextVarsIdentity
21+
from ..variables import (
22+
ContextVariableOverrides,
23+
ContextVariables,
24+
ContextVarsIdentity,
25+
ContextVarsRequestOverride,
26+
ContextVarsResponseOverride,
27+
)
2228

2329
LOG = logging.getLogger(__name__)
2430

@@ -40,6 +46,10 @@ def parse_and_enrich(self, context: RestApiInvocationContext):
4046
# then we can create the ContextVariables, used throughout the invocation as payload and to render authorizer
4147
# payload, mapping templates and such.
4248
context.context_variables = self.create_context_variables(context)
49+
context.context_variable_overrides = ContextVariableOverrides(
50+
requestOverride=ContextVarsRequestOverride(header={}, querystring={}, path={}),
51+
responseOverride=ContextVarsResponseOverride(header={}, status=0),
52+
)
4353
# TODO: maybe adjust the logging
4454
LOG.debug("Initializing $context='%s'", context.context_variables)
4555
# then populate the stage variables

localstack-core/localstack/services/apigateway/next_gen/execute_api/template_mapping.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from localstack import config
2828
from localstack.services.apigateway.next_gen.execute_api.variables import (
29+
ContextVariableOverrides,
2930
ContextVariables,
30-
ContextVarsRequestOverride,
3131
ContextVarsResponseOverride,
3232
)
3333
from localstack.utils.aws.templating import APIGW_SOURCE, VelocityUtil, VtlTemplate
@@ -261,22 +261,27 @@ def prepare_namespace(self, variables, source: str = APIGW_SOURCE) -> dict[str,
261261
return namespace
262262

263263
def render_request(
264-
self, template: str, variables: MappingTemplateVariables
265-
) -> tuple[str, ContextVarsRequestOverride]:
264+
self,
265+
template: str,
266+
variables: MappingTemplateVariables,
267+
context_overrides: ContextVariableOverrides,
268+
) -> tuple[str, ContextVariableOverrides]:
266269
variables_copy: MappingTemplateVariables = copy.deepcopy(variables)
267-
variables_copy["context"]["requestOverride"] = ContextVarsRequestOverride(
268-
querystring={}, header={}, path={}
269-
)
270+
variables_copy["context"].update(copy.deepcopy(context_overrides))
270271
result = self.render_vtl(template=template.strip(), variables=variables_copy)
271-
return result, variables_copy["context"]["requestOverride"]
272+
return result, ContextVariableOverrides(
273+
requestOverride=variables_copy["context"]["requestOverride"],
274+
responseOverride=variables_copy["context"]["responseOverride"],
275+
)
272276

273277
def render_response(
274-
self, template: str, variables: MappingTemplateVariables
278+
self,
279+
template: str,
280+
variables: MappingTemplateVariables,
281+
context_overrides: ContextVariableOverrides,
275282
) -> tuple[str, ContextVarsResponseOverride]:
276283
variables_copy: MappingTemplateVariables = copy.deepcopy(variables)
277-
variables_copy["context"]["responseOverride"] = ContextVarsResponseOverride(
278-
header={}, status=0
279-
)
284+
variables_copy["context"].update(copy.deepcopy(context_overrides))
280285
result = self.render_vtl(template=template.strip(), variables=variables_copy)
281286
return result, variables_copy["context"]["responseOverride"]
282287

localstack-core/localstack/services/apigateway/next_gen/execute_api/test_invoke.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from .handlers.resource_router import RestAPIResourceRouter
1717
from .header_utils import build_multi_value_headers
1818
from .template_mapping import dict_to_string
19+
from .variables import (
20+
ContextVariableOverrides,
21+
ContextVarsRequestOverride,
22+
ContextVarsResponseOverride,
23+
)
1924

2025
# TODO: we probably need to write and populate those logs as part of the handler chain itself
2126
# and store it in the InvocationContext. That way, we could also retrieve in when calling TestInvoke
@@ -150,8 +155,11 @@ def create_test_invocation_context(
150155
invocation_context.context_variables = parse_handler.create_context_variables(
151156
invocation_context
152157
)
< F438 /td>158+
invocation_context.context_variable_overrides = ContextVariableOverrides(
159+
requestOverride=ContextVarsRequestOverride(header={}, path={}, querystring={}),
160+
responseOverride=ContextVarsResponseOverride(header={}, status=0),
161+
)
153162
invocation_context.trace_id = parse_handler.populate_trace_id({})
154-
155163
resource = deployment.rest_api.resources[test_request["resourceId"]]
156164
resource_method = resource["resourceMethods"][http_method]
157165
invocation_context.resource = resource

localstack-core/localstack/services/apigateway/next_gen/execute_api/variables.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class ContextVarsResponseOverride(TypedDict):
7575
status: int
7676

7777

78+
class ContextVariableOverrides(TypedDict):
79+
requestOverride: ContextVarsRequestOverride
80+
responseOverride: ContextVarsResponseOverride
81+
82+
7883
class GatewayResponseContextVarsError(TypedDict, total=False):
7984
# This variable can only be used for simple variable substitution in a GatewayResponse body-mapping template,
8085
# which is not processed by the Velocity Template Language engine, and in access logging.

tests/aws/services/apigateway/test_apigateway_integrations.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,88 @@ def invoke_api(url) -> requests.Response:
724724
snapshot.match("invoke-path-else", response_data_3.json())
725725

726726

727+
@markers.aws.validated
728+
@pytest.mark.parametrize("create_response_template", [True, False])
729+
def test_integration_mock_with_response_override_in_request_template(
730+
create_rest_apigw, aws_client, snapshot, create_response_template
731+
):
732+
expected_status = 444
733+
api_id, _, root_id = create_rest_apigw(
734+
name=f&qu 10000 ot;test-api-{short_uid()}",
735+
description="this is my api",
736+
)
737+
738+
aws_client.apigateway.put_method(
739+
restApiId=api_id,
740+
resourceId=root_id,
741+
httpMethod="GET",
742+
authorizationType="NONE",
743+
)
744+
745+
aws_client.apigateway.put_method_response(
746+
restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200"
747+
)
748+
749+
request_template = textwrap.dedent(f"""
750+
#set($context.responseOverride.status = {expected_status})
751+
#set($context.responseOverride.header.foo = "bar")
752+
#set($context.responseOverride.custom = "is also passed around")
753+
{{
754+
"statusCode": 200
755+
}}
756+
""")
757+
758+
aws_client.apigateway.put_integration(
759+
restApiId=api_id,
760+
resourceId=root_id,
761+
httpMethod="GET",
762+
integrationHttpMethod="POST",
763+
type="MOCK",
764+
requestParameters={},
765+
requestTemplates={"application/json": request_template},
766+
)
767+
response_template = textwrap.dedent("""
768+
#set($statusOverride = $context.responseOverride.status)
769+
#set($fooHeader = $context.responseOverride.header.foo)
770+
#set($custom = $context.responseOverride.custom)
771+
{
772+
"statusOverride": "$statusOverride",
773+
"fooHeader": "$fooHeader",
774+
"custom": "$custom"
775+
}
776+
""")
777+
778+
aws_client.apigateway.put_integration_response(
779+
restApiId=api_id,
780+
resourceId=root_id,
781+
httpMethod="GET",
782+
statusCode="200",
783+
selectionPattern="2\\d{2}",
784+
responseTemplates={"application/json": response_template}
785+
if create_response_template
786+
else {},
787+
)
788+
stage_name = "dev"
789+
aws_client.apigateway.create_deployment(restApiId=api_id, stageName=stage_name)
790+
791+
invocation_url = api_invoke_url(api_id=api_id, stage=stage_name)
792+
793+
def invoke_api(url) -> requests.Response:
794+
_response = requests.get(url, verify=False)
795+
assert _response.status_code == expected_status
796+
return _response
797+
798+
response_data = retry(invoke_api, sleep=2, retries=10, url=invocation_url)
799+
assert response_data.headers["foo"] == "bar"
800+
snapshot.match(
801+
"response",
802+
{
803+
"body": response_data.json() if create_response_template else response_data.content,
804+
"status_code": response_data.status_code,
805+
},
806+
)
807+
808+
727809
@pytest.fixture
728810
def default_vpc(aws_client):
729811
vpcs = aws_client.ec2.describe_vpcs()

tests/aws/services/apigateway/test_apigateway_integrations.snapshot.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,5 +1078,27 @@
10781078
}
10791079
}
10801080
}
1081+
},
1082+
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_integration_mock_with_response_override_in_request_template[True]": {
1083+
"recorded-date": "16-05-2025, 10:22:21",
1084+
"recorded-content": {
1085+
"response": {
1086+
"body": {
1087+
"custom": "is also passed around",
1088+
"fooHeader": "bar",
1089+
"statusOverride": "444"
1090+
},
1091+
"status_code": 444
1092+
}
1093+
}
1094+
},
1095+
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_integration_mock_with_response_override_in_request_template[False]": {
1096+
"recorded-date": "16-05-2025, 10:22:27",
1097+
"recorded-content": {
1098+
"response": {
1099+
"body": "b''",
1100+
"status_code": 444
1101+
}
1102+
}
10811103
}
10821104
}

tests/aws/services/apigateway/test_apigateway_integrations.validation.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_integration_mock_with_request_overrides_in_response_template": {
2121
"last_validated_date": "2024-11-06T23:09:04+00:00"
2222
},
23+
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_integration_mock_with_response_override_in_request_template[False]": {
24+
"last_validated_date": "2025-05-16T10:22:27+00:00"
25+
},
26+
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_integration_mock_with_response_override_in_request_template[True]": {
27+
"last_validated_date": "2025-05-16T10:22:21+00:00"
28+
},
2329
"tests/aws/services/apigateway/test_apigateway_integrations.py::test_put_integration_response_with_response_template": {
2430
"last_validated_date": "2024-05-30T16:15:58+00:00"
2531
},

0 commit comments

Comments
 (0)
0