diff --git a/.chglog/CHANGELOG.tpl.md b/.chglog/CHANGELOG.tpl.md index bbbb9caef30..ed0c0f823f6 100755 --- a/.chglog/CHANGELOG.tpl.md +++ b/.chglog/CHANGELOG.tpl.md @@ -20,7 +20,7 @@ ### {{ .Title }} {{ range .Commits -}} -* {{ if .Scope }}**{{ upperFirst .Scope }}:** {{ end }}{{ .Subject }} +* {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }} {{ end }} {{ end -}} diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index 0990d6d0152..48f8783a49e 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -37,7 +37,7 @@ jobs: - name: Complexity baseline run: make complexity-baseline - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2.0.1 with: file: ./coverage.xml # flags: unittests diff --git a/CHANGELOG.md b/CHANGELOG.md index aea741921de..32af8abca95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,46 @@ This project follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) fo ## [Unreleased] +## [1.18.0] - 2021-07-20 + +### Bug Fixes + +* **api-gateway:** non-greedy route pattern regex which incorrectly mapped certain route params to function params ([#533](https://github.com/awslabs/aws-lambda-powertools-python/issues/533)) +* **api-gateway:** incorrect plain text mimetype constant [#506](https://github.com/awslabs/aws-lambda-powertools-python/issues/506) +* **data-classes:** include milliseconds in scalar types to correctly align with AppSync scalars ([#504](https://github.com/awslabs/aws-lambda-powertools-python/issues/504)) +* **mypy:** addresses lack of optional types ([#521](https://github.com/awslabs/aws-lambda-powertools-python/issues/521)) +* **parser:** make ApiGateway version, authorizer fields optional ([#532](https://github.com/awslabs/aws-lambda-powertools-python/issues/532)) +* **tracer:** mypy generic to preserve decorated method signature ([#529](https://github.com/awslabs/aws-lambda-powertools-python/issues/529)) + +### Code Refactoring + +* **feature-toggles:** code coverage and housekeeping ([#530](https://github.com/awslabs/aws-lambda-powertools-python/issues/530)) + +### Documentation + +* **api-gateway:** new HTTP service error exceptions ([#546](https://github.com/awslabs/aws-lambda-powertools-python/issues/546)) +* **logger:** new get_correlation_id method ([#545](https://github.com/awslabs/aws-lambda-powertools-python/issues/545)) + +### Features + +* **api-gateway:** add debug mode ([#507](https://github.com/awslabs/aws-lambda-powertools-python/issues/507)) +* **api-gateway:** add common HTTP service errors ([#506](https://github.com/awslabs/aws-lambda-powertools-python/issues/506)) +* **event-handler:** Support AppSyncResolverEvent subclassing ([#526](https://github.com/awslabs/aws-lambda-powertools-python/issues/526)) +* **feat-toggle:** New simple feature toggles rule engine (WIP) ([#494](https://github.com/awslabs/aws-lambda-powertools-python/issues/494)) +* **logger:** add get_correlation_id method ([#516](https://github.com/awslabs/aws-lambda-powertools-python/issues/516)) + +### Maintenance + +* **mypy:** add mypy support to makefile ([#508](https://github.com/awslabs/aws-lambda-powertools-python/issues/508)) +* **deps:** bump codecov/codecov-action from 1 to 2.0.1 ([#539](https://github.com/awslabs/aws-lambda-powertools-python/issues/539)) +* **deps:** bump boto3 from 1.18.0 to 1.18.1 ([#528](https://github.com/awslabs/aws-lambda-powertools-python/issues/528)) +* **deps:** bump boto3 from 1.17.110 to 1.18.0 ([#527](https://github.com/awslabs/aws-lambda-powertools-python/issues/527)) +* **deps:** bump boto3 from 1.17.102 to 1.17.110 ([#523](https://github.com/awslabs/aws-lambda-powertools-python/issues/523)) +* **deps-dev:** bump mkdocs-material from 7.1.10 to 7.1.11 ([#542](https://github.com/awslabs/aws-lambda-powertools-python/issues/542)) +* **deps-dev:** bump mkdocs-material from 7.1.9 to 7.1.10 ([#522](https://github.com/awslabs/aws-lambda-powertools-python/issues/522)) +* **deps-dev:** bump isort from 5.9.1 to 5.9.2 ([#514](https://github.com/awslabs/aws-lambda-powertools-python/issues/514)) +* **event-handler:** adjusts API Gateway/ALB service errors exception docstrings to not confuse AppSync customers + ## [1.17.1] - 2021-07-02 ### Bug Fixes diff --git a/Makefile b/Makefile index da43c1de67a..e098615b86c 100644 --- a/Makefile +++ b/Makefile @@ -83,3 +83,6 @@ release: pr changelog: @echo "[+] Pre-generating CHANGELOG for tag: $$(git describe --abbrev=0 --tag)" docker run -v "${PWD}":/workdir quay.io/git-chglog/git-chglog $$(git describe --abbrev=0 --tag).. > TMP_CHANGELOG.md + +mypy: + poetry run mypy --pretty aws_lambda_powertools diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 0475982e377..def92f706f9 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -2,6 +2,7 @@ Event handler decorators for common Lambda events """ +from .api_gateway import ApiGatewayResolver from .appsync import AppSyncResolver -__all__ = ["AppSyncResolver"] +__all__ = ["AppSyncResolver", "ApiGatewayResolver"] diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 2b1e1fc0900..6948646a360 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,11 +1,18 @@ import base64 import json import logging +import os import re +import traceback import zlib from enum import Enum +from http import HTTPStatus from typing import Any, Callable, Dict, List, Optional, Set, Union +from aws_lambda_powertools.event_handler import content_types +from aws_lambda_powertools.event_handler.exceptions import ServiceError +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -13,6 +20,9 @@ logger = logging.getLogger(__name__) +_DYNAMIC_ROUTE_PATTERN = r"(<\w+>)" +_NAMED_GROUP_BOUNDARY_PATTERN = r"(?P\1\\w+\\b)" + class ProxyEventType(Enum): """An enumerations of the supported proxy event types.""" @@ -25,43 +35,46 @@ class ProxyEventType(Enum): class CORSConfig(object): """CORS Config - Examples -------- Simple cors example using the default permissive cors, not this should only be used during early prototyping - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + ```python + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver - app = ApiGatewayResolver() + app = ApiGatewayResolver() - @app.get("/my/path", cors=True) - def with_cors(): - return {"message": "Foo"} + @app.get("/my/path", cors=True) + def with_cors(): + return {"message": "Foo"} + ``` Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors` do not include any cors headers. - from aws_lambda_powertools.event_handler.api_gateway import ( - ApiGatewayResolver, CORSConfig - ) - - cors_config = CORSConfig( - allow_origin="https://wwww.example.com/", - expose_headers=["x-exposed-response-header"], - allow_headers=["x-custom-request-header"], - max_age=100, - allow_credentials=True, - ) - app = ApiGatewayResolver(cors=cors_config) - - @app.get("/my/path") - def with_cors(): - return {"message": "Foo"} + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver, CORSConfig + ) + + cors_config = CORSConfig( + allow_origin="https://wwww.example.com/", + expose_headers=["x-exposed-response-header"], + allow_headers=["x-custom-request-header"], + max_age=100, + allow_credentials=True, + ) + app = ApiGatewayResolver(cors=cors_config) + + @app.get("/my/path") + def with_cors(): + return {"message": "Foo"} - @app.get("/another-one", cors=False) - def without_cors(): - return {"message": "Foo"} + @app.get("/another-one", cors=False) + def without_cors(): + return {"message": "Foo"} + ``` """ _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] @@ -116,7 +129,11 @@ class Response: """Response data class that provides greater control over what is returned from the proxy event""" def __init__( - self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None + self, + status_code: int, + content_type: Optional[str], + body: Union[str, bytes, None], + headers: Optional[Dict] = None, ): """ @@ -157,7 +174,7 @@ def __init__( class ResponseBuilder: """Internally used Response builder""" - def __init__(self, response: Response, route: Route = None): + def __init__(self, response: Response, route: Optional[Route] = None): self.response = response self.route = route @@ -189,7 +206,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""): self._compress() - def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any]: + def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, cors) @@ -237,7 +254,12 @@ def lambda_handler(event, context): current_event: BaseProxyEvent lambda_context: LambdaContext - def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: CORSConfig = None): + def __init__( + self, + proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, + cors: Optional[CORSConfig] = None, + debug: Optional[bool] = None, + ): """ Parameters ---------- @@ -245,14 +267,20 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: Proxy request type, defaults to API Gateway V1 cors: CORSConfig Optionally configure and enabled CORS. Not each route will need to have to cors=True + debug: Optional[bool] + Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG" + environment variable """ self._proxy_type = proxy_type self._routes: List[Route] = [] self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} + self._debug = resolve_truthy_env_var_choice( + env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug + ) - def get(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): """Get route decorator with GET `method` Examples @@ -277,7 +305,7 @@ def lambda_handler(event, context): """ return self.route(rule, "GET", cors, compress, cache_control) - def post(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): """Post route decorator with POST `method` Examples @@ -303,7 +331,7 @@ def lambda_handler(event, context): """ return self.route(rule, "POST", cors, compress, cache_control) - def put(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): """Put route decorator with PUT `method` Examples @@ -329,7 +357,9 @@ def lambda_handler(event, context): """ return self.route(rule, "PUT", cors, compress, cache_control) - def delete(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def delete( + self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None + ): """Delete route decorator with DELETE `method` Examples @@ -354,7 +384,9 @@ def lambda_handler(event, context): """ return self.route(rule, "DELETE", cors, compress, cache_control) - def patch(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def patch( + self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None + ): """Patch route decorator with PATCH `method` Examples @@ -382,7 +414,14 @@ def lambda_handler(event, context): """ return self.route(rule, "PATCH", cors, compress, cache_control) - def route(self, rule: str, method: str, cors: bool = None, compress: bool = False, cache_control: str = None): + def route( + self, + rule: str, + method: str, + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + ): """Route decorator includes parameter `method`""" def register_resolver(func: Callable): @@ -413,6 +452,8 @@ def resolve(self, event, context) -> Dict[str, Any]: dict Returns the dict response """ + if self._debug: + print(self._json_dump(event)) self.current_event = self._to_proxy_event(event) self.lambda_context = context return self._resolve().build(self.current_event, self._cors) @@ -422,8 +463,35 @@ def __call__(self, event, context) -> Any: @staticmethod def _compile_regex(rule: str): - """Precompile regex pattern""" - rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) + """Precompile regex pattern + + Logic + ----- + + 1. Find any dynamic routes defined as + e.g. @app.get("/accounts/") + 2. Create a new regex by substituting every dynamic route found as a named group (?P), + and match whole words only (word boundary) instead of a greedy match + + non-greedy example with word boundary + + rule: '/accounts/' + regex: r'/accounts/(?P\\w+\\b)' + + value: /accounts/123/some_other_path + account_id: 123 + + greedy example without word boundary + + regex: r'/accounts/(?P.+)' + + value: /accounts/123/some_other_path + account_id: 123/some_other_path + 3. Compiles a regex and include start (^) and end ($) in between for an exact match + + NOTE: See #520 for context + """ + rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule) return re.compile("^{}$".format(rule_regex)) def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: @@ -447,7 +515,7 @@ def _resolve(self) -> ResponseBuilder: match: Optional[re.Match] = route.rule.match(path) if match: logger.debug("Found a registered route. Calling function") - return self._call_route(route, match.groupdict()) + return self._call_route(route, match.groupdict()) # pass fn args logger.debug(f"No match found for path {path} and method {method}") return self._not_found(method) @@ -466,19 +534,41 @@ def _not_found(self, method: str) -> ResponseBuilder: return ResponseBuilder( Response( - status_code=404, - content_type="application/json", + status_code=HTTPStatus.NOT_FOUND.value, + content_type=content_types.APPLICATION_JSON, headers=headers, - body=json.dumps({"message": "Not found"}), + body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}), ) ) def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" - return ResponseBuilder(self._to_response(route.func(**args)), route) - - @staticmethod - def _to_response(result: Union[Dict, Response]) -> Response: + try: + return ResponseBuilder(self._to_response(route.func(**args)), route) + except ServiceError as e: + return ResponseBuilder( + Response( + status_code=e.status_code, + content_type=content_types.APPLICATION_JSON, + body=self._json_dump({"statusCode": e.status_code, "message": e.msg}), + ), + route, + ) + except Exception: + if self._debug: + # If the user has turned on debug mode, + # we'll let the original exception propagate so + # they get more information about what went wrong. + return ResponseBuilder( + Response( + status_code=500, + content_type=content_types.TEXT_PLAIN, + body="".join(traceback.format_exc()), + ) + ) + raise + + def _to_response(self, result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response 2 main result types are supported: @@ -493,6 +583,13 @@ def _to_response(result: Union[Dict, Response]) -> Response: logger.debug("Simple response detected, serializing return before constructing final response") return Response( status_code=200, - content_type="application/json", - body=json.dumps(result, separators=(",", ":"), cls=Encoder), + content_type=content_types.APPLICATION_JSON, + body=self._json_dump(result), ) + + def _json_dump(self, obj: Any) -> str: + """Does a concise json serialization or pretty print when in debug mode""" + if self._debug: + return json.dumps(obj, indent=4, cls=Encoder) + else: + return json.dumps(obj, separators=(",", ":"), cls=Encoder) diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index 021afaa6654..69b90c4cbb6 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -1,11 +1,13 @@ import logging -from typing import Any, Callable +from typing import Any, Callable, Optional, Type, TypeVar from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) +AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) + class AppSyncResolver: """ @@ -38,13 +40,13 @@ def common_field() -> str: return str(uuid.uuid4()) """ - current_event: AppSyncResolverEvent + current_event: AppSyncResolverEventT # type: ignore[valid-type] lambda_context: LambdaContext def __init__(self): self._resolvers: dict = {} - def resolver(self, type_name: str = "*", field_name: str = None): + def resolver(self, type_name: str = "*", field_name: Optional[str] = None): """Registers the resolver for field_name Parameters @@ -62,7 +64,9 @@ def register_resolver(func): return register_resolver - def resolve(self, event: dict, context: LambdaContext) -> Any: + def resolve( + self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + ) -> Any: """Resolve field_name Parameters @@ -71,6 +75,56 @@ def resolve(self, event: dict, context: LambdaContext) -> Any: Lambda event context : LambdaContext Lambda context + data_model: + Your data data_model to decode AppSync event, by default AppSyncResolverEvent + + Example + ------- + + ```python + from aws_lambda_powertools.event_handler import AppSyncResolver + from aws_lambda_powertools.utilities.typing import LambdaContext + + @app.resolver(field_name="createSomething") + def create_something(id: str): # noqa AA03 VNE003 + return id + + def handler(event, context: LambdaContext): + return app.resolve(event, context) + ``` + + **Bringing custom models** + + ```python + from aws_lambda_powertools import Logger, Tracer + + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler import AppSyncResolver + + tracer = Tracer(service="sample_resolver") + logger = Logger(service="sample_resolver") + app = AppSyncResolver() + + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self) -> str: + return self.request_headers.get("cloudfront-viewer-country") + + + @app.resolver(field_name="listLocations") + @app.resolver(field_name="locations") + def get_locations(name: str, description: str = ""): + if app.current_event.country_viewer == "US": + ... + return name + description + + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context, data_model=MyCustomModel) + ``` Returns ------- @@ -82,7 +136,7 @@ def resolve(self, event: dict, context: LambdaContext) -> Any: ValueError If we could not find a field resolver """ - self.current_event = AppSyncResolverEvent(event) + self.current_event = data_model(event) self.lambda_context = context resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) return resolver(**self.current_event.arguments) @@ -108,6 +162,8 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable: raise ValueError(f"No resolver found for '{full_name}'") return resolver["func"] - def __call__(self, event, context) -> Any: + def __call__( + self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + ) -> Any: """Implicit lambda handler which internally calls `resolve`""" - return self.resolve(event, context) + return self.resolve(event, context, data_model) diff --git a/aws_lambda_powertools/event_handler/content_types.py b/aws_lambda_powertools/event_handler/content_types.py new file mode 100644 index 00000000000..0f55b1088ad --- /dev/null +++ b/aws_lambda_powertools/event_handler/content_types.py @@ -0,0 +1,5 @@ +# use mimetypes library to be certain, e.g., mimetypes.types_map[".json"] + +APPLICATION_JSON = "application/json" +TEXT_PLAIN = "text/plain" +TEXT_HTML = "text/html" diff --git a/aws_lambda_powertools/event_handler/exceptions.py b/aws_lambda_powertools/event_handler/exceptions.py new file mode 100644 index 00000000000..4a2838275b1 --- /dev/null +++ b/aws_lambda_powertools/event_handler/exceptions.py @@ -0,0 +1,45 @@ +from http import HTTPStatus + + +class ServiceError(Exception): + """API Gateway and ALB HTTP Service Error""" + + def __init__(self, status_code: int, msg: str): + """ + Parameters + ---------- + status_code: int + Http status code + msg: str + Error message + """ + self.status_code = status_code + self.msg = msg + + +class BadRequestError(ServiceError): + """API Gateway and ALB Bad Request Error (400)""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.BAD_REQUEST, msg) + + +class UnauthorizedError(ServiceError): + """API Gateway and ALB Unauthorized Error (401)""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.UNAUTHORIZED, msg) + + +class NotFoundError(ServiceError): + """API Gateway and ALB Not Found Error (404)""" + + def __init__(self, msg: str = "Not found"): + super().__init__(HTTPStatus.NOT_FOUND, msg) + + +class InternalServerError(ServiceError): + """API Gateway and ALB Not Found Internal Server Error (500)""" + + def __init__(self, message: str): + super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message) diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index 7ff9881062a..de9254a3371 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -60,8 +60,8 @@ def __init__( json_serializer: Optional[Callable[[Dict], str]] = None, json_deserializer: Optional[Callable[[Dict], str]] = None, json_default: Optional[Callable[[Any], Any]] = None, - datefmt: str = None, - log_record_order: List[str] = None, + datefmt: Optional[str] = None, + log_record_order: Optional[List[str]] = None, utc: bool = False, **kwargs ): diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 689409d9813..35054f86137 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -4,7 +4,7 @@ import os import random import sys -from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union +from typing import IO, Any, Callable, Dict, Iterable, Optional, TypeVar, Union import jmespath @@ -167,11 +167,11 @@ class Logger(logging.Logger): # lgtm [py/missing-call-to-init] def __init__( self, - service: str = None, - level: Union[str, int] = None, + service: Optional[str] = None, + level: Union[str, int, None] = None, child: bool = False, - sampling_rate: float = None, - stream: sys.stdout = None, + sampling_rate: Optional[float] = None, + stream: Optional[IO[str]] = None, logger_formatter: Optional[PowertoolsFormatter] = None, logger_handler: Optional[logging.Handler] = None, **kwargs, @@ -261,10 +261,10 @@ def _configure_sampling(self): def inject_lambda_context( self, - lambda_handler: Callable[[Dict, Any], Any] = None, - log_event: bool = None, - correlation_id_path: str = None, - clear_state: bool = False, + lambda_handler: Optional[Callable[[Dict, Any], Any]] = None, + log_event: Optional[bool] = None, + correlation_id_path: Optional[str] = None, + clear_state: Optional[bool] = False, ): """Decorator to capture Lambda contextual info and inject into logger @@ -324,7 +324,7 @@ def handler(event, context): ) log_event = resolve_truthy_env_var_choice( - choice=log_event, env=os.getenv(constants.LOGGER_LOG_EVENT_ENV, "false") + env=os.getenv(constants.LOGGER_LOG_EVENT_ENV, "false"), choice=log_event ) @functools.wraps(lambda_handler) @@ -363,7 +363,7 @@ def registered_handler(self) -> logging.Handler: @property def registered_formatter(self) -> Optional[PowertoolsFormatter]: """Convenience property to access logger formatter""" - return self.registered_handler.formatter + return self.registered_handler.formatter # type: ignore def structure_logs(self, append: bool = False, **keys): """Sets logging formatting to JSON. @@ -384,19 +384,29 @@ def structure_logs(self, append: bool = False, **keys): self.append_keys(**keys) else: log_keys = {**self._default_log_keys, **keys} - formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys) + formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys) # type: ignore self.registered_handler.setFormatter(formatter) - def set_correlation_id(self, value: str): + def set_correlation_id(self, value: Optional[str]): """Sets the correlation_id in the logging json Parameters ---------- - value : str - Value for the correlation id + value : str, optional + Value for the correlation id. None will remove the correlation_id """ self.append_keys(correlation_id=value) + def get_correlation_id(self) -> Optional[str]: + """Gets the correlation_id in the logging json + + Returns + ------- + str, optional + Value for the correlation id + """ + return self.registered_formatter.log_format.get("correlation_id") + @staticmethod def _get_log_level(level: Union[str, int, None]) -> Union[str, int]: """Returns preferred log level set by the customer in upper case""" @@ -421,7 +431,9 @@ def _get_caller_filename(): def set_package_logger( - level: Union[str, int] = logging.DEBUG, stream: sys.stdout = None, formatter: logging.Formatter = None + level: Union[str, int] = logging.DEBUG, + stream: Optional[IO[str]] = None, + formatter: Optional[logging.Formatter] = None, ): """Set an additional stream handler, formatter, and log level for aws_lambda_powertools package logger. diff --git a/aws_lambda_powertools/metrics/base.py b/aws_lambda_powertools/metrics/base.py index dc4fe34ee12..853f06f210b 100644 --- a/aws_lambda_powertools/metrics/base.py +++ b/aws_lambda_powertools/metrics/base.py @@ -5,7 +5,7 @@ import os from collections import defaultdict from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from ..shared import constants from ..shared.functions import resolve_env_var_choice @@ -76,11 +76,11 @@ class MetricManager: def __init__( self, - metric_set: Dict[str, Any] = None, - dimension_set: Dict = None, - namespace: str = None, - metadata_set: Dict[str, Any] = None, - service: str = None, + metric_set: Optional[Dict[str, Any]] = None, + dimension_set: Optional[Dict] = None, + namespace: Optional[str] = None, + metadata_set: Optional[Dict[str, Any]] = None, + service: Optional[str] = None, ): self.metric_set = metric_set if metric_set is not None else {} self.dimension_set = dimension_set if dimension_set is not None else {} @@ -136,7 +136,9 @@ def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float): # since we could have more than 100 metrics self.metric_set.clear() - def serialize_metric_set(self, metrics: Dict = None, dimensions: Dict = None, metadata: Dict = None) -> Dict: + def serialize_metric_set( + self, metrics: Optional[Dict] = None, dimensions: Optional[Dict] = None, metadata: Optional[Dict] = None + ) -> Dict: """Serializes metric and dimensions set Parameters diff --git a/aws_lambda_powertools/metrics/metric.py b/aws_lambda_powertools/metrics/metric.py index 8bdd0d800b8..1ac2bd9450e 100644 --- a/aws_lambda_powertools/metrics/metric.py +++ b/aws_lambda_powertools/metrics/metric.py @@ -61,7 +61,7 @@ def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float): @contextmanager -def single_metric(name: str, unit: MetricUnit, value: float, namespace: str = None): +def single_metric(name: str, unit: MetricUnit, value: float, namespace: Optional[str] = None): """Context manager to simplify creation of a single metric Example diff --git a/aws_lambda_powertools/metrics/metrics.py b/aws_lambda_powertools/metrics/metrics.py index 8cc4895f03e..fafc604b505 100644 --- a/aws_lambda_powertools/metrics/metrics.py +++ b/aws_lambda_powertools/metrics/metrics.py @@ -71,7 +71,7 @@ def lambda_handler(): _metadata: Dict[str, Any] = {} _default_dimensions: Dict[str, Any] = {} - def __init__(self, service: str = None, namespace: str = None): + def __init__(self, service: Optional[str] = None, namespace: Optional[str] = None): self.metric_set = self._metrics self.service = service self.namespace: Optional[str] = namespace @@ -125,10 +125,10 @@ def clear_metrics(self): def log_metrics( self, - lambda_handler: Callable[[Any, Any], Any] = None, + lambda_handler: Optional[Callable[[Any, Any], Any]] = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, - default_dimensions: Dict[str, str] = None, + default_dimensions: Optional[Dict[str, str]] = None, ): """Decorator to serialize and publish metrics at the end of a function execution. diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index 77277052272..74858bf6709 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -2,7 +2,7 @@ import inspect import logging import os -from typing import Callable +from typing import Callable, Optional from ..shared import constants from ..shared.functions import resolve_truthy_env_var_choice @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def lambda_handler_decorator(decorator: Callable = None, trace_execution: bool = None): +def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None): """Decorator factory for decorating Lambda handlers. You can use lambda_handler_decorator to create your own middlewares, @@ -106,11 +106,11 @@ def lambda_handler(event, context): return functools.partial(lambda_handler_decorator, trace_execution=trace_execution) trace_execution = resolve_truthy_env_var_choice( - choice=trace_execution, env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false") + env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false"), choice=trace_execution ) @functools.wraps(decorator) - def final_decorator(func: Callable = None, **kwargs): + def final_decorator(func: Optional[Callable] = None, **kwargs): # If called with kwargs return new func with kwargs if func is None: return functools.partial(final_decorator, **kwargs) diff --git a/aws_lambda_powertools/shared/constants.py b/aws_lambda_powertools/shared/constants.py index eaad5640dfd..8388eded654 100644 --- a/aws_lambda_powertools/shared/constants.py +++ b/aws_lambda_powertools/shared/constants.py @@ -10,11 +10,12 @@ METRICS_NAMESPACE_ENV: str = "POWERTOOLS_METRICS_NAMESPACE" +EVENT_HANDLER_DEBUG_ENV: str = "POWERTOOLS_EVENT_HANDLER_DEBUG" + SAM_LOCAL_ENV: str = "AWS_SAM_LOCAL" CHALICE_LOCAL_ENV: str = "AWS_CHALICE_CLI_MODE" SERVICE_NAME_ENV: str = "POWERTOOLS_SERVICE_NAME" XRAY_TRACE_ID_ENV: str = "_X_AMZN_TRACE_ID" - -XRAY_SDK_MODULE = "aws_xray_sdk" -XRAY_SDK_CORE_MODULE = "aws_xray_sdk.core" +XRAY_SDK_MODULE: str = "aws_xray_sdk" +XRAY_SDK_CORE_MODULE: str = "aws_xray_sdk.core" diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index b8f5cb9f74b..0b117cc32bb 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -2,14 +2,14 @@ from typing import Any, Optional, Union -def resolve_truthy_env_var_choice(env: Any, choice: bool = None) -> bool: +def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bool: """Pick explicit choice over truthy env value, if available, otherwise return truthy env value NOTE: Environment variable should be resolved by the caller. Parameters ---------- - env : Any + env : str environment variable actual value choice : bool explicit choice diff --git a/aws_lambda_powertools/tracing/base.py b/aws_lambda_powertools/tracing/base.py index 1857ed52a73..722652ce08b 100644 --- a/aws_lambda_powertools/tracing/base.py +++ b/aws_lambda_powertools/tracing/base.py @@ -2,11 +2,11 @@ import numbers import traceback from contextlib import contextmanager -from typing import Any, AsyncContextManager, ContextManager, List, NoReturn, Set, Union +from typing import Any, AsyncContextManager, ContextManager, List, NoReturn, Optional, Set, Union class BaseProvider(abc.ABC): - @abc.abstractmethod + @abc.abstractmethod # type: ignore @contextmanager def in_subsegment(self, name=None, **kwargs) -> ContextManager: """Return a subsegment context manger. @@ -19,7 +19,7 @@ def in_subsegment(self, name=None, **kwargs) -> ContextManager: Optional parameters to be propagated to segment """ - @abc.abstractmethod + @abc.abstractmethod # type: ignore @contextmanager def in_subsegment_async(self, name=None, **kwargs) -> AsyncContextManager: """Return a subsegment async context manger. @@ -81,7 +81,7 @@ class BaseSegment(abc.ABC): """Holds common properties and methods on segment and subsegment.""" @abc.abstractmethod - def close(self, end_time: int = None): + def close(self, end_time: Optional[int] = None): """Close the trace entity by setting `end_time` and flip the in progress flag to False. diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 47568802202..5709b1956c2 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -5,7 +5,7 @@ import logging import numbers import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload from ..shared import constants from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice @@ -18,6 +18,9 @@ aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) +AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 +AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable) + class Tracer: """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions @@ -53,7 +56,7 @@ class Tracer: disabled: bool Flag to explicitly disable tracing, useful when running/testing locally `Env POWERTOOLS_TRACE_DISABLED="true"` - patch_modules: Tuple[str] + patch_modules: Optional[Sequence[str]] Tuple of modules supported by tracing provider to patch, by default all modules are patched provider: BaseProvider Tracing provider, by default it is aws_xray_sdk.core.xray_recorder @@ -146,11 +149,11 @@ def handler(event: dict, context: Any) -> Dict: def __init__( self, - service: str = None, - disabled: bool = None, - auto_patch: bool = None, - patch_modules: Optional[Tuple[str]] = None, - provider: BaseProvider = None, + service: Optional[str] = None, + disabled: Optional[bool] = None, + auto_patch: Optional[bool] = None, + patch_modules: Optional[Sequence[str]] = None, + provider: Optional[BaseProvider] = None, ): self.__build_config( service=service, disabled=disabled, auto_patch=auto_patch, patch_modules=patch_modules, provider=provider @@ -195,7 +198,7 @@ def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]): logger.debug(f"Annotating on key '{key}' with '{value}'") self.provider.put_annotation(key=key, value=value) - def put_metadata(self, key: str, value: Any, namespace: str = None): + def put_metadata(self, key: str, value: Any, namespace: Optional[str] = None): """Adds metadata to existing segment or subsegment Parameters @@ -223,14 +226,14 @@ def put_metadata(self, key: str, value: Any, namespace: str = None): logger.debug(f"Adding metadata on key '{key}' with '{value}' at namespace '{namespace}'") self.provider.put_metadata(key=key, value=value, namespace=namespace) - def patch(self, modules: Tuple[str] = None): + def patch(self, modules: Optional[Sequence[str]] = None): """Patch modules for instrumentation. Patches all supported modules by default if none are given. Parameters ---------- - modules : Tuple[str] + modules : Optional[Sequence[str]] List of modules to be patched, optional by default """ if self.disabled: @@ -244,7 +247,7 @@ def patch(self, modules: Tuple[str] = None): def capture_lambda_handler( self, - lambda_handler: Union[Callable[[Dict, Any], Any], Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: Union[Callable[[Dict, Any], Any], Optional[Callable[[Dict, Any, Optional[Dict]], Any]]] = None, capture_response: Optional[bool] = None, capture_error: Optional[bool] = None, ): @@ -329,9 +332,26 @@ def decorate(event, context, **kwargs): return decorate + # see #465 + @overload + def capture_method(self, method: "AnyCallableT") -> "AnyCallableT": + ... + + @overload def capture_method( - self, method: Callable = None, capture_response: Optional[bool] = None, capture_error: Optional[bool] = None - ): + self, + method: None = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ) -> Callable[["AnyCallableT"], "AnyCallableT"]: + ... + + def capture_method( + self, + method: Optional[AnyCallableT] = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ) -> AnyCallableT: """Decorator to create subsegment for arbitrary functions It also captures both response and exceptions as metadata @@ -484,8 +504,9 @@ async def async_tasks(): # Return a partial function with args filled if method is None: logger.debug("Decorator called with parameters") - return functools.partial( - self.capture_method, capture_response=capture_response, capture_error=capture_error + return cast( + AnyCallableT, + functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error), ) method_name = f"{method.__name__}" @@ -506,7 +527,7 @@ async def async_tasks(): return self._decorate_generator_function( method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name ) - elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): + elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): # type: ignore return self._decorate_generator_function_with_context_manager( method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name ) @@ -520,7 +541,7 @@ def _decorate_async_function( method: Callable, capture_response: Optional[Union[bool, str]] = None, capture_error: Optional[Union[bool, str]] = None, - method_name: str = None, + method_name: Optional[str] = None, ): @functools.wraps(method) async def decorate(*args, **kwargs): @@ -547,7 +568,7 @@ def _decorate_generator_function( method: Callable, capture_response: Optional[Union[bool, str]] = None, capture_error: Optional[Union[bool, str]] = None, - method_name: str = None, + method_name: Optional[str] = None, ): @functools.wraps(method) def decorate(*args, **kwargs): @@ -574,7 +595,7 @@ def _decorate_generator_function_with_context_manager( method: Callable, capture_response: Optional[Union[bool, str]] = None, capture_error: Optional[Union[bool, str]] = None, - method_name: str = None, + method_name: Optional[str] = None, ): @functools.wraps(method) @contextlib.contextmanager @@ -599,11 +620,11 @@ def decorate(*args, **kwargs): def _decorate_sync_function( self, - method: Callable, + method: AnyCallableT, capture_response: Optional[Union[bool, str]] = None, capture_error: Optional[Union[bool, str]] = None, - method_name: str = None, - ): + method_name: Optional[str] = None, + ) -> AnyCallableT: @functools.wraps(method) def decorate(*args, **kwargs): with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: @@ -625,13 +646,13 @@ def decorate(*args, **kwargs): return response - return decorate + return cast(AnyCallableT, decorate) def _add_response_as_metadata( self, - method_name: str = None, - data: Any = None, - subsegment: BaseSegment = None, + method_name: Optional[str] = None, + data: Optional[Any] = None, + subsegment: Optional[BaseSegment] = None, capture_response: Optional[Union[bool, str]] = None, ): """Add response as metadata for given subsegment @@ -714,11 +735,11 @@ def _is_tracer_disabled() -> Union[bool, str]: def __build_config( self, - service: str = None, - disabled: bool = None, - auto_patch: bool = None, - patch_modules: Union[List, Tuple] = None, - provider: BaseProvider = None, + service: Optional[str] = None, + disabled: Optional[bool] = None, + auto_patch: Optional[bool] = None, + patch_modules: Optional[Sequence[str]] = None, + provider: Optional[BaseProvider] = None, ): """Populates Tracer config for new and existing initializations""" is_disabled = disabled if disabled is not None else self._is_tracer_disabled() diff --git a/aws_lambda_powertools/utilities/data_classes/appsync/scalar_types_utils.py b/aws_lambda_powertools/utilities/data_classes/appsync/scalar_types_utils.py index 71bfbe7046a..b83a947c3f8 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync/scalar_types_utils.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync/scalar_types_utils.py @@ -19,15 +19,21 @@ def _formatted_time(now: datetime.date, fmt: str, timezone_offset: int) -> str: str Returns string formatted time with optional timezone offset """ - if timezone_offset == 0: - return now.strftime(fmt + "Z") + if timezone_offset != 0: + now = now + datetime.timedelta(hours=timezone_offset) + + datetime_str = now.strftime(fmt) + if fmt.endswith(".%f"): + datetime_str = datetime_str[:-3] - now = now + datetime.timedelta(hours=timezone_offset) - fmt += "+" if timezone_offset > 0 else "-" - fmt += str(abs(timezone_offset)).zfill(2) - fmt += ":00:00" + if timezone_offset == 0: + postfix = "Z" + else: + postfix = "+" if timezone_offset > 0 else "-" + postfix += str(abs(timezone_offset)).zfill(2) + postfix += ":00:00" - return now.strftime(fmt) + return datetime_str + postfix def make_id() -> str: @@ -65,7 +71,7 @@ def aws_time(timezone_offset: int = 0) -> str: str Returns current time as AWSTime scalar string with optional timezone offset """ - return _formatted_time(datetime.datetime.utcnow(), "%H:%M:%S", timezone_offset) + return _formatted_time(datetime.datetime.utcnow(), "%H:%M:%S.%f", timezone_offset) def aws_datetime(timezone_offset: int = 0) -> str: @@ -81,7 +87,7 @@ def aws_datetime(timezone_offset: int = 0) -> str: str Returns current time as AWSDateTime scalar string with optional timezone offset """ - return _formatted_time(datetime.datetime.utcnow(), "%Y-%m-%dT%H:%M:%S", timezone_offset) + return _formatted_time(datetime.datetime.utcnow(), "%Y-%m-%dT%H:%M:%S.%f", timezone_offset) def aws_timestamp() -> int: diff --git a/aws_lambda_powertools/utilities/feature_toggles/__init__.py b/aws_lambda_powertools/utilities/feature_toggles/__init__.py new file mode 100644 index 00000000000..04237d63812 --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/__init__.py @@ -0,0 +1,16 @@ +"""Advanced feature toggles utility +""" +from .appconfig_fetcher import AppConfigFetcher +from .configuration_store import ConfigurationStore +from .exceptions import ConfigurationError +from .schema import ACTION, SchemaValidator +from .schema_fetcher import SchemaFetcher + +__all__ = [ + "ConfigurationError", + "ConfigurationStore", + "ACTION", + "SchemaValidator", + "AppConfigFetcher", + "SchemaFetcher", +] diff --git a/aws_lambda_powertools/utilities/feature_toggles/appconfig_fetcher.py b/aws_lambda_powertools/utilities/feature_toggles/appconfig_fetcher.py new file mode 100644 index 00000000000..ae7c6c90e51 --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/appconfig_fetcher.py @@ -0,0 +1,67 @@ +import logging +from typing import Any, Dict, Optional + +from botocore.config import Config + +from aws_lambda_powertools.utilities.parameters import AppConfigProvider, GetParameterError, TransformParameterError + +from .exceptions import ConfigurationError +from .schema_fetcher import SchemaFetcher + +logger = logging.getLogger(__name__) + + +TRANSFORM_TYPE = "json" + + +class AppConfigFetcher(SchemaFetcher): + def __init__( + self, + environment: str, + service: str, + configuration_name: str, + cache_seconds: int, + config: Optional[Config] = None, + ): + """This class fetches JSON schemas from AWS AppConfig + + Parameters + ---------- + environment: str + what appconfig environment to use 'dev/test' etc. + service: str + what service name to use from the supplied environment + configuration_name: str + what configuration to take from the environment & service combination + cache_seconds: int + cache expiration time, how often to call AppConfig to fetch latest configuration + config: Optional[Config] + boto3 client configuration + """ + super().__init__(configuration_name, cache_seconds) + self._logger = logger + self._conf_store = AppConfigProvider(environment=environment, application=service, config=config) + + def get_json_configuration(self) -> Dict[str, Any]: + """Get configuration string from AWs AppConfig and return the parsed JSON dictionary + + Raises + ------ + ConfigurationError + Any validation error or appconfig error that can occur + + Returns + ------- + Dict[str, Any] + parsed JSON dictionary + """ + try: + return self._conf_store.get( + name=self.configuration_name, + transform=TRANSFORM_TYPE, + max_age=self._cache_seconds, + ) # parse result conf as JSON, keep in cache for self.max_age seconds + except (GetParameterError, TransformParameterError) as exc: + error_str = f"unable to get AWS AppConfig configuration file, exception={str(exc)}" + self._logger.error(error_str) + raise ConfigurationError(error_str) diff --git a/aws_lambda_powertools/utilities/feature_toggles/configuration_store.py b/aws_lambda_powertools/utilities/feature_toggles/configuration_store.py new file mode 100644 index 00000000000..72d00bb9c03 --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/configuration_store.py @@ -0,0 +1,216 @@ +import logging +from typing import Any, Dict, List, Optional, cast + +from . import schema +from .exceptions import ConfigurationError +from .schema_fetcher import SchemaFetcher + +logger = logging.getLogger(__name__) + + +class ConfigurationStore: + def __init__(self, schema_fetcher: SchemaFetcher): + """constructor + + Parameters + ---------- + schema_fetcher: SchemaFetcher + A schema JSON fetcher, can be AWS AppConfig, Hashicorp Consul etc. + """ + self._logger = logger + self._schema_fetcher = schema_fetcher + self._schema_validator = schema.SchemaValidator(self._logger) + + def _match_by_action(self, action: str, condition_value: Any, context_value: Any) -> bool: + if not context_value: + return False + mapping_by_action = { + schema.ACTION.EQUALS.value: lambda a, b: a == b, + schema.ACTION.STARTSWITH.value: lambda a, b: a.startswith(b), + schema.ACTION.ENDSWITH.value: lambda a, b: a.endswith(b), + schema.ACTION.CONTAINS.value: lambda a, b: a in b, + } + + try: + func = mapping_by_action.get(action, lambda a, b: False) + return func(context_value, condition_value) + except Exception as exc: + self._logger.error(f"caught exception while matching action, action={action}, exception={str(exc)}") + return False + + def _is_rule_matched(self, feature_name: str, rule: Dict[str, Any], rules_context: Dict[str, Any]) -> bool: + rule_name = rule.get(schema.RULE_NAME_KEY, "") + rule_default_value = rule.get(schema.RULE_DEFAULT_VALUE) + conditions = cast(List[Dict], rule.get(schema.CONDITIONS_KEY)) + + for condition in conditions: + context_value = rules_context.get(str(condition.get(schema.CONDITION_KEY))) + if not self._match_by_action( + condition.get(schema.CONDITION_ACTION, ""), + condition.get(schema.CONDITION_VALUE), + context_value, + ): + logger.debug( + f"rule did not match action, rule_name={rule_name}, rule_default_value={rule_default_value}, " + f"feature_name={feature_name}, context_value={str(context_value)} " + ) + # context doesn't match condition + return False + # if we got here, all conditions match + logger.debug( + f"rule matched, rule_name={rule_name}, rule_default_value={rule_default_value}, " + f"feature_name={feature_name}" + ) + return True + return False + + def _handle_rules( + self, + *, + feature_name: str, + rules_context: Dict[str, Any], + feature_default_value: bool, + rules: List[Dict[str, Any]], + ) -> bool: + for rule in rules: + rule_default_value = rule.get(schema.RULE_DEFAULT_VALUE) + if self._is_rule_matched(feature_name, rule, rules_context): + return bool(rule_default_value) + # no rule matched, return default value of feature + logger.debug( + f"no rule matched, returning default value of feature, feature_default_value={feature_default_value}, " + f"feature_name={feature_name}" + ) + return feature_default_value + return False + + def get_configuration(self) -> Dict[str, Any]: + """Get configuration string from AWs AppConfig and returned the parsed JSON dictionary + + Raises + ------ + ConfigurationError + Any validation error or appconfig error that can occur + + Returns + ------ + Dict[str, Any] + parsed JSON dictionary + """ + # parse result conf as JSON, keep in cache for self.max_age seconds + config = self._schema_fetcher.get_json_configuration() + # validate schema + self._schema_validator.validate_json_schema(config) + return config + + def get_feature_toggle( + self, *, feature_name: str, rules_context: Optional[Dict[str, Any]] = None, value_if_missing: bool + ) -> bool: + """Get a feature toggle boolean value. Value is calculated according to a set of rules and conditions. + + See below for explanation. + + Parameters + ---------- + feature_name: str + feature name that you wish to fetch + rules_context: Optional[Dict[str, Any]] + dict of attributes that you would like to match the rules + against, can be {'tenant_id: 'X', 'username':' 'Y', 'region': 'Z'} etc. + value_if_missing: bool + this will be the returned value in case the feature toggle doesn't exist in + the schema or there has been an error while fetching the + configuration from appconfig + + Returns + ------ + bool + calculated feature toggle value. several possibilities: + 1. if the feature doesn't appear in the schema or there has been an error fetching the + configuration -> error/warning log would appear and value_if_missing is returned + 2. feature exists and has no rules or no rules have matched -> return feature_default_value of + the defined feature + 3. feature exists and a rule matches -> rule_default_value of rule is returned + """ + if rules_context is None: + rules_context = {} + + try: + toggles_dict: Dict[str, Any] = self.get_configuration() + except ConfigurationError: + logger.error("unable to get feature toggles JSON, returning provided value_if_missing value") + return value_if_missing + + feature: Dict[str, Dict] = toggles_dict.get(schema.FEATURES_KEY, {}).get(feature_name, None) + if feature is None: + logger.warning( + f"feature does not appear in configuration, using provided value_if_missing, " + f"feature_name={feature_name}, value_if_missing={value_if_missing}" + ) + return value_if_missing + + rules_list = feature.get(schema.RULES_KEY) + feature_default_value = feature.get(schema.FEATURE_DEFAULT_VAL_KEY) + if not rules_list: + # not rules but has a value + logger.debug( + f"no rules found, returning feature default value, feature_name={feature_name}, " + f"default_value={feature_default_value}" + ) + return bool(feature_default_value) + # look for first rule match + logger.debug( + f"looking for rule match, feature_name={feature_name}, feature_default_value={feature_default_value}" + ) + return self._handle_rules( + feature_name=feature_name, + rules_context=rules_context, + feature_default_value=bool(feature_default_value), + rules=cast(List, rules_list), + ) + + def get_all_enabled_feature_toggles(self, *, rules_context: Optional[Dict[str, Any]] = None) -> List[str]: + """Get all enabled feature toggles while also taking into account rule_context + (when a feature has defined rules) + + Parameters + ---------- + rules_context: Optional[Dict[str, Any]] + dict of attributes that you would like to match the rules + against, can be `{'tenant_id: 'X', 'username':' 'Y', 'region': 'Z'}` etc. + + Returns + ---------- + List[str] + a list of all features name that are enabled by also taking into account + rule_context (when a feature has defined rules) + """ + if rules_context is None: + rules_context = {} + + try: + toggles_dict: Dict[str, Any] = self.get_configuration() + except ConfigurationError: + logger.error("unable to get feature toggles JSON") + return [] + + ret_list = [] + features: Dict[str, Any] = toggles_dict.get(schema.FEATURES_KEY, {}) + for feature_name, feature_dict_def in features.items(): + rules_list = feature_dict_def.get(schema.RULES_KEY, []) + feature_default_value = feature_dict_def.get(schema.FEATURE_DEFAULT_VAL_KEY) + if feature_default_value and not rules_list: + self._logger.debug( + f"feature is enabled by default and has no defined rules, feature_name={feature_name}" + ) + ret_list.append(feature_name) + elif self._handle_rules( + feature_name=feature_name, + rules_context=rules_context, + feature_default_value=feature_default_value, + rules=rules_list, + ): + self._logger.debug(f"feature's calculated value is True, feature_name={feature_name}") + ret_list.append(feature_name) + + return ret_list diff --git a/aws_lambda_powertools/utilities/feature_toggles/exceptions.py b/aws_lambda_powertools/utilities/feature_toggles/exceptions.py new file mode 100644 index 00000000000..d87f9a39dec --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/exceptions.py @@ -0,0 +1,2 @@ +class ConfigurationError(Exception): + """When a a configuration store raises an exception on config retrieval or parsing""" diff --git a/aws_lambda_powertools/utilities/feature_toggles/schema.py b/aws_lambda_powertools/utilities/feature_toggles/schema.py new file mode 100644 index 00000000000..9d995ab59e4 --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/schema.py @@ -0,0 +1,84 @@ +from enum import Enum +from logging import Logger +from typing import Any, Dict + +from .exceptions import ConfigurationError + +FEATURES_KEY = "features" +RULES_KEY = "rules" +FEATURE_DEFAULT_VAL_KEY = "feature_default_value" +CONDITIONS_KEY = "conditions" +RULE_NAME_KEY = "rule_name" +RULE_DEFAULT_VALUE = "value_when_applies" +CONDITION_KEY = "key" +CONDITION_VALUE = "value" +CONDITION_ACTION = "action" + + +class ACTION(str, Enum): + EQUALS = "EQUALS" + STARTSWITH = "STARTSWITH" + ENDSWITH = "ENDSWITH" + CONTAINS = "CONTAINS" + + +class SchemaValidator: + def __init__(self, logger: Logger): + self._logger = logger + + def _raise_conf_exc(self, error_str: str) -> None: + self._logger.error(error_str) + raise ConfigurationError(error_str) + + def _validate_condition(self, rule_name: str, condition: Dict[str, str]) -> None: + if not condition or not isinstance(condition, dict): + self._raise_conf_exc(f"invalid condition type, not a dictionary, rule_name={rule_name}") + action = condition.get(CONDITION_ACTION, "") + if action not in [ACTION.EQUALS.value, ACTION.STARTSWITH.value, ACTION.ENDSWITH.value, ACTION.CONTAINS.value]: + self._raise_conf_exc(f"invalid action value, rule_name={rule_name}, action={action}") + key = condition.get(CONDITION_KEY, "") + if not key or not isinstance(key, str): + self._raise_conf_exc(f"invalid key value, key has to be a non empty string, rule_name={rule_name}") + value = condition.get(CONDITION_VALUE, "") + if not value: + self._raise_conf_exc(f"missing condition value, rule_name={rule_name}") + + def _validate_rule(self, feature_name: str, rule: Dict[str, Any]) -> None: + if not rule or not isinstance(rule, dict): + self._raise_conf_exc(f"feature rule is not a dictionary, feature_name={feature_name}") + rule_name = rule.get(RULE_NAME_KEY) + if not rule_name or rule_name is None or not isinstance(rule_name, str): + return self._raise_conf_exc(f"invalid rule_name, feature_name={feature_name}") + rule_default_value = rule.get(RULE_DEFAULT_VALUE) + if rule_default_value is None or not isinstance(rule_default_value, bool): + self._raise_conf_exc(f"invalid rule_default_value, rule_name={rule_name}") + conditions = rule.get(CONDITIONS_KEY, {}) + if not conditions or not isinstance(conditions, list): + self._raise_conf_exc(f"invalid condition, rule_name={rule_name}") + # validate conditions + for condition in conditions: + self._validate_condition(rule_name, condition) + + def _validate_feature(self, feature_name: str, feature_dict_def: Dict[str, Any]) -> None: + if not feature_dict_def or not isinstance(feature_dict_def, dict): + self._raise_conf_exc(f"invalid AWS AppConfig JSON schema detected, feature {feature_name} is invalid") + feature_default_value = feature_dict_def.get(FEATURE_DEFAULT_VAL_KEY) + if feature_default_value is None or not isinstance(feature_default_value, bool): + self._raise_conf_exc(f"missing feature_default_value for feature, feature_name={feature_name}") + # validate rules + rules = feature_dict_def.get(RULES_KEY, []) + if not rules: + return + if not isinstance(rules, list): + self._raise_conf_exc(f"feature rules is not a list, feature_name={feature_name}") + for rule in rules: + self._validate_rule(feature_name, rule) + + def validate_json_schema(self, schema: Dict[str, Any]) -> None: + if not isinstance(schema, dict): + self._raise_conf_exc("invalid AWS AppConfig JSON schema detected, root schema is not a dictionary") + features_dict = schema.get(FEATURES_KEY) + if not isinstance(features_dict, dict): + return self._raise_conf_exc("invalid AWS AppConfig JSON schema detected, missing features dictionary") + for feature_name, feature_dict_def in features_dict.items(): + self._validate_feature(feature_name, feature_dict_def) diff --git a/aws_lambda_powertools/utilities/feature_toggles/schema_fetcher.py b/aws_lambda_powertools/utilities/feature_toggles/schema_fetcher.py new file mode 100644 index 00000000000..89fffe1221d --- /dev/null +++ b/aws_lambda_powertools/utilities/feature_toggles/schema_fetcher.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class SchemaFetcher(ABC): + def __init__(self, configuration_name: str, cache_seconds: int): + self.configuration_name = configuration_name + self._cache_seconds = cache_seconds + + @abstractmethod + def get_json_configuration(self) -> Dict[str, Any]: + """Get configuration string from any configuration storing service and return the parsed JSON dictionary + + Raises + ------ + ConfigurationError + Any error that can occur during schema fetch or JSON parse + + Returns + ------- + Dict[str, Any] + parsed JSON dictionary + """ + return NotImplemented # pragma: no cover diff --git a/aws_lambda_powertools/utilities/idempotency/config.py b/aws_lambda_powertools/utilities/idempotency/config.py index 52afb3bad8c..06468cc74a7 100644 --- a/aws_lambda_powertools/utilities/idempotency/config.py +++ b/aws_lambda_powertools/utilities/idempotency/config.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional class IdempotencyConfig: @@ -6,7 +6,7 @@ def __init__( self, event_key_jmespath: str = "", payload_validation_jmespath: str = "", - jmespath_options: Dict = None, + jmespath_options: Optional[Dict] = None, raise_on_no_idempotency_key: bool = False, expires_after_seconds: int = 60 * 60, # 1 hour default use_local_cache: bool = False, diff --git a/aws_lambda_powertools/utilities/idempotency/idempotency.py b/aws_lambda_powertools/utilities/idempotency/idempotency.py index 6f73a842af4..c2bcc62fd69 100644 --- a/aws_lambda_powertools/utilities/idempotency/idempotency.py +++ b/aws_lambda_powertools/utilities/idempotency/idempotency.py @@ -31,7 +31,7 @@ def idempotent( event: Dict[str, Any], context: LambdaContext, persistence_store: BasePersistenceLayer, - config: IdempotencyConfig = None, + config: Optional[IdempotencyConfig] = None, ) -> Any: """ Middleware to handle idempotency diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index 31aef6dc0f2..eb43a8b30c5 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -39,9 +39,9 @@ def __init__( self, idempotency_key, status: str = "", - expiry_timestamp: int = None, + expiry_timestamp: Optional[int] = None, response_data: Optional[str] = "", - payload_hash: str = None, + payload_hash: Optional[str] = None, ) -> None: """ diff --git a/aws_lambda_powertools/utilities/parameters/appconfig.py b/aws_lambda_powertools/utilities/parameters/appconfig.py index 4490e260364..63a8415f1ec 100644 --- a/aws_lambda_powertools/utilities/parameters/appconfig.py +++ b/aws_lambda_powertools/utilities/parameters/appconfig.py @@ -149,7 +149,7 @@ def get_app_config( >>> print(value) My configuration value - **Retrieves a confiugration value and decodes it using a JSON decoder** + **Retrieves a configuration value and decodes it using a JSON decoder** >>> from aws_lambda_powertools.utilities.parameters import get_parameter >>> diff --git a/aws_lambda_powertools/utilities/parser/models/apigw.py b/aws_lambda_powertools/utilities/parser/models/apigw.py index de968e20ecf..4de8ee96cc5 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw.py @@ -46,7 +46,7 @@ class APIGatewayEventAuthorizer(BaseModel): class APIGatewayEventRequestContext(BaseModel): accountId: str apiId: str - authorizer: APIGatewayEventAuthorizer + authorizer: Optional[APIGatewayEventAuthorizer] stage: str protocol: str identity: APIGatewayEventIdentity @@ -70,7 +70,7 @@ class APIGatewayEventRequestContext(BaseModel): class APIGatewayProxyEventModel(BaseModel): - version: str + version: Optional[str] resource: str path: str httpMethod: Literal["DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] diff --git a/aws_lambda_powertools/utilities/validation/base.py b/aws_lambda_powertools/utilities/validation/base.py index ec4165f4876..b818f11a40e 100644 --- a/aws_lambda_powertools/utilities/validation/base.py +++ b/aws_lambda_powertools/utilities/validation/base.py @@ -1,9 +1,9 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union -import fastjsonschema +import fastjsonschema # type: ignore import jmespath -from jmespath.exceptions import LexerError +from jmespath.exceptions import LexerError # type: ignore from aws_lambda_powertools.shared.jmespath_functions import PowertoolsFunctions @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def validate_data_against_schema(data: Dict, schema: Dict, formats: Optional[Dict] = None): +def validate_data_against_schema(data: Union[Dict, str], schema: Dict, formats: Optional[Dict] = None): """Validate dict data against given JSON Schema Parameters @@ -41,7 +41,7 @@ def validate_data_against_schema(data: Dict, schema: Dict, formats: Optional[Dic raise SchemaValidationError(message) -def unwrap_event_from_envelope(data: Dict, envelope: str, jmespath_options: Optional[Dict]) -> Any: +def unwrap_event_from_envelope(data: Union[Dict, str], envelope: str, jmespath_options: Optional[Dict]) -> Any: """Searches data using JMESPath expression Parameters diff --git a/aws_lambda_powertools/utilities/validation/validator.py b/aws_lambda_powertools/utilities/validation/validator.py index 3628d486eb3..0497a49a714 100644 --- a/aws_lambda_powertools/utilities/validation/validator.py +++ b/aws_lambda_powertools/utilities/validation/validator.py @@ -12,12 +12,12 @@ def validator( handler: Callable, event: Union[Dict, str], context: Any, - inbound_schema: Dict = None, + inbound_schema: Optional[Dict] = None, inbound_formats: Optional[Dict] = None, - outbound_schema: Dict = None, + outbound_schema: Optional[Dict] = None, outbound_formats: Optional[Dict] = None, - envelope: str = None, - jmespath_options: Dict = None, + envelope: Optional[str] = None, + jmespath_options: Optional[Dict] = None, ) -> Any: """Lambda handler decorator to validate incoming/outbound data using a JSON Schema @@ -135,8 +135,8 @@ def validate( event: Any, schema: Dict, formats: Optional[Dict] = None, - envelope: str = None, - jmespath_options: Dict = None, + envelope: Optional[str] = None, + jmespath_options: Optional[Dict] = None, ): """Standalone function to validate event data using a JSON Schema diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 87263062f4f..a87eefbd5cd 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -357,7 +357,7 @@ You can access the raw payload via `body` property, or if it's a JSON string you #### Headers -Similarly to [Query strings](#query-strings), you can access headers as dictionary via `app.current_event.headers`, or by name via `get_header_value`. +Similarly to [Query strings](#query-strings-and-payload), you can access headers as dictionary via `app.current_event.headers`, or by name via `get_header_value`. === "app.py" @@ -377,6 +377,66 @@ Similarly to [Query strings](#query-strings), you can access headers as dictiona return app.resolve(event, context) ``` +### Raising HTTP errors + +You can easily raise any HTTP Error back to the client using `ServiceError` exception. + +!!! info "If you need to send custom headers, use [Response](#fine-grained-responses) class instead." + +Additionally, we provide pre-defined errors for the most popular ones such as HTTP 400, 401, 404, 500. + + +=== "app.py" + + ```python hl_lines="4-10 20 25 30 35 39" + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + from aws_lambda_powertools.event_handler.exceptions import ( + BadRequestError, + InternalServerError, + NotFoundError, + ServiceError, + UnauthorizedError, + ) + + tracer = Tracer() + logger = Logger() + + app = ApiGatewayResolver() + + @app.get(rule="/bad-request-error") + def bad_request_error(): + # HTTP 400 + raise BadRequestError("Missing required parameter") + + @app.get(rule="/unauthorized-error") + def unauthorized_error(): + # HTTP 401 + raise UnauthorizedError("Unauthorized") + + @app.get(rule="/not-found-error") + def not_found_error(): + # HTTP 404 + raise NotFoundError + + @app.get(rule="/internal-server-error") + def internal_server_error(): + # HTTP 500 + raise InternalServerError("Internal server error") + + @app.get(rule="/service-error", cors=True) + def service_error(): + raise ServiceError(502, "Something went wrong!") + # alternatively + # from http import HTTPStatus + # raise ServiceError(HTTPStatus.BAD_GATEWAY.value, "Something went wrong) + + def handler(event, context): + return app.resolve(event, context) + ``` + + ## Advanced ### CORS @@ -401,7 +461,7 @@ This will ensure that CORS headers are always returned as part of the response w @app.get("/hello/") @tracer.capture_method def get_hello_you(name): - return {"message": f"hello {name}}"} + return {"message": f"hello {name}"} @app.get("/hello", cors=False) # optionally exclude CORS from response, if needed @tracer.capture_method @@ -647,6 +707,30 @@ Like `compress` feature, the client must send the `Accept` header with the corre } ``` +### Debug mode + +You can enable debug mode via `debug` param, or via `POWERTOOLS_EVENT_HANDLER_DEBUG` [environment variable](../../index.md#environment-variables). + +This will enable full tracebacks errors in the response, print request and responses, and set CORS in development mode. + +!!! warning "This might reveal sensitive information in your logs and relax CORS restrictions, use it sparingly." + +=== "debug.py" + + ```python hl_lines="3" + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + + app = ApiGatewayResolver(debug=True) + + @app.get("/hello") + def get_hello_universe(): + return {"message": "hello universe"} + + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + + ## Testing your code You can test your routes by passing a proxy event request where `path` and `httpMethod`. diff --git a/docs/core/event_handler/appsync.md b/docs/core/event_handler/appsync.md index 67ad1999285..a47b8a4c641 100644 --- a/docs/core/event_handler/appsync.md +++ b/docs/core/event_handler/appsync.md @@ -598,6 +598,118 @@ Use the following code for `merchantInfo` and `searchMerchant` functions respect } ``` +### Custom data models + +You can subclass `AppSyncResolverEvent` to bring your own set of methods to handle incoming events, by using `data_model` param in the `resolve` method. + + +=== "custom_model.py" + + ```python hl_lines="11-14 19 26" + from aws_lambda_powertools import Logger, Tracer + + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler import AppSyncResolver + + tracer = Tracer(service="sample_resolver") + logger = Logger(service="sample_resolver") + app = AppSyncResolver() + + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self) -> str: + return self.request_headers.get("cloudfront-viewer-country") + + @app.resolver(field_name="listLocations") + @app.resolver(field_name="locations") + def get_locations(name: str, description: str = ""): + if app.current_event.country_viewer == "US": + ... + return name + description + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context, data_model=MyCustomModel) + ``` + +=== "schema.graphql" + + ```typescript hl_lines="6 20" + schema { + query: Query + } + + type Query { + listLocations: [Location] + } + + type Location { + id: ID! + name: String! + description: String + address: String + } + + type Merchant { + id: String! + name: String! + description: String + locations: [Location] + } + ``` + +=== "listLocations_event.json" + + ```json + { + "arguments": {}, + "identity": null, + "source": null, + "request": { + "headers": { + "x-forwarded-for": "1.2.3.4, 5.6.7.8", + "accept-encoding": "gzip, deflate, br", + "cloudfront-viewer-country": "NL", + "cloudfront-is-tablet-viewer": "false", + "referer": "https://eu-west-1.console.aws.amazon.com/appsync/home?region=eu-west-1", + "via": "2.0 9fce949f3749407c8e6a75087e168b47.cloudfront.net (CloudFront)", + "cloudfront-forwarded-proto": "https", + "origin": "https://eu-west-1.console.aws.amazon.com", + "x-api-key": "da1-c33ullkbkze3jg5hf5ddgcs4fq", + "content-type": "application/json", + "x-amzn-trace-id": "Root=1-606eb2f2-1babc433453a332c43fb4494", + "x-amz-cf-id": "SJw16ZOPuMZMINx5Xcxa9pB84oMPSGCzNOfrbJLvd80sPa0waCXzYQ==", + "content-length": "114", + "x-amz-user-agent": "AWS-Console-AppSync/", + "x-forwarded-proto": "https", + "host": "ldcvmkdnd5az3lm3gnf5ixvcyy.appsync-api.eu-west-1.amazonaws.com", + "accept-language": "en-US,en;q=0.5", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:78.0) Gecko/20100101 Firefox/78.0", + "cloudfront-is-desktop-viewer": "true", + "cloudfront-is-mobile-viewer": "false", + "accept": "*/*", + "x-forwarded-port": "443", + "cloudfront-is-smarttv-viewer": "false" + } + }, + "prev": null, + "info": { + "parentTypeName": "Query", + "selectionSetList": [ + "id", + "name", + "description" + ], + "selectionSetGraphQL": "{\n id\n name\n description\n}", + "fieldName": "listLocations", + "variables": {} + }, + "stash": {} + } + ``` + ## Testing your code You can test your resolvers by passing a mocked or actual AppSync Lambda event that you're expecting. diff --git a/docs/core/logger.md b/docs/core/logger.md index 43d367e171b..53818bada51 100644 --- a/docs/core/logger.md +++ b/docs/core/logger.md @@ -146,6 +146,8 @@ When debugging in non-production environments, you can instruct Logger to log th You can set a Correlation ID using `correlation_id_path` param by passing a [JMESPath expression](https://jmespath.org/tutorial.html){target="_blank"}. +!!! tip "You can retrieve correlation IDs via `get_correlation_id` method" + === "collect.py" ```python hl_lines="5" @@ -155,6 +157,7 @@ You can set a Correlation ID using `correlation_id_path` param by passing a [JME @logger.inject_lambda_context(correlation_id_path="headers.my_request_id_header") def handler(event, context): + logger.debug(f"Correlation ID => {logger.get_correlation_id()}") logger.info("Collecting payment") ``` @@ -198,6 +201,7 @@ We provide [built-in JMESPath expressions](#built-in-correlation-id-expressions) @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) def handler(event, context): + logger.debug(f"Correlation ID => {logger.get_correlation_id()}") logger.info("Collecting payment") ``` diff --git a/docs/index.md b/docs/index.md index ce0573915fd..104ed1d85d6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -229,6 +229,7 @@ aws serverlessrepo list-application-versions \ | **POWERTOOLS_LOGGER_LOG_EVENT** | Logs incoming event | [Logging](./core/logger) | `false` | | **POWERTOOLS_LOGGER_SAMPLE_RATE** | Debug log sampling | [Logging](./core/logger) | `0` | | **POWERTOOLS_LOG_DEDUPLICATION_DISABLED** | Disables log deduplication filter protection to use Pytest Live Log feature | [Logging](./core/logger) | `false` | +| **POWERTOOLS_EVENT_HANDLER_DEBUG** | Enables debugging mode for event handler | [Event Handler](./core/event_handler/api_gateway.md#debug-mode) | `false` | | **LOG_LEVEL** | Sets logging level | [Logging](./core/logger) | `INFO` | ## Debug mode diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000000..2436d7074d2 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,27 @@ +[mypy] +warn_return_any=False +warn_unused_configs=True +no_implicit_optional=True +warn_redundant_casts=True +warn_unused_ignores=True +show_column_numbers = True +show_error_codes = True +show_error_context = True + +[mypy-jmespath] +ignore_missing_imports=True + +[mypy-boto3] +ignore_missing_imports = True + +[mypy-boto3.dynamodb.conditions] +ignore_missing_imports = True + +[mypy-botocore.config] +ignore_missing_imports = True + +[mypy-botocore.exceptions] +ignore_missing_imports = True + +[mypy-aws_xray_sdk.ext.aiohttp.client] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 9a6f1fdbc04..d15b11daae8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -81,24 +81,24 @@ d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] name = "boto3" -version = "1.17.102" +version = "1.18.1" description = "The AWS SDK for Python" category = "main" optional = false -python-versions = ">= 2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = ">= 3.6" [package.dependencies] -botocore = ">=1.20.102,<1.21.0" +botocore = ">=1.21.1,<1.22.0" jmespath = ">=0.7.1,<1.0.0" -s3transfer = ">=0.4.0,<0.5.0" +s3transfer = ">=0.5.0,<0.6.0" [[package]] name = "botocore" -version = "1.20.102" +version = "1.21.1" description = "Low-level, data-driven core of boto 3." category = "main" optional = false -python-versions = ">= 2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = ">= 3.6" [package.dependencies] jmespath = ">=0.7.1,<1.0.0" @@ -413,7 +413,7 @@ python-versions = "*" [[package]] name = "isort" -version = "5.9.1" +version = "5.9.2" description = "A Python utility / library to sort Python imports." category = "dev" optional = false @@ -592,7 +592,7 @@ mkdocs = ">=0.17" [[package]] name = "mkdocs-material" -version = "7.1.9" +version = "7.1.11" description = "A Material Design theme for MkDocs" category = "dev" optional = false @@ -616,6 +616,24 @@ python-versions = ">=3.5" [package.dependencies] mkdocs-material = ">=5.0.0" +[[package]] +name = "mypy" +version = "0.910" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +mypy-extensions = ">=0.4.3,<0.5.0" +toml = "*" +typed-ast = {version = ">=1.4.0,<1.5.0", markers = "python_version < \"3.8\""} +typing-extensions = ">=3.7.4" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<1.5.0)"] + [[package]] name = "mypy-extensions" version = "0.4.3" @@ -916,11 +934,11 @@ python-versions = "*" [[package]] name = "s3transfer" -version = "0.4.2" +version = "0.5.0" description = "An Amazon S3 Transfer Manager" category = "main" optional = false -python-versions = "*" +python-versions = ">= 3.6" [package.dependencies] botocore = ">=1.12.36,<2.0a.0" @@ -1066,7 +1084,7 @@ pydantic = ["pydantic", "email-validator"] [metadata] lock-version = "1.1" python-versions = "^3.6.1" -content-hash = "c07aad70013171c8bb000d163f997efef709189584089f68d471f69eb8bb38c0" +content-hash = "1e91beb4537c7042746d638f86154a664cbb840c1a43b2e902586a1dc7b0b9c2" [metadata.files] appdirs = [ @@ -1093,12 +1111,12 @@ black = [ {file = "black-20.8b1.tar.gz", hash = "sha256:1c02557aa099101b9d21496f8a914e9ed2222ef70336404eeeac8edba836fbea"}, ] boto3 = [ - {file = "boto3-1.17.102-py2.py3-none-any.whl", hash = "sha256:6300e9ee9a404038113250bd218e2c4827f5e676efb14e77de2ad2dcb67679bc"}, - {file = "boto3-1.17.102.tar.gz", hash = "sha256:be4714f0475c1f5183eea09ddbf568ced6fa41b0fc9976f2698b8442e1b17303"}, + {file = "boto3-1.18.1-py3-none-any.whl", hash = "sha256:a6399df957bfc7944fbd97e9fb0755cba29b1cb135b91d7e43fd298b268ab804"}, + {file = "boto3-1.18.1.tar.gz", hash = "sha256:ddfe4a78f04cd2d3a7a37d5cdfa07b4889b24296508786969bc968bee6b8b003"}, ] botocore = [ - {file = "botocore-1.20.102-py2.py3-none-any.whl", hash = "sha256:bdf08a4f7f01ead00d386848f089c08270499711447569c18d0db60023619c06"}, - {file = "botocore-1.20.102.tar.gz", hash = "sha256:2f57f7ceed1598d96cc497aeb45317db5d3b21a5aafea4732d0e561d0fc2a8fa"}, + {file = "botocore-1.21.1-py3-none-any.whl", hash = "sha256:b845220eb580d10f7714798a96e380eb8f94dca89905a41d8a3c35119c757b01"}, + {file = "botocore-1.21.1.tar.gz", hash = "sha256:200887ce5f3b47d7499b7ded75dc65c4649abdaaddd06cebc118a3a954d6fd73"}, ] certifi = [ {file = "certifi-2020.12.5-py2.py3-none-any.whl", hash = "sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830"}, @@ -1256,8 +1274,8 @@ iniconfig = [ {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] isort = [ - {file = "isort-5.9.1-py3-none-any.whl", hash = "sha256:8e2c107091cfec7286bc0f68a547d0ba4c094d460b732075b6fba674f1035c0c"}, - {file = "isort-5.9.1.tar.gz", hash = "sha256:83510593e07e433b77bd5bff0f6f607dbafa06d1a89022616f02d8b699cfcd56"}, + {file = "isort-5.9.2-py3-none-any.whl", hash = "sha256:eed17b53c3e7912425579853d078a0832820f023191561fcee9d7cae424e0813"}, + {file = "isort-5.9.2.tar.gz", hash = "sha256:f65ce5bd4cbc6abdfbe29afc2f0245538ab358c14590912df638033f157d555e"}, ] jinja2 = [ {file = "Jinja2-3.0.1-py3-none-any.whl", hash = "sha256:1f06f2da51e7b56b8f238affdd6b4e2c61e39598a378cc49345bc1bd42a978a4"}, @@ -1343,13 +1361,38 @@ mkdocs-git-revision-date-plugin = [ {file = "mkdocs_git_revision_date_plugin-0.3.1-py3-none-any.whl", hash = "sha256:8ae50b45eb75d07b150a69726041860801615aae5f4adbd6b1cf4d51abaa03d5"}, ] mkdocs-material = [ - {file = "mkdocs-material-7.1.9.tar.gz", hash = "sha256:5a2fd487f769f382a7c979e869e4eab1372af58d7dec44c4365dd97ef5268cb5"}, - {file = "mkdocs_material-7.1.9-py2.py3-none-any.whl", hash = "sha256:92c8a2bd3bd44d5948eefc46ba138e2d3285cac658900112b6bf5722c7d067a5"}, + {file = "mkdocs-material-7.1.11.tar.gz", hash = "sha256:cad3a693f1c28823370578e5b9c9aea418bddae0c7348ab734537391e9f2b1e5"}, + {file = "mkdocs_material-7.1.11-py2.py3-none-any.whl", hash = "sha256:0bcfb788020b72b0ebf5b2722ddf89534acaed8c3feb39c2d6dda239b49dec45"}, ] mkdocs-material-extensions = [ {file = "mkdocs-material-extensions-1.0.1.tar.gz", hash = "sha256:6947fb7f5e4291e3c61405bad3539d81e0b3cd62ae0d66ced018128af509c68f"}, {file = "mkdocs_material_extensions-1.0.1-py3-none-any.whl", hash = "sha256:d90c807a88348aa6d1805657ec5c0b2d8d609c110e62b9dce4daf7fa981fa338"}, ] +mypy = [ + {file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"}, + {file = "mypy-0.910-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb"}, + {file = "mypy-0.910-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9"}, + {file = "mypy-0.910-cp35-cp35m-win_amd64.whl", hash = "sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e"}, + {file = "mypy-0.910-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921"}, + {file = "mypy-0.910-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6"}, + {file = "mypy-0.910-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212"}, + {file = "mypy-0.910-cp36-cp36m-win_amd64.whl", hash = "sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885"}, + {file = "mypy-0.910-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0"}, + {file = "mypy-0.910-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de"}, + {file = "mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703"}, + {file = "mypy-0.910-cp37-cp37m-win_amd64.whl", hash = "sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a"}, + {file = "mypy-0.910-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504"}, + {file = "mypy-0.910-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9"}, + {file = "mypy-0.910-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072"}, + {file = "mypy-0.910-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811"}, + {file = "mypy-0.910-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e"}, + {file = "mypy-0.910-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b"}, + {file = "mypy-0.910-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2"}, + {file = "mypy-0.910-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97"}, + {file = "mypy-0.910-cp39-cp39-win_amd64.whl", hash = "sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8"}, + {file = "mypy-0.910-py3-none-any.whl", hash = "sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d"}, + {file = "mypy-0.910.tar.gz", hash = "sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150"}, +] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, @@ -1565,8 +1608,8 @@ requests = [ {file = "ruamel.yaml.clib-0.2.2.tar.gz", hash = "sha256:2d24bd98af676f4990c4d715bcdc2a60b19c56a3fb3a763164d2d8ca0e806ba7"}, ] s3transfer = [ - {file = "s3transfer-0.4.2-py2.py3-none-any.whl", hash = "sha256:9b3752887a2880690ce628bc263d6d13a3864083aeacff4890c1c9839a5eb0bc"}, - {file = "s3transfer-0.4.2.tar.gz", hash = "sha256:cb022f4b16551edebbb31a377d3f09600dbada7363d8c5db7976e7f47732e1b2"}, + {file = "s3transfer-0.5.0-py3-none-any.whl", hash = "sha256:9c1dc369814391a6bda20ebbf4b70a0f34630592c9aa520856bf384916af2803"}, + {file = "s3transfer-0.5.0.tar.gz", hash = "sha256:50ed823e1dc5868ad40c8dc92072f757aa0e653a192845c94a3b676f4a62da4c"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, diff --git a/pyproject.toml b/pyproject.toml index f7743f2eba2..e1200887938 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws_lambda_powertools" -version = "1.17.1" +version = "1.18.0" description = "Python utilities for AWS Lambda functions including but not limited to tracing, logging and custom metric" authors = ["Amazon Web Services"] include = ["aws_lambda_powertools/py.typed"] @@ -39,7 +39,7 @@ flake8-debugger = "^4.0.0" flake8-fixme = "^1.1.1" flake8-isort = "^4.0.0" flake8-variables-names = "^0.0.4" -isort = "^5.9.1" +isort = "^5.9.2" pytest-cov = "^2.12.1" pytest-mock = "^3.5.1" pdoc3 = "^0.9.2" @@ -49,9 +49,10 @@ radon = "^4.5.0" xenon = "^0.7.3" flake8-eradicate = "^1.1.0" flake8-bugbear = "^21.3.2" -mkdocs-material = "^7.1.9" +mkdocs-material = "^7.1.11" mkdocs-git-revision-date-plugin = "^0.3.1" mike = "^0.6.0" +mypy = "^0.910" [tool.poetry.extras] diff --git a/tests/events/apiGatewayProxyEvent_noVersionAuth.json b/tests/events/apiGatewayProxyEvent_noVersionAuth.json new file mode 100644 index 00000000000..055301f8f15 --- /dev/null +++ b/tests/events/apiGatewayProxyEvent_noVersionAuth.json @@ -0,0 +1,75 @@ +{ + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": true +} diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index caaaeb1b97b..1c7c53f7187 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1,10 +1,14 @@ import base64 import json import zlib +from copy import deepcopy from decimal import Decimal from pathlib import Path from typing import Dict +import pytest + +from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.api_gateway import ( ApiGatewayResolver, CORSConfig, @@ -12,6 +16,14 @@ Response, ResponseBuilder, ) +from aws_lambda_powertools.event_handler.exceptions import ( + BadRequestError, + InternalServerError, + NotFoundError, + ServiceError, + UnauthorizedError, +) +from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from tests.functional.utils import load_event @@ -23,8 +35,6 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") -TEXT_HTML = "text/html" -APPLICATION_JSON = "application/json" def test_alb_event(): @@ -35,7 +45,7 @@ def test_alb_event(): def foo(): assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} - return Response(200, TEXT_HTML, "foo") + return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler result = app(load_event("albEvent.json"), {}) @@ -43,7 +53,7 @@ def foo(): # THEN process event correctly # AND set the current_event type as ALBEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "foo" @@ -55,7 +65,7 @@ def test_api_gateway_v1(): def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} - return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"})) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -63,7 +73,7 @@ def get_lambda() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == APPLICATION_JSON + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON def test_api_gateway(): @@ -73,7 +83,7 @@ def test_api_gateway(): @app.get("/my/path") def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) - return Response(200, TEXT_HTML, "foo") + return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -81,7 +91,7 @@ def get_lambda() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "foo" @@ -93,7 +103,7 @@ def test_api_gateway_v2(): def my_path() -> Response: assert isinstance(app.current_event, APIGatewayProxyEventV2) post_data = app.current_event.json_body - return Response(200, "plain/text", post_data["username"]) + return Response(200, content_types.TEXT_PLAIN, post_data["username"]) # WHEN calling the event handler result = app(load_event("apiGatewayProxyV2Event.json"), {}) @@ -101,7 +111,7 @@ def my_path() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEventV2 assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "plain/text" + assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN assert result["body"] == "tom" @@ -112,14 +122,14 @@ def test_include_rule_matching(): @app.get("//") def get_lambda(my_id: str, name: str) -> Response: assert name == "my" - return Response(200, TEXT_HTML, my_id) + return Response(200, content_types.TEXT_HTML, my_id) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "path" @@ -180,11 +190,11 @@ def test_cors(): @app.get("/my/path", cors=True) def with_cors() -> Response: - return Response(200, TEXT_HTML, "test") + return Response(200, content_types.TEXT_HTML, "test") @app.get("/without-cors") def without_cors() -> Response: - return Response(200, TEXT_HTML, "test") + return Response(200, content_types.TEXT_HTML, "test") def handler(event, context): return app.resolve(event, context) @@ -195,7 +205,7 @@ def handler(event, context): # THEN the headers should include cors headers assert "headers" in result headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) @@ -215,7 +225,7 @@ def test_compress(): @app.get("/my/request", compress=True) def with_compression() -> Response: - return Response(200, APPLICATION_JSON, expected_value) + return Response(200, content_types.APPLICATION_JSON, expected_value) def handler(event, context): return app.resolve(event, context) @@ -261,7 +271,7 @@ def test_compress_no_accept_encoding(): @app.get("/my/path", compress=True) def return_text() -> Response: - return Response(200, "text/plain", expected_value) + return Response(200, content_types.TEXT_PLAIN, expected_value) # WHEN calling the event handler result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) @@ -277,7 +287,7 @@ def test_cache_control_200(): @app.get("/success", cache_control="max-age=600") def with_cache_control() -> Response: - return Response(200, TEXT_HTML, "has 200 response") + return Response(200, content_types.TEXT_HTML, "has 200 response") def handler(event, context): return app.resolve(event, context) @@ -288,7 +298,7 @@ def handler(event, context): # THEN return the set Cache-Control headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Cache-Control"] == "max-age=600" @@ -298,7 +308,7 @@ def test_cache_control_non_200(): @app.delete("/fails", cache_control="max-age=600") def with_cache_control_has_500() -> Response: - return Response(503, TEXT_HTML, "has 503 response") + return Response(503, content_types.TEXT_HTML, "has 503 response") def handler(event, context): return app.resolve(event, context) @@ -309,7 +319,7 @@ def handler(event, context): # THEN return a Cache-Control of "no-cache" headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Cache-Control"] == "no-cache" @@ -327,7 +337,7 @@ def rest_func() -> Dict: # THEN automatically process this as a json rest api response assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == APPLICATION_JSON + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder) assert result["body"] == expected_str @@ -382,7 +392,7 @@ def another_one(): # THEN routes by default return the custom cors headers assert "headers" in result headers = result["headers"] - assert headers["Content-Type"] == APPLICATION_JSON + assert headers["Content-Type"] == content_types.APPLICATION_JSON assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS))) assert headers["Access-Control-Allow-Headers"] == expected_allows_headers @@ -429,6 +439,7 @@ def test_no_matches_with_cors(): # AND cors headers are returned assert result["statusCode"] == 404 assert "Access-Control-Allow-Origin" in result["headers"] + assert "Not found" in result["body"] def test_cors_preflight(): @@ -471,7 +482,7 @@ def test_custom_preflight_response(): def custom_preflight(): return Response( status_code=200, - content_type=TEXT_HTML, + content_type=content_types.TEXT_HTML, body="Foo", headers={"Access-Control-Allow-Methods": "CUSTOM"}, ) @@ -487,6 +498,206 @@ def custom_method(): assert result["statusCode"] == 200 assert result["body"] == "Foo" headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert "Access-Control-Allow-Origin" in result["headers"] assert headers["Access-Control-Allow-Methods"] == "CUSTOM" + + +def test_service_error_responses(): + # SCENARIO handling different kind of service errors being raised + app = ApiGatewayResolver(cors=CORSConfig()) + + def json_dump(obj): + return json.dumps(obj, separators=(",", ":")) + + # GIVEN an BadRequestError + @app.get(rule="/bad-request-error", cors=False) + def bad_request_error(): + raise BadRequestError("Missing required parameter") + + # WHEN calling the handler + # AND path is /bad-request-error + result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None) + # THEN return the bad request error response + # AND status code equals 400 + assert result["statusCode"] == 400 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 400, "message": "Missing required parameter"} + assert result["body"] == json_dump(expected) + + # GIVEN an UnauthorizedError + @app.get(rule="/unauthorized-error", cors=False) + def unauthorized_error(): + raise UnauthorizedError("Unauthorized") + + # WHEN calling the handler + # AND path is /unauthorized-error + result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None) + # THEN return the unauthorized error response + # AND status code equals 401 + assert result["statusCode"] == 401 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 401, "message": "Unauthorized"} + assert result["body"] == json_dump(expected) + + # GIVEN an NotFoundError + @app.get(rule="/not-found-error", cors=False) + def not_found_error(): + raise NotFoundError + + # WHEN calling the handler + # AND path is /not-found-error + result = app({"path": "/not-found-error", "httpMethod": "GET"}, None) + # THEN return the not found error response + # AND status code equals 404 + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 404, "message": "Not found"} + assert result["body"] == json_dump(expected) + + # GIVEN an InternalServerError + @app.get(rule="/internal-server-error", cors=False) + def internal_server_error(): + raise InternalServerError("Internal server error") + + # WHEN calling the handler + # AND path is /internal-server-error + result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None) + # THEN return the internal server error response + # AND status code equals 500 + assert result["statusCode"] == 500 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 500, "message": "Internal server error"} + assert result["body"] == json_dump(expected) + + # GIVEN an ServiceError with a custom status code + @app.get(rule="/service-error", cors=True) + def service_error(): + raise ServiceError(502, "Something went wrong!") + + # WHEN calling the handler + # AND path is /service-error + result = app({"path": "/service-error", "httpMethod": "GET"}, None) + # THEN return the service error response + # AND status code equals 502 + assert result["statusCode"] == 502 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert "Access-Control-Allow-Origin" in result["headers"] + expected = {"statusCode": 502, "message": "Something went wrong!"} + assert result["body"] == json_dump(expected) + + +def test_debug_unhandled_exceptions_debug_on(): + # GIVEN debug is enabled + # AND an unhandled exception is raised + app = ApiGatewayResolver(debug=True) + assert app._debug + + @app.get("/raises-error") + def raises_error(): + raise RuntimeError("Foo") + + # WHEN calling the handler + result = app({"path": "/raises-error", "httpMethod": "GET"}, None) + + # THEN return a 500 + # AND Content-Type is set to text/plain + # AND include the exception traceback in the response + assert result["statusCode"] == 500 + assert "Traceback (most recent call last)" in result["body"] + headers = result["headers"] + assert headers["Content-Type"] == content_types.TEXT_PLAIN + + +def test_debug_unhandled_exceptions_debug_off(): + # GIVEN debug is disabled + # AND an unhandled exception is raised + app = ApiGatewayResolver(debug=False) + assert not app._debug + + @app.get("/raises-error") + def raises_error(): + raise RuntimeError("Foo") + + # WHEN calling the handler + # THEN raise the original exception + with pytest.raises(RuntimeError) as e: + app({"path": "/raises-error", "httpMethod": "GET"}, None) + + # AND include the original error + assert e.value.args == ("Foo",) + + +def test_debug_mode_environment_variable(monkeypatch): + # GIVEN a debug mode environment variable is set + monkeypatch.setenv(constants.EVENT_HANDLER_DEBUG_ENV, "true") + app = ApiGatewayResolver() + + # WHEN calling app._debug + # THEN the debug mode is enabled + assert app._debug + + +def test_debug_json_formatting(): + # GIVEN debug is True + app = ApiGatewayResolver(debug=True) + response = {"message": "Foo"} + + @app.get("/foo") + def foo(): + return response + + # WHEN calling the handler + result = app({"path": "/foo", "httpMethod": "GET"}, None) + + # THEN return a pretty print json in the body + assert result["body"] == json.dumps(response, indent=4) + + +def test_debug_print_event(capsys): + # GIVE debug is True + app = ApiGatewayResolver(debug=True) + + # WHEN calling resolve + event = {"path": "/foo", "httpMethod": "GET"} + app(event, None) + + # THEN print the event + out, err = capsys.readouterr() + assert json.loads(out) == event + + +def test_similar_dynamic_routes(): + # GIVEN + app = ApiGatewayResolver() + event = deepcopy(LOAD_GW_EVENT) + + # WHEN + # r'^/accounts/(?P\\w+\\b)$' # noqa: E800 + @app.get("/accounts/") + def get_account(account_id: str): + assert account_id == "single_account" + + # r'^/accounts/(?P\\w+\\b)/source_networks$' # noqa: E800 + @app.get("/accounts//source_networks") + def get_account_networks(account_id: str): + assert account_id == "nested_account" + + # r'^/accounts/(?P\\w+\\b)/source_networks/(?P\\w+\\b)$' # noqa: E800 + @app.get("/accounts//source_networks/") + def get_network_account(account_id: str, network_id: str): + assert account_id == "nested_account" + assert network_id == "network" + + # THEN + event["resource"] = "/accounts/{account_id}" + event["path"] = "/accounts/single_account" + app.resolve(event, None) + + event["resource"] = "/accounts/{account_id}/source_networks" + event["path"] = "/accounts/nested_account/source_networks" + app.resolve(event, None) + + event["resource"] = "/accounts/{account_id}/source_networks/{network_id}" + event["path"] = "/accounts/nested_account/source_networks/network" + app.resolve(event, {}) diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index e260fef89ab..26a3ffdcb1f 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -138,3 +138,26 @@ async def get_async(): # THEN assert asyncio.run(result) == "value" + + +def test_resolve_custom_data_model(): + # Check whether we can handle an example appsync direct resolver + mock_event = load_event("appSyncDirectResolver.json") + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self): + return self.request_headers.get("cloudfront-viewer-country") + + app = AppSyncResolver() + + @app.resolver(field_name="createSomething") + def create_something(id: str): # noqa AA03 VNE003 + return id + + # Call the implicit handler + result = app(event=mock_event, context=LambdaContext(), data_model=MyCustomModel) + + assert result == "my identifier" + + assert app.current_event.country_viewer == "US" diff --git a/tests/functional/feature_toggles/__init__.py b/tests/functional/feature_toggles/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/feature_toggles/test_feature_toggles.py b/tests/functional/feature_toggles/test_feature_toggles.py new file mode 100644 index 00000000000..bb4b8f24dfc --- /dev/null +++ b/tests/functional/feature_toggles/test_feature_toggles.py @@ -0,0 +1,503 @@ +from typing import Dict, List + +import pytest +from botocore.config import Config + +from aws_lambda_powertools.utilities.feature_toggles import ConfigurationError, schema +from aws_lambda_powertools.utilities.feature_toggles.appconfig_fetcher import AppConfigFetcher +from aws_lambda_powertools.utilities.feature_toggles.configuration_store import ConfigurationStore +from aws_lambda_powertools.utilities.feature_toggles.schema import ACTION +from aws_lambda_powertools.utilities.parameters import GetParameterError + + +@pytest.fixture(scope="module") +def config(): + return Config(region_name="us-east-1") + + +def init_configuration_store(mocker, mock_schema: Dict, config: Config) -> ConfigurationStore: + mocked_get_conf = mocker.patch("aws_lambda_powertools.utilities.parameters.AppConfigProvider.get") + mocked_get_conf.return_value = mock_schema + + app_conf_fetcher = AppConfigFetcher( + environment="test_env", + service="test_app", + configuration_name="test_conf_name", + cache_seconds=600, + config=config, + ) + conf_store: ConfigurationStore = ConfigurationStore(schema_fetcher=app_conf_fetcher) + return conf_store + + +def init_fetcher_side_effect(mocker, config: Config, side_effect) -> AppConfigFetcher: + mocked_get_conf = mocker.patch("aws_lambda_powertools.utilities.parameters.AppConfigProvider.get") + mocked_get_conf.side_effect = side_effect + return AppConfigFetcher( + environment="env", + service="service", + configuration_name="conf", + cache_seconds=1, + config=config, + ) + + +# this test checks that we get correct value of feature that exists in the schema. +# we also don't send an empty rules_context dict in this case +def test_toggles_rule_does_not_match(mocker, config): + expected_value = True + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": expected_value, + "rules": [ + { + "rule_name": "tenant id equals 345345435", + "value_when_applies": False, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "345345435", + } + ], + }, + ], + } + }, + } + + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle(feature_name="my_feature", rules_context={}, value_if_missing=False) + assert toggle == expected_value + + +# this test checks that if you try to get a feature that doesn't exist in the schema, +# you get the default value of False that was sent to the get_feature_toggle API +def test_toggles_no_conditions_feature_does_not_exist(mocker, config): + expected_value = False + mocked_app_config_schema = {"features": {"my_fake_feature": {"feature_default_value": True}}} + + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle(feature_name="my_feature", rules_context={}, value_if_missing=expected_value) + assert toggle == expected_value + + +# check that feature match works when they are no rules and we send rules_context. +# default value is False but the feature has a True default_value. +def test_toggles_no_rules(mocker, config): + expected_value = True + mocked_app_config_schema = {"features": {"my_feature": {"feature_default_value": expected_value}}} + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", rules_context={"tenant_id": "6", "username": "a"}, value_if_missing=False + ) + assert toggle == expected_value + + +# check a case where the feature exists but the rule doesn't match so we revert to the default value of the feature +def test_toggles_conditions_no_match(mocker, config): + expected_value = True + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": expected_value, + "rules": [ + { + "rule_name": "tenant id equals 345345435", + "value_when_applies": False, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "345345435", + } + ], + }, + ], + } + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "6", "username": "a"}, # rule will not match + value_if_missing=False, + ) + assert toggle == expected_value + + +# check that a rule can match when it has multiple conditions, see rule name for further explanation +def test_toggles_conditions_rule_match_equal_multiple_conditions(mocker, config): + expected_value = False + tenant_id_val = "6" + username_val = "a" + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": True, + "rules": [ + { + "rule_name": "tenant id equals 6 and username is a", + "value_when_applies": expected_value, + "conditions": [ + { + "action": ACTION.EQUALS.value, # this rule will match, it has multiple conditions + "key": "tenant_id", + "value": tenant_id_val, + }, + { + "action": ACTION.EQUALS.value, + "key": "username", + "value": username_val, + }, + ], + }, + ], + } + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={ + "tenant_id": tenant_id_val, + "username": username_val, + }, + value_if_missing=True, + ) + assert toggle == expected_value + + +# check a case when rule doesn't match and it has multiple conditions, +# different tenant id causes the rule to not match. +# default value of the feature in this case is True +def test_toggles_conditions_no_rule_match_equal_multiple_conditions(mocker, config): + expected_val = True + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": expected_val, + "rules": [ + { + "rule_name": "tenant id equals 645654 and username is a", # rule will not match + "value_when_applies": False, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "645654", + }, + { + "action": ACTION.EQUALS.value, + "key": "username", + "value": "a", + }, + ], + }, + ], + } + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", rules_context={"tenant_id": "6", "username": "a"}, value_if_missing=False + ) + assert toggle == expected_val + + +# check rule match for multiple of action types +def test_toggles_conditions_rule_match_multiple_actions_multiple_rules_multiple_conditions(mocker, config): + expected_value_first_check = True + expected_value_second_check = False + expected_value_third_check = False + expected_value_fourth_case = False + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": expected_value_third_check, + "rules": [ + { + "rule_name": "tenant id equals 6 and username startswith a", + "value_when_applies": expected_value_first_check, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "6", + }, + { + "action": ACTION.STARTSWITH.value, + "key": "username", + "value": "a", + }, + ], + }, + { + "rule_name": "tenant id equals 4446 and username startswith a and endswith z", + "value_when_applies": expected_value_second_check, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "4446", + }, + { + "action": ACTION.STARTSWITH.value, + "key": "username", + "value": "a", + }, + { + "action": ACTION.ENDSWITH.value, + "key": "username", + "value": "z", + }, + ], + }, + ], + } + }, + } + + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + # match first rule + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "6", "username": "abcd"}, + value_if_missing=False, + ) + assert toggle == expected_value_first_check + # match second rule + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "4446", "username": "az"}, + value_if_missing=False, + ) + assert toggle == expected_value_second_check + # match no rule + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "11114446", "username": "ab"}, + value_if_missing=False, + ) + assert toggle == expected_value_third_check + # feature doesn't exist + toggle = conf_store.get_feature_toggle( + feature_name="my_fake_feature", + rules_context={"tenant_id": "11114446", "username": "ab"}, + value_if_missing=expected_value_fourth_case, + ) + assert toggle == expected_value_fourth_case + + +# check a case where the feature exists but the rule doesn't match so we revert to the default value of the feature +def test_toggles_match_rule_with_contains_action(mocker, config): + expected_value = True + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": False, + "rules": [ + { + "rule_name": "tenant id is contained in [6,2] ", + "value_when_applies": expected_value, + "conditions": [ + { + "action": ACTION.CONTAINS.value, + "key": "tenant_id", + "value": ["6", "2"], + } + ], + }, + ], + } + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "6", "username": "a"}, # rule will match + value_if_missing=False, + ) + assert toggle == expected_value + + +def test_toggles_no_match_rule_with_contains_action(mocker, config): + expected_value = False + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": expected_value, + "rules": [ + { + "rule_name": "tenant id is contained in [6,2] ", + "value_when_applies": True, + "conditions": [ + { + "action": ACTION.CONTAINS.value, + "key": "tenant_id", + "value": ["8", "2"], + } + ], + }, + ], + } + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + toggle = conf_store.get_feature_toggle( + feature_name="my_feature", + rules_context={"tenant_id": "6", "username": "a"}, # rule will not match + value_if_missing=False, + ) + assert toggle == expected_value + + +def test_multiple_features_enabled(mocker, config): + expected_value = ["my_feature", "my_feature2"] + mocked_app_config_schema = { + "features": { + "my_feature": { + "feature_default_value": False, + "rules": [ + { + "rule_name": "tenant id is contained in [6,2] ", + "value_when_applies": True, + "conditions": [ + { + "action": ACTION.CONTAINS.value, + "key": "tenant_id", + "value": ["6", "2"], + } + ], + }, + ], + }, + "my_feature2": { + "feature_default_value": True, + }, + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + enabled_list: List[str] = conf_store.get_all_enabled_feature_toggles( + rules_context={"tenant_id": "6", "username": "a"} + ) + assert enabled_list == expected_value + + +def test_multiple_features_only_some_enabled(mocker, config): + expected_value = ["my_feature", "my_feature2", "my_feature4"] + mocked_app_config_schema = { + "features": { + "my_feature": { # rule will match here, feature is enabled due to rule match + "feature_default_value": False, + "rules": [ + { + "rule_name": "tenant id is contained in [6,2] ", + "value_when_applies": True, + "conditions": [ + { + "action": ACTION.CONTAINS.value, + "key": "tenant_id", + "value": ["6", "2"], + } + ], + }, + ], + }, + "my_feature2": { + "feature_default_value": True, + }, + "my_feature3": { + "feature_default_value": False, + }, + "my_feature4": { # rule will not match here, feature is enabled by default + "feature_default_value": True, + "rules": [ + { + "rule_name": "tenant id equals 7", + "value_when_applies": False, + "conditions": [ + { + "action": ACTION.EQUALS.value, + "key": "tenant_id", + "value": "7", + } + ], + }, + ], + }, + }, + } + conf_store = init_configuration_store(mocker, mocked_app_config_schema, config) + enabled_list: List[str] = conf_store.get_all_enabled_feature_toggles( + rules_context={"tenant_id": "6", "username": "a"} + ) + assert enabled_list == expected_value + + +def test_get_feature_toggle_handles_error(mocker, config): + # GIVEN a schema fetch that raises a ConfigurationError + schema_fetcher = init_fetcher_side_effect(mocker, config, GetParameterError()) + conf_store = ConfigurationStore(schema_fetcher) + + # WHEN calling get_feature_toggle + toggle = conf_store.get_feature_toggle(feature_name="Foo", value_if_missing=False) + + # THEN handle the error and return the value_if_missing + assert toggle is False + + +def test_get_all_enabled_feature_toggles_handles_error(mocker, config): + # GIVEN a schema fetch that raises a ConfigurationError + schema_fetcher = init_fetcher_side_effect(mocker, config, GetParameterError()) + conf_store = ConfigurationStore(schema_fetcher) + + # WHEN calling get_all_enabled_feature_toggles + toggles = conf_store.get_all_enabled_feature_toggles(rules_context=None) + + # THEN handle the error and return an empty list + assert toggles == [] + + +def test_app_config_get_parameter_err(mocker, config): + # GIVEN an appconfig with a missing config + app_conf_fetcher = init_fetcher_side_effect(mocker, config, GetParameterError()) + + # WHEN calling get_json_configuration + with pytest.raises(ConfigurationError) as err: + app_conf_fetcher.get_json_configuration() + + # THEN raise ConfigurationError error + assert "AWS AppConfig configuration" in str(err.value) + + +def test_match_by_action_no_matching_action(mocker, config): + # GIVEN an unsupported action + conf_store = init_configuration_store(mocker, {}, config) + # WHEN calling _match_by_action + result = conf_store._match_by_action("Foo", None, "foo") + # THEN default to False + assert result is False + + +def test_match_by_action_attribute_error(mocker, config): + # GIVEN a startswith action and 2 integer + conf_store = init_configuration_store(mocker, {}, config) + # WHEN calling _match_by_action + result = conf_store._match_by_action(ACTION.STARTSWITH.value, 1, 100) + # THEN swallow the AttributeError and return False + assert result is False + + +def test_is_rule_matched_no_matches(mocker, config): + # GIVEN an empty list of conditions + rule = {schema.CONDITIONS_KEY: []} + rules_context = {} + conf_store = init_configuration_store(mocker, {}, config) + + # WHEN calling _is_rule_matched + result = conf_store._is_rule_matched("feature_name", rule, rules_context) + + # THEN return False + assert result is False diff --git a/tests/functional/feature_toggles/test_schema_validation.py b/tests/functional/feature_toggles/test_schema_validation.py new file mode 100644 index 00000000000..184f448322a --- /dev/null +++ b/tests/functional/feature_toggles/test_schema_validation.py @@ -0,0 +1,330 @@ +import logging + +import pytest # noqa: F401 + +from aws_lambda_powertools.utilities.feature_toggles.exceptions import ConfigurationError +from aws_lambda_powertools.utilities.feature_toggles.schema import ( + ACTION, + CONDITION_ACTION, + CONDITION_KEY, + CONDITION_VALUE, + CONDITIONS_KEY, + FEATURE_DEFAULT_VAL_KEY, + FEATURES_KEY, + RULE_DEFAULT_VALUE, + RULE_NAME_KEY, + RULES_KEY, + SchemaValidator, +) + +logger = logging.getLogger(__name__) + + +def test_invalid_features_dict(): + schema = {} + # empty dict + validator = SchemaValidator(logger) + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + schema = [] + # invalid type + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # invalid features key + schema = {FEATURES_KEY: []} + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + +def test_empty_features_not_fail(): + schema = {FEATURES_KEY: {}} + validator = SchemaValidator(logger) + validator.validate_json_schema(schema) + + +def test_invalid_feature_dict(): + # invalid feature type, not dict + schema = {FEATURES_KEY: {"my_feature": []}} + validator = SchemaValidator(logger) + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # empty feature dict + schema = {FEATURES_KEY: {"my_feature": {}}} + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # invalid FEATURE_DEFAULT_VAL_KEY type, not boolean + schema = {FEATURES_KEY: {"my_feature": {FEATURE_DEFAULT_VAL_KEY: "False"}}} + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # invalid FEATURE_DEFAULT_VAL_KEY type, not boolean #2 + schema = {FEATURES_KEY: {"my_feature": {FEATURE_DEFAULT_VAL_KEY: 5}}} + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # invalid rules type, not list + schema = {FEATURES_KEY: {"my_feature": {FEATURE_DEFAULT_VAL_KEY: False, RULES_KEY: "4"}}} + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + +def test_valid_feature_dict(): + # no rules list at all + schema = {FEATURES_KEY: {"my_feature": {FEATURE_DEFAULT_VAL_KEY: False}}} + validator = SchemaValidator(logger) + validator.validate_json_schema(schema) + + # empty rules list + schema = {FEATURES_KEY: {"my_feature": {FEATURE_DEFAULT_VAL_KEY: False, RULES_KEY: []}}} + validator.validate_json_schema(schema) + + +def test_invalid_rule(): + # rules list is not a list of dict + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + "a", + "b", + ], + } + } + } + validator = SchemaValidator(logger) + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # rules RULE_DEFAULT_VALUE is not bool + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 345345435", + RULE_DEFAULT_VALUE: "False", + }, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # missing conditions list + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 345345435", + RULE_DEFAULT_VALUE: False, + }, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # condition list is empty + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + {RULE_NAME_KEY: "tenant id equals 345345435", RULE_DEFAULT_VALUE: False, CONDITIONS_KEY: []}, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # condition is invalid type, not list + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + {RULE_NAME_KEY: "tenant id equals 345345435", RULE_DEFAULT_VALUE: False, CONDITIONS_KEY: {}}, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + +def test_invalid_condition(): + # invalid condition action + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 345345435", + RULE_DEFAULT_VALUE: False, + CONDITIONS_KEY: {CONDITION_ACTION: "stuff", CONDITION_KEY: "a", CONDITION_VALUE: "a"}, + }, + ], + } + } + } + validator = SchemaValidator(logger) + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # missing condition key and value + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 345345435", + RULE_DEFAULT_VALUE: False, + CONDITIONS_KEY: {CONDITION_ACTION: ACTION.EQUALS.value}, + }, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + # invalid condition key type, not string + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 345345435", + RULE_DEFAULT_VALUE: False, + CONDITIONS_KEY: { + CONDITION_ACTION: ACTION.EQUALS.value, + CONDITION_KEY: 5, + CONDITION_VALUE: "a", + }, + }, + ], + } + } + } + with pytest.raises(ConfigurationError): + validator.validate_json_schema(schema) + + +def test_valid_condition_all_actions(): + validator = SchemaValidator(logger) + schema = { + FEATURES_KEY: { + "my_feature": { + FEATURE_DEFAULT_VAL_KEY: False, + RULES_KEY: [ + { + RULE_NAME_KEY: "tenant id equals 645654 and username is a", + RULE_DEFAULT_VALUE: True, + CONDITIONS_KEY: [ + { + CONDITION_ACTION: ACTION.EQUALS.value, + CONDITION_KEY: "tenant_id", + CONDITION_VALUE: "645654", + }, + { + CONDITION_ACTION: ACTION.STARTSWITH.value, + CONDITION_KEY: "username", + CONDITION_VALUE: "a", + }, + { + CONDITION_ACTION: ACTION.ENDSWITH.value, + CONDITION_KEY: "username", + CONDITION_VALUE: "a", + }, + { + CONDITION_ACTION: ACTION.CONTAINS.value, + CONDITION_KEY: "username", + CONDITION_VALUE: ["a", "b"], + }, + ], + }, + ], + } + }, + } + validator.validate_json_schema(schema) + + +def test_validate_condition_invalid_condition_type(): + # GIVEN an invalid condition type of empty dict + validator = SchemaValidator(logger) + condition = {} + + # WHEN calling _validate_condition + with pytest.raises(ConfigurationError) as err: + validator._validate_condition("foo", condition) + + # THEN raise ConfigurationError + assert "invalid condition type" in str(err) + + +def test_validate_condition_invalid_condition_action(): + # GIVEN an invalid condition action of foo + validator = SchemaValidator(logger) + condition = {"action": "foo"} + + # WHEN calling _validate_condition + with pytest.raises(ConfigurationError) as err: + validator._validate_condition("foo", condition) + + # THEN raise ConfigurationError + assert "invalid action value" in str(err) + + +def test_validate_condition_invalid_condition_key(): + # GIVEN a configuration with a missing "key" + validator = SchemaValidator(logger) + condition = {"action": ACTION.EQUALS.value} + + # WHEN calling _validate_condition + with pytest.raises(ConfigurationError) as err: + validator._validate_condition("foo", condition) + + # THEN raise ConfigurationError + assert "invalid key value" in str(err) + + +def test_validate_condition_missing_condition_value(): + # GIVEN a configuration with a missing condition value + validator = SchemaValidator(logger) + condition = {"action": ACTION.EQUALS.value, "key": "Foo"} + + # WHEN calling _validate_condition + with pytest.raises(ConfigurationError) as err: + validator._validate_condition("foo", condition) + + # THEN raise ConfigurationError + assert "missing condition value" in str(err) + + +def test_validate_rule_invalid_rule_name(): + # GIVEN a rule_name not in the rule dict + validator = SchemaValidator(logger) + rule_name = "invalid_rule_name" + rule = {"missing": ""} + + # WHEN calling _validate_rule + with pytest.raises(ConfigurationError) as err: + validator._validate_rule(rule_name, rule) + + # THEN raise ConfigurationError + assert "invalid rule_name" in str(err) diff --git a/tests/functional/py.typed b/tests/functional/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/test_data_classes.py b/tests/functional/test_data_classes.py index 8b412860694..cbbaf834379 100644 --- a/tests/functional/test_data_classes.py +++ b/tests/functional/test_data_classes.py @@ -743,6 +743,70 @@ def test_seq_trigger_event(): assert record.aws_region == "us-east-2" +def test_default_api_gateway_proxy_event(): + event = APIGatewayProxyEvent(load_event("apiGatewayProxyEvent_noVersionAuth.json")) + + assert event.get("version") is None + assert event.resource == event["resource"] + assert event.path == event["path"] + assert event.http_method == event["httpMethod"] + assert event.headers == event["headers"] + assert event.multi_value_headers == event["multiValueHeaders"] + assert event.query_string_parameters == event["queryStringParameters"] + assert event.multi_value_query_string_parameters == event["multiValueQueryStringParameters"] + + request_context = event.request_context + assert request_context.account_id == event["requestContext"]["accountId"] + assert request_context.api_id == event["requestContext"]["apiId"] + + assert request_context.get("authorizer") is None + + assert request_context.domain_name == event["requestContext"]["domainName"] + assert request_context.domain_prefix == event["requestContext"]["domainPrefix"] + assert request_context.extended_request_id == event["requestContext"]["extendedRequestId"] + assert request_context.http_method == event["requestContext"]["httpMethod"] + + identity = request_context.identity + assert identity.access_key == event["requestContext"]["identity"]["accessKey"] + assert identity.account_id == event["requestContext"]["identity"]["accountId"] + assert identity.caller == event["requestContext"]["identity"]["caller"] + assert ( + identity.cognito_authentication_provider == event["requestContext"]["identity"]["cognitoAuthenticationProvider"] + ) + assert identity.cognito_authentication_type == event["requestContext"]["identity"]["cognitoAuthenticationType"] + assert identity.cognito_identity_id == event["requestContext"]["identity"]["cognitoIdentityId"] + assert identity.cognito_identity_pool_id == event["requestContext"]["identity"]["cognitoIdentityPoolId"] + assert identity.principal_org_id == event["requestContext"]["identity"]["principalOrgId"] + assert identity.source_ip == event["requestContext"]["identity"]["sourceIp"] + assert identity.user == event["requestContext"]["identity"]["user"] + assert identity.user_agent == event["requestContext"]["identity"]["userAgent"] + assert identity.user_arn == event["requestContext"]["identity"]["userArn"] + + assert request_context.path == event["requestContext"]["path"] + assert request_context.protocol == event["requestContext"]["protocol"] + assert request_context.request_id == event["requestContext"]["requestId"] + assert request_context.request_time == event["requestContext"]["requestTime"] + assert request_context.request_time_epoch == event["requestContext"]["requestTimeEpoch"] + assert request_context.resource_id == event["requestContext"]["resourceId"] + assert request_context.resource_path == event["requestContext"]["resourcePath"] + assert request_context.stage == event["requestContext"]["stage"] + + assert event.path_parameters == event["pathParameters"] + assert event.stage_variables == event["stageVariables"] + assert event.body == event["body"] + assert event.is_base64_encoded == event["isBase64Encoded"] + + assert request_context.connected_at is None + assert request_context.connection_id is None + assert request_context.event_type is None + assert request_context.message_direction is None + assert request_context.message_id is None + assert request_context.route_key is None + assert request_context.operation_name is None + assert identity.api_key is None + assert identity.api_key_id is None + + def test_api_gateway_proxy_event(): event = APIGatewayProxyEvent(load_event("apiGatewayProxyEvent.json")) @@ -1210,13 +1274,18 @@ def test_aws_date_utc(): def test_aws_time_utc(): time_str = aws_time() assert isinstance(time_str, str) - assert datetime.datetime.strptime(time_str, "%H:%M:%SZ") + assert datetime.datetime.strptime(time_str, "%H:%M:%S.%fZ") def test_aws_datetime_utc(): datetime_str = aws_datetime() - assert isinstance(datetime_str, str) - assert datetime.datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%SZ") + assert datetime.datetime.strptime(datetime_str[:-1] + "000Z", "%Y-%m-%dT%H:%M:%S.%fZ") + + +def test_format_time_to_milli(): + now = datetime.datetime(2024, 4, 23, 16, 26, 34, 123021) + datetime_str = _formatted_time(now, "%H:%M:%S.%f", -12) + assert datetime_str == "04:26:34.123-12:00:00" def test_aws_timestamp(): @@ -1227,14 +1296,12 @@ def test_aws_timestamp(): def test_format_time_positive(): now = datetime.datetime(2022, 1, 22) datetime_str = _formatted_time(now, "%Y-%m-%d", 8) - assert isinstance(datetime_str, str) assert datetime_str == "2022-01-22+08:00:00" def test_format_time_negative(): now = datetime.datetime(2022, 1, 22, 14, 22, 33) datetime_str = _formatted_time(now, "%H:%M:%S", -12) - assert isinstance(datetime_str, str) assert datetime_str == "02:22:33-12:00:00" diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index 44249af6250..a8d92c05257 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -460,6 +460,18 @@ def handler(event, _): assert request_id == log["correlation_id"] +def test_logger_get_correlation_id(lambda_context, stdout, service_name): + # GIVEN a logger with a correlation_id set + logger = Logger(service=service_name, stream=stdout) + logger.set_correlation_id("foo") + + # WHEN calling get_correlation_id + correlation_id = logger.get_correlation_id() + + # THEN it should return the correlation_id + assert "foo" == correlation_id + + def test_logger_set_correlation_id_path(lambda_context, stdout, service_name): # GIVEN logger = Logger(service=service_name, stream=stdout)