8000 add Lambda Event Filtering for DynamoDB Streams, SQS (#6212) · localstack/localstack@2a9be8c · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a9be8c

Browse files
authored
add Lambda Event Filtering for DynamoDB Streams, SQS (#6212)
1 parent e759813 commit 2a9be8c

File tree

9 files changed

+693
-13
lines changed

9 files changed

+693
-13
lines changed

localstack/services/apigateway/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def extract_path_params(path: str, extracted_path: str) -> Dict[str, str]:
389389
return path_params
390390

391391

392-
def extract_query_string_params(path: str) -> list[str, Dict[str, str]]:
392+
def extract_query_string_params(path: str) -> Tuple[str, Dict[str, str]]:
393393
parsed_path = urlparse.urlparse(path)
394394
path = parsed_path.path
395395
parsed_query_string_params = urlparse.parse_qs(parsed_path.query)
@@ -403,7 +403,7 @@ def extract_query_string_params(path: str) -> list[str, Dict[str, str]]:
403403

404404
# strip trailing slashes from path to fix downstream lookups
405405
path = path.rstrip("/") or "/"
406-
return [path, query_string_params]
406+
return path, query_string_params
407407

408408

409409
def get_cors_response(headers):

localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
run_lambda,
1414
)
1515
from localstack.services.awslambda.lambda_executors import InvocationResult
16+
from localstack.services.awslambda.lambda_utils import (
17+
filter_stream_records,
18+
get_lambda_event_filters_for_arn,
19+
)
1620
from localstack.utils.aws import aws_stack
1721
from localstack.utils.threads import FuncThread
1822

@@ -80,7 +84,7 @@ def _listener_loop(self, *args):
8084
finally:
8185
time.sleep(self.SQS_POLL_INTERVAL_SEC)
8286

83-
def _process_messages_for_event_source(self, source, messages):
87+
def _process_messages_for_event_source(self, source, messages) -> bool:
8488
lambda_arn = source["FunctionArn"]
8589
queue_arn = source["EventSourceArn"]
8690
# https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#services-sqs-batchfailurereporting
@@ -186,6 +190,29 @@ def delete_messages(result: InvocationResult, func_arn, event, error=None, **kwa
186190

187191
records.append(record)
188192

193+
event_filter_criterias = get_lambda_event_filters_for_arn(lambda_arn, queue_arn)
194+
if len(event_filter_criterias) > 0:
195+
# convert to json for filtering
196+
for record in records:
197+
try:
198+
record["body"] = json.loads(record["body"])
199+
except json.JSONDecodeError:
200+
LOG.warning(
201+
f"Unable to convert record '{record['body']}' to json... Record might be dropped."
202+
)
203+
records = filter_stream_records(records, event_filter_criterias)
204+
# convert them back
205+
for record in records:
206+
record["body"] = (
207+
json.dumps(record["body"])
208+
if not isinstance(record["body"], str)
209+
else record["body"]
210+
)
211+
212+
# all messages were filtered out
213+
if not len(records) > 0:
214+
return True
215+
189216
event = {"Records": records}
190217

191218
res = run_lambda(

localstack/services/awslambda/event_source_listeners/stream_event_source_listener.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
)
1212
from localstack.services.awslambda.lambda_api import run_lambda
1313
from localstack.services.awslambda.lambda_executors import InvocationResult
14+
from localstack.services.awslambda.lambda_utils import (
15+
filter_stream_records,
16+
get_lambda_event_filters_for_arn,
17+
)
1418
from localstack.utils.aws.message_forwarding import send_event_to_target
1519
from localstack.utils.common import long_uid, timestamp_millis
1620
from localstack.utils.threads import FuncThread
@@ -183,6 +187,10 @@ def _listen_to_shard_and_invoke_lambda(self, params: Dict):
183187
ShardIterator=shard_iterator, Limit=batch_size
184188
)
185189
records = records_response.get("Records")
190+
event_filter_criterias = get_lambda_event_filters_for_arn(function_arn, stream_arn)
191+
if len(event_filter_criterias) > 0:
192+
records = filter_stream_records(records, event_filter_criterias)
193+
186194
should_get_next_batch = True
187195
if records:
188196
payload = self._create_lambda_event_payload(stream_arn, records, shard_id=shard_id)

localstack/services/awslambda/lambda_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
get_lambda_runtime,
4848
get_zip_bytes,
4949
multi_value_dict_for_list,
50+
validate_filters,
5051
)
5152
from localstack.services.generic_proxy import RegionBackend
5253
from localstack.services.install import INSTALL_DIR_STEPFUNCTIONS, install_go_lambda_runtime
@@ -245,6 +246,16 @@ def build_mapping_obj(data) -> Dict:
245246
if data.get("FunctionResponseTypes"):
246247
mapping["FunctionResponseTypes"] = data.get("FunctionResponseTypes")
247248

249+
if data.get("FilterCriteria"):
250+
# validate for valid json
251+
if not validate_filters(data.get("FilterCriteria")):
252+
# AWS raises following Exception when FilterCriteria is not valid:
253+
# An error occurred (InvalidParameterValueException) when calling the CreateEventSourceMapping operation:
254+
# Invalid filter pattern definition.
255+
raise ValueError(
256+
INVALID_PARAMETER_VALUE_EXCEPTION, "Invalid filter pattern definition."
257+
)
258+
mapping["FilterCriteria"] = data.get("FilterCriteria")
248259
return mapping
249260

250261

localstack/services/awslambda/lambda_utils.py

Lines changed: 142 additions & 1 deletion
< 10000 td data-grid-cell-id="diff-da118a05c30e47b0a45a82347563f0de41763281781993ab9c9185cceafe949e-302-361-1" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionNum-bgColor, var(--diffBlob-addition-bgColor-num));text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative left-side">361
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import json
23
import logging
34
import os
45
import re
@@ -7,7 +8,7 @@
78
from collections import defaultdict
89
from functools import lru_cache
910
from io import BytesIO
10-
from typing import Any, Dict, List, Optional, Union
11+
from typing import Any, Dict, List, Optional, TypedDict, Union
1112

1213
from flask import Response
1314

@@ -300,3 +301,143 @@ def generate_lambda_arn(
300301
return f"arn:aws:lambda:{region}:{account_id}:function:{fn_name}:{qualifier}"
301302
else:
302303
return f"arn:aws:lambda:{region}:{account_id}:function:{fn_name}"
304+
305+
306+
class FilterCriteria(TypedDict):
307+
Filters: List[Dict[str, any]]
308+
309+
310+
def parse_and_apply_numeric_filter(record_value: Dict, numeric_filter: List[str | int]) -> bool:
311+
if len(numeric_filter) % 2 > 0:
312+
LOG.warn("Invalid numeric lambda filter given")
313+
return True
314+
315+
if not isinstance(record_value, (int, float)):
316+
LOG.warn(f"Record {record_value} seem not to be a valid number")
317+
return False
318+
319+
for idx in range(0, len(numeric_filter), 2):
320+
321+
try:
322+
if numeric_filter[idx] == ">" and not (record_value > float(numeric_filter[idx + 1])):
323+
return False
324+
if numeric_filter[idx] == ">=" and not (record_value >= float(numeric_filter[idx + 1])):
325+
return False
326+
if numeric_filter[idx] == "=" and not (record_value == float(numeric_filter[idx + 1])):
327+
return False
328+
if numeric_filter[idx] == "<" and not (record_value < float(numeric_filter[idx + 1])):
329+
return False
330+
if numeric_filter[idx] == "<=" and not (record_value <= float(numeric_filter[idx + 1])):
331+
return False
332+
except ValueError:
333+
LOG.warn(
334+
f"Could not convert filter value {numeric_filter[idx + 1]} to a valid number value for filtering"
335+
)
336+
return True
337+
338+
339+
def verify_dict_filter(record_value: any, dict_filter: Dict[str, any]) -> bool:
340+
# https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventfiltering.html#filtering-syntax
341+
fits_filter = False
342+
for key, filter_value in dict_filter.items():
343+
if key.lower() == "anything-but":
344+
fits_filter = record_value not in filter_value
345+
elif key.lower() == "numeric":
346+
fits_filter = parse_and_apply_numeric_filter(record_value, filter_value)
347+
elif key.lower() == "exists":
348+
fits_filter = bool(filter_value) # exists means that the key exists in the event record
349+
elif key.lower() == "prefix":
350+
if not isinstance(record_value, str):
351+
LOG.warn(f"Record Value {record_value} does not seem to be a valid string.")
352+
fits_filter = isinstance(record_value, str) and record_value.startswith(
353+
str(filter_value)
354+
)
355+
356+
if fits_filter:
357+
return True
358+
return fits_filter
359+
360+
+
def filter_stream_record(filter_rule: Dict[str, any], record: Dict[str, any]) -> bool:
362+
if not filter_rule:
363+
return True
364+
# https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventfiltering.html#filtering-syntax
365+
filter_results = []
366+
for key, value in filter_rule.items():
367+
# check if rule exists in event
368+
record_value = (
369+
record.get(key.lower(), record.get(key)) if isinstance(record, Dict) else None
370+
)
371+
append_record = False
372+
if record_value is not None:
373+
# check if filter rule value is a list (leaf of rule tree) or a dict (rescursively call function)
374+
if isinstance(value, list):
375+
if len(value) > 0:
376+
if isinstance(value[0], (str, int)):
377+
append_record = record_value in value
378+
if isinstance(value[0], dict):
379+
append_record = verify_dict_filter(record_value, value[0])
380+
else:
381+
LOG.warn(f"Empty lambda filter: {key}")
382+
elif isinstance(value, dict):
383+
append_record = filter_stream_record(value, record_value)
384+
else:
385+
# special case 'exists'
386+
if isinstance(value, list) and len(value) > 0:
387+
append_record = not value[0].get("exists", True)
388+
389+
filter_results.append(append_record)
390+
return all(filter_results)
391+
392+
393+
def filter_stream_records(records, filters: List[FilterCriteria]):
394+
filtered_records = []
395+
for record in records:
396+
for filter in filters:
397+
for rule in filter["Filters"]:
398+
if filter_stream_record(json.loads(rule["Pattern"]), record):
399+
filtered_records.append(record)
400+
break
401+
return filtered_records
402+
403+
404+
def contains_list(filter: Dict) -> bool:
405+
if isinstance(filter, dict):
406+
for key, value in filter.items():
407+
if isinstance(value, list) and len(value) > 0:
408+
return True
409+
return contains_list(value)
410+
return False
411+
412+
413+
def validate_filters(filter: FilterCriteria) -> bool:
414+
# filter needs to be json serializeable
415+
for rule in filter["Filters"]:
416+
try:
417+
if not (filter_pattern := json.loads(rule["Pattern"])):
418+
return False
419+
return contains_list(filter_pattern)
420+
except json.JSONDecodeError:
421+
return False
422+
# needs to contain on what to filter (some list with citerias)
423+
# https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventfiltering.html#filtering-syntax
424+
425+
return True
426+
427+
428+
def get_lambda_event_filters_for_arn(lambda_arn: str, event_arn: str) -> List[Dict]:
429+
# late import to avoid circular import
430+
from localstack.services.awslambda.lambda_api import LambdaRegion
431+
432+
region_name = lambda_arn.split(":")[3]
433+
region = LambdaRegion.get(region_name)
434+
435+
event_filter_criterias = [
436+
event_source_mapping.get("FilterCriteria")
437+
for event_source_mapping in region.event_source_mappings
438+
if event_source_mapping.get("FunctionArn") == lambda_arn
439+
and event_source_mapping.get("EventSourceArn") == event_arn
440+
and event_source_mapping.get("FilterCriteria") is not None
441+
]
442+
443+
return event_filter_criterias

localstack/testing/pytest/fixtures.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line n A1E5 umberDiff line change
@@ -721,6 +721,18 @@ def is_stream_ready():
721721
return _wait_for_stream_ready
722722

723723

724+
@pytest.fixture
725+
def wait_for_dynamodb_stream_ready(dynamodbstreams_client):
726+
def _wait_for_stream_ready(stream_arn: str):
727+
def is_stream_ready():
728+
describe_stream_response = dynamodbstreams_client.describe_stream(StreamArn=stream_arn)
729+
return describe_stream_response["StreamDescription"]["StreamStatus"] == "ENABLED"
730+
731+
poll_condition(is_stream_ready)
732+
733+
return _wait_for_stream_ready
734+
735+
724736
@pytest.fixture()
725737
def kms_create_key(kms_client):
726738
key_ids = []

0 commit comments

Comments
 (0)
0