From d70eaa12f69230bddcab17ad5d58811edd76e3e9 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 21 Aug 2020 12:04:24 -0700 Subject: [PATCH 1/9] fix(ssm): Make decrypt an explicit option --- .../utilities/parameters/ssm.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index b458f8690d0..9208b6807c7 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -8,7 +8,7 @@ import boto3 from botocore.config import Config -from .base import DEFAULT_PROVIDERS, BaseProvider +from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider class SSMProvider(BaseProvider): @@ -86,6 +86,46 @@ def __init__( super().__init__() + def get( + _, + name: str, + max_age: int = DEFAULT_MAX_AGE_SECS, + transform: Optional[str] = None, + decrypt: bool = False, + **sdk_options + ) -> Union[str, list, dict, bytes]: + """ + Retrieve a parameter value or return the cached value + + Parameters + ---------- + name: str + Parameter name + max_age: int + Maximum age of the cached value + transform: str + Optional transformation of the parameter value. Supported values + are "json" for JSON strings and "binary" for base 64 encoded + values. + decrypt: bool, optional + If the parameter value should be decrypted + sdk_options: dict, optional + Arguments that will be passed directly to the underlying API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + """ + + # Add to `decrypt` sdk_options to we can have an explicit option for this + sdk_options["decrypt"] = decrypt + + return super().get(name, max_age, transform, **sdk_options) + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: """ Retrieve a parameter value from AWS Systems Manager Parameter Store From e9e3a4e94dc6d7bcadb924be796dfdef1dc317c6 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 21 Aug 2020 13:29:28 -0700 Subject: [PATCH 2/9] chore: declare as self --- aws_lambda_powertools/utilities/parameters/ssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 9208b6807c7..5c20fb54b5e 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -87,7 +87,7 @@ def __init__( super().__init__() def get( - _, + self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, From 674a039c0e1ec5417bc25542f39ead52a5cd1a55 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 21 Aug 2020 19:08:28 -0700 Subject: [PATCH 3/9] fix: update get_parameter and get_parameters Changes: ssm.py - get_parameters - pass through the **sdk_options and merge in the recursive and decrypt params ssm.py - get_parameter - add explicit option for decrypt --- .../utilities/parameters/ssm.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 5c20fb54b5e..0f39bfac9c0 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -184,7 +184,9 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals return parameters -def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, list, dict, bytes]: +def get_parameter( + name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options +) -> Union[str, list, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -194,6 +196,8 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Name of the parameter transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameter API call @@ -230,7 +234,10 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> if "ssm" not in DEFAULT_PROVIDERS: DEFAULT_PROVIDERS["ssm"] = SSMProvider() - return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform) + # Add to `decrypt` sdk_options to we can have an explicit option for this + sdk_options["decrypt"] = decrypt + + return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options) def get_parameters( @@ -245,10 +252,10 @@ def get_parameters( Path to retrieve the parameters transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') - decrypt: bool, optional - If the parameter values should be decrypted recursive: bool, optional If this should retrieve the parameter values recursively or not, defaults to True + decrypt: bool, optional + If the parameter values should be decrypted sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call @@ -285,4 +292,7 @@ def get_parameters( if "ssm" not in DEFAULT_PROVIDERS: DEFAULT_PROVIDERS["ssm"] = SSMProvider() - return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, recursive=recursive, decrypt=decrypt) + sdk_options["recursive"] = recursive + sdk_options["decrypt"] = decrypt + + return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options) From 824d611afb9c25fd7219343ce00d0d5aa43897a4 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 21 Aug 2020 19:37:57 -0700 Subject: [PATCH 4/9] chore: fix typos and type hinting --- aws_lambda_powertools/utilities/parameters/base.py | 4 ++-- aws_lambda_powertools/utilities/parameters/secrets.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 8a552b53bcb..7e3f838528a 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -87,7 +87,7 @@ def get( @abstractmethod def _get(self, name: str, **sdk_options) -> str: """ - Retrieve paramater value from the underlying parameter store + Retrieve parameter value from the underlying parameter store """ raise NotImplementedError() @@ -168,7 +168,7 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]: Parameters --------- value: str - Parameter alue to transform + Parameter value to transform transform: str Type of transform, supported values are "json" and "binary" diff --git a/aws_lambda_powertools/utilities/parameters/secrets.py b/aws_lambda_powertools/utilities/parameters/secrets.py index ee4585309fe..67cb94c340b 100644 --- a/aws_lambda_powertools/utilities/parameters/secrets.py +++ b/aws_lambda_powertools/utilities/parameters/secrets.py @@ -77,7 +77,7 @@ def _get(self, name: str, **sdk_options) -> str: ---------- name: str Name of the parameter - sdk_options: dict + sdk_options: dict, optional Dictionary of options that will be passed to the Secrets Manager get_secret_value API call """ From 2cca4ebb878263bb724b1e5fddde0d6f5d7c39ea Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 21 Aug 2020 20:27:07 -0700 Subject: [PATCH 5/9] tests: verify that the default kwargs are set - `decrypt` should be false by default - `recursive` should be true by default --- tests/functional/test_utilities_parameters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 7a0677b2197..8c4e20fac52 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1310,6 +1310,7 @@ def test_get_parameter_new(monkeypatch, mock_name, mock_value): class TestProvider(BaseProvider): def _get(self, name: str, **kwargs) -> str: assert name == mock_name + assert not kwargs["decrypt"] return mock_value def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: @@ -1355,6 +1356,8 @@ def _get(self, name: str, **kwargs) -> str: def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert path == mock_name + assert kwargs["recursive"] + assert not kwargs["decrypt"] return mock_value monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) From a33bf4ee09bfa6c201ea2f8d5f6b8134e4ae7d46 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 22 Aug 2020 01:50:34 -0700 Subject: [PATCH 6/9] fix(capture_method): should yield inside with (#124) Changes: * capture_method should yield from within the "with" statement * Add missing test cases Closes #112 --- aws_lambda_powertools/tracing/tracer.py | 3 +- tests/unit/test_tracing.py | 91 +++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 0ce55e60837..4c12be3fc26 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -487,14 +487,13 @@ def decorate(*args, **kwargs): logger.debug(f"Calling method: {method_name}") with method(*args, **kwargs) as return_val: result = return_val + yield result self._add_response_as_metadata(function_name=method_name, data=result, subsegment=subsegment) except Exception as err: logger.exception(f"Exception received from '{method_name}' method") self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment) raise - yield result - return decorate def _decorate_sync_function(self, method: Callable = None): diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py index 8f7d9a646dd..16c476ee0fc 100644 --- a/tests/unit/test_tracing.py +++ b/tests/unit/test_tracing.py @@ -383,6 +383,72 @@ def handler(event, context): assert "test result" in result +def test_tracer_yield_from_context_manager_exception_metadata(mocker, provider_stub, in_subsegment_mock): + # GIVEN tracer is initialized + provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) + tracer = Tracer(provider=provider, service="booking") + + # WHEN capture_method decorator is used on a context manager + # and the method raises an exception + @tracer.capture_method + @contextlib.contextmanager + def yield_with_capture(): + yield "partial" + raise ValueError("test") + + with pytest.raises(ValueError): + with yield_with_capture() as partial_val: + assert partial_val == "partial" + + # THEN we should add the exception using method name as key plus error + # and their service name as the namespace + put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] + assert put_metadata_mock_args["key"] == "yield_with_capture error" + assert isinstance(put_metadata_mock_args["value"], ValueError) + assert put_metadata_mock_args["namespace"] == "booking" + + +def test_tracer_yield_from_nested_context_manager(mocker, provider_stub, in_subsegment_mock): + # GIVEN tracer is initialized + provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) + tracer = Tracer(provider=provider, service="booking") + + # WHEN capture_method decorator is used on a context manager nesting another context manager + class NestedContextManager(object): + def __enter__(self): + self._value = {"result": "test result"} + return self._value + + def __exit__(self, exc_type, exc_val, exc_tb): + self._value["result"] = "exit was called before yielding" + + @tracer.capture_method + @contextlib.contextmanager + def yield_with_capture(): + with NestedContextManager() as nested_context: + yield nested_context + + @tracer.capture_lambda_handler + def handler(event, context): + response = [] + with yield_with_capture() as yielded_value: + response.append(yielded_value["result"]) + + return response + + result = handler({}, {}) + + # THEN we should have a subsegment named after the method name + # and add its response as trace metadata + handler_trace, yield_function_trace = in_subsegment_mock.in_subsegment.call_args_list + + assert "test result" in in_subsegment_mock.put_metadata.call_args[1]["value"] + assert in_subsegment_mock.in_subsegment.call_count == 2 + assert handler_trace == mocker.call(name="## handler") + assert yield_function_trace == mocker.call(name="## yield_with_capture") + assert "test result" in result + + def test_tracer_yield_from_generator(mocker, provider_stub, in_subsegment_mock): # GIVEN tracer is initialized provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) @@ -411,3 +477,28 @@ def handler(event, context): assert handler_trace == mocker.call(name="## handler") assert generator_fn_trace == mocker.call(name="## generator_fn") assert "test result" in result + + +def test_tracer_yield_from_generator_exception_metadata(mocker, provider_stub, in_subsegment_mock): + # GIVEN tracer is initialized + provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) + tracer = Tracer(provider=provider, service="booking") + + # WHEN capture_method decorator is used on a generator function + # and the method raises an exception + @tracer.capture_method + def generator_fn(): + yield "partial" + raise ValueError("test") + + with pytest.raises(ValueError): + gen = generator_fn() + list(gen) + + # THEN we should add the exception using method name as key plus error + # and their service name as the namespace + put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] + assert put_metadata_mock_args["key"] == "generator_fn error" + assert put_metadata_mock_args["namespace"] == "booking" + assert isinstance(put_metadata_mock_args["value"], ValueError) + assert str(put_metadata_mock_args["value"]) == "test" From 77567c0ac180b4c6297fd8953d4ac44af4ee6049 Mon Sep 17 00:00:00 2001 From: Tom McCarthy Date: Sat, 22 Aug 2020 10:58:03 +0200 Subject: [PATCH 7/9] chore: version bump to 1.3.1 --- CHANGELOG.md | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5981ce910a4..ce5fba748b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.3.1] - 2020-08-22 +### Fixed +- **Tracer**: capture_method decorator did not properly handle nested context managers + ## [1.3.0] - 2020-08-21 ### Added - **Utilities**: Add new `parameters` utility to retrieve a single or multiple parameters from SSM Parameter Store, Secrets Manager, DynamoDB, or your very own diff --git a/pyproject.toml b/pyproject.toml index 75c74fb9bed..0cfd9c45bed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws_lambda_powertools" -version = "1.3.0" +version = "1.3.1" description = "Python utilities for AWS Lambda functions including but not limited to tracing, logging and custom metric" authors = ["Amazon Web Services"] classifiers=[ From d2184bc2871b35629dc9cdd47db152af354cde3b Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 22 Aug 2020 18:35:53 -0700 Subject: [PATCH 8/9] refactor: reduce get_multiple complexity Changes: - base.py - update get_multiple to reduce the overall complexity - base.py - `_has_not_expired` returns whether a key exists and has not expired - base.py - `transform_value` add `raise_on_transform_error` and default to True - test_utilities_parameters.py - Add a direct test of transform_value --- .../utilities/parameters/base.py | 66 ++++++++++--------- tests/functional/test_utilities_parameters.py | 10 +++ 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 7e3f838528a..340a93e22e7 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime, timedelta -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union from .exceptions import GetParameterError, TransformParameterError @@ -31,6 +31,9 @@ def __init__(self): self.store = {} + def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: + return key in self.store and self.store[key].ttl >= datetime.now() + def get( self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options ) -> Union[str, list, dict, bytes]: @@ -70,19 +73,21 @@ def get( # an acceptable tradeoff. key = (name, transform) - if key not in self.store or self.store[key].ttl < datetime.now(): - try: - value = self._get(name, **sdk_options) - # Encapsulate all errors into a generic GetParameterError - except Exception as exc: - raise GetParameterError(str(exc)) + if self._has_not_expired(key): + return self.store[key].value + + try: + value = self._get(name, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) - if transform is not None: - value = transform_value(value, transform) + if transform is not None: + value = transform_value(value, transform) - self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),) + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),) - return self.store[key].value + return value @abstractmethod def _get(self, name: str, **sdk_options) -> str: @@ -129,29 +134,21 @@ def get_multiple( key = (path, transform) - if key not in self.store or self.store[key].ttl < datetime.now(): - try: - values = self._get_multiple(path, **sdk_options) - # Encapsulate all errors into a generic GetParameterError - except Exception as exc: - raise GetParameterError(str(exc)) + if self._has_not_expired(key): + return self.store[key].value - if transform is not None: - new_values = {} - for key, value in values.items(): - try: - new_values[key] = transform_value(value, transform) - except Exception as exc: - if raise_on_transform_error: - raise exc - else: - new_values[key] = None + try: + values = self._get_multiple(path, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) - values = new_values + if transform is not None: + values = {k: transform_value(v, transform, raise_on_transform_error) for (k, v) in values.items()} - self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) + self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) - return self.store[key].value + return values @abstractmethod def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: @@ -161,7 +158,7 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: raise NotImplementedError() -def transform_value(value: str, transform: str) -> Union[dict, bytes]: +def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: """ Apply a transform to a value @@ -171,6 +168,9 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]: Parameter value to transform transform: str Type of transform, supported values are "json" and "binary" + raise_on_transform_error: bool, optional + Raises an exception if any transform fails, otherwise this will + return a None value for each transform that failed Raises ------ @@ -187,4 +187,6 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]: raise ValueError(f"Invalid transform type '{transform}'") except Exception as exc: - raise TransformParameterError(str(exc)) + if raise_on_transform_error: + raise TransformParameterError(str(exc)) + return None diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 8c4e20fac52..abd121540a6 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1471,3 +1471,13 @@ def test_transform_value_wrong(mock_value): parameters.base.transform_value(mock_value, "INCORRECT") assert "Invalid transform type" in str(excinfo) + + +def test_transform_value_ignore_error(mock_value): + """ + Test transform_value() does not raise errors when raise_on_transform_error is False + """ + + value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False) + + assert value is None From fc8ac13b788de0b446c76b5970c5db99762f2185 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 23 Aug 2020 09:36:02 -0700 Subject: [PATCH 9/9] refactor: revert to a regular for each Changes: * Add type hint to `values` as it can change later on in transform * Use a slightly faster and easier to read for each over dict comprehension --- aws_lambda_powertools/utilities/parameters/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 340a93e22e7..274cd96aace 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -138,13 +138,14 @@ def get_multiple( return self.store[key].value try: - values = self._get_multiple(path, **sdk_options) + values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options) # Encapsulate all errors into a generic GetParameterError except Exception as exc: raise GetParameterError(str(exc)) if transform is not None: - values = {k: transform_value(v, transform, raise_on_transform_error) for (k, v) in values.items()} + for (key, value) in values.items(): + values[key] = transform_value(value, transform, raise_on_transform_error) self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)