diff --git a/.gitignore b/.gitignore index b4e4ffd6..7bb319c2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ pip-selfcheck.json htmlcov venv +.idea diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 66d9b0f4..3dab014c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,20 @@ CHANGELOG Unreleased ========== +2.11.0 +========== +* bugfix: Fix TypeError by patching register_default_jsonb from psycopg2 `PR350 https://github.com/aws/aws-xray-sdk-python/pull/350` +* improvement: Add annotations `PR348 https://github.com/aws/aws-xray-sdk-python/pull/348` +* bugfix: Use service parameter to match centralized sampling rules `PR 353 https://github.com/aws/aws-xray-sdk-python/pull/353` +* bugfix: Implement PEP3134 to discover underlying problems with python3 `PR355 https://github.com/aws/aws-xray-sdk-python/pull/355` +* improvement: Allow list TopicArn for SNS PublishBatch request `PR358 https://github.com/aws/aws-xray-sdk-python/pull/358` +* bugfix: Version pinning flask-sqlalchemy version to 2.5.1 or less `PR360 https://github.com/aws/aws-xray-sdk-python/pull/360` +* bugfix: Fix UnboundLocalError when aiohttp server raises a CancelledError `PR356 https://github.com/aws/aws-xray-sdk-python/pull/356` +* improvement: Instrument httpx >= 0.20 `PR357 https://github.com/aws/aws-xray-sdk-python/pull/357` +* improvement: [LambdaContext] persist original trace header `PR362 https://github.com/aws/aws-xray-sdk-python/pull/362` +* bugfix: Run tests against Django 4.x `PR361 https://github.com/aws/aws-xray-sdk-python/pull/361` +* improvement: Oversampling Mitigation `PR366 https://github.com/aws/aws-xray-sdk-python/pull/366` + 2.10.0 ========== * bugfix: Only import future for py2. `PR343 `_. diff --git a/README.md b/README.md index cb31e5cf..0788d13e 100644 --- a/README.md +++ b/README.md @@ -383,6 +383,19 @@ If `AUTO_PATCH_PARENT_SEGMENT_NAME` is also specified, then a segment parent wil with the supplied name, wrapping the automatic patching so that it captures any dangling subsegments created on the import patching. +### Django in Lambda +X-Ray can't search on http annotations in subsegments. To enable searching the middleware adds the http values as annotations +This allows searching in the X-Ray console like so + +This is configurable in settings with `URLS_AS_ANNOTATION` that has 3 valid values +`LAMBDA` - the default, which uses URLs as annotations by default if running in a lambda context +`ALL` - do this for every request (useful if running in a mixed lambda/other deployment) +`NONE` - don't do this for any (avoiding hitting the 50 annotation limit) + +``` +annotation.url BEGINSWITH "https://your.url.com/here" +``` + ### Add Flask middleware ```python diff --git a/aws_xray_sdk/core/async_recorder.py b/aws_xray_sdk/core/async_recorder.py index d7b3d198..9e53bb7a 100644 --- a/aws_xray_sdk/core/async_recorder.py +++ b/aws_xray_sdk/core/async_recorder.py @@ -1,4 +1,5 @@ import time +import six from aws_xray_sdk.core.recorder import AWSXRayRecorder from aws_xray_sdk.core.utils import stacktrace @@ -81,10 +82,10 @@ async def record_subsegment_async(self, wrapped, instance, args, kwargs, name, try: return_value = await wrapped(*args, **kwargs) return return_value - except Exception as e: - exception = e + except Exception as exc: + exception = exc stack = stacktrace.get_stacktrace(limit=self._max_trace_back) - raise + six.raise_from(exc, exc) finally: # No-op if subsegment is `None` due to `LOG_ERROR`. if subsegment is not None: diff --git a/aws_xray_sdk/core/lambda_launcher.py b/aws_xray_sdk/core/lambda_launcher.py index f35b2d99..9efccc6b 100644 --- a/aws_xray_sdk/core/lambda_launcher.py +++ b/aws_xray_sdk/core/lambda_launcher.py @@ -142,5 +142,6 @@ def _initialize_context(self, trace_header): entityid=trace_header.parent, sampled=sampled, ) + segment.save_origin_trace_header(trace_header) setattr(self._local, 'segment', segment) setattr(self._local, 'entities', []) diff --git a/aws_xray_sdk/core/models/dummy_entities.py b/aws_xray_sdk/core/models/dummy_entities.py index 9e4a0379..6d962d71 100644 --- a/aws_xray_sdk/core/models/dummy_entities.py +++ b/aws_xray_sdk/core/models/dummy_entities.py @@ -11,7 +11,7 @@ class DummySegment(Segment): the segment based on sampling rules. Adding data to a dummy segment becomes a no-op except for subsegments. This is to reduce the memory footprint of the SDK. - A dummy segment will not be sent to the X-Ray daemon. Manually create + A dummy segment will not be sent to the X-Ray daemon. Manually creating dummy segments is not recommended. """ diff --git a/aws_xray_sdk/core/models/entity.py b/aws_xray_sdk/core/models/entity.py index 41ef3893..3583bf28 100644 --- a/aws_xray_sdk/core/models/entity.py +++ b/aws_xray_sdk/core/models/entity.py @@ -81,6 +81,10 @@ def add_subsegment(self, subsegment): """ self._check_ended() subsegment.parent_id = self.id + + if not self.sampled and subsegment.sampled: + log.warning("This sampled subsegment is being added to an unsampled parent segment/subsegment and will be orphaned.") + self.subsegments.append(subsegment) def remove_subsegment(self, subsegment): diff --git a/aws_xray_sdk/core/patcher.py b/aws_xray_sdk/core/patcher.py index 3319a26a..1a700dd9 100644 --- a/aws_xray_sdk/core/patcher.py +++ b/aws_xray_sdk/core/patcher.py @@ -6,6 +6,7 @@ import re import sys import wrapt +import six from aws_xray_sdk import global_sdk_config from .utils.compat import PY2, is_classmethod, is_instance_method @@ -25,6 +26,7 @@ 'psycopg2', 'pg8000', 'sqlalchemy_core', + 'httpx', ) NO_DOUBLE_PATCH = ( @@ -39,6 +41,7 @@ 'psycopg2', 'pg8000', 'sqlalchemy_core', + 'httpx', ) _PATCHED_MODULES = set() @@ -107,9 +110,9 @@ def patch(modules_to_patch, raise_errors=True, ignore_module_patterns=None): def _patch_module(module_to_patch, raise_errors=True): try: _patch(module_to_patch) - except Exception: + except Exception as exc: if raise_errors: - raise + six.raise_from(exc, exc) log.debug('failed to patch module %s', module_to_patch) diff --git a/aws_xray_sdk/core/recorder.py b/aws_xray_sdk/core/recorder.py index 3169e3a2..ff4f20b5 100644 --- a/aws_xray_sdk/core/recorder.py +++ b/aws_xray_sdk/core/recorder.py @@ -4,6 +4,7 @@ import os import platform import time +import six from aws_xray_sdk import global_sdk_config from aws_xray_sdk.version import VERSION @@ -232,7 +233,7 @@ def begin_segment(self, name=None, traceid=None, elif sampling: decision = sampling elif self.sampling: - decision = self._sampler.should_trace() + decision = self._sampler.should_trace({'service': seg_name}) if not decision: segment = DummySegment(seg_name) @@ -274,16 +275,10 @@ def current_segment(self): else: return entity - def begin_subsegment(self, name, namespace='local'): - """ - Begin a new subsegment. - If there is open subsegment, the newly created subsegment will be the - child of latest opened subsegment. - If not, it will be the child of the current open segment. - - :param str name: the name of the subsegment. - :param str namespace: currently can only be 'local', 'remote', 'aws'. - """ + def _begin_subsegment_helper(self, name, namespace='local', beginWithoutSampling=False): + ''' + Helper method to begin_subsegment and begin_subsegment_without_sampling + ''' # Generating the parent dummy segment is necessary. # We don't need to store anything in context. Assumption here # is that we only work with recorder-level APIs. @@ -294,16 +289,42 @@ def begin_subsegment(self, name, namespace='local'): if not segment: log.warning("No segment found, cannot begin subsegment %s." % name) return None - - if not segment.sampled: + + current_entity = self.get_trace_entity() + if not current_entity.sampled or beginWithoutSampling: subsegment = DummySubsegment(segment, name) else: subsegment = Subsegment(name, namespace, segment) self.context.put_subsegment(subsegment) - return subsegment + + + def begin_subsegment(self, name, namespace='local'): + """ + Begin a new subsegment. + If there is open subsegment, the newly created subsegment will be the + child of latest opened subsegment. + If not, it will be the child of the current open segment. + + :param str name: the name of the subsegment. + :param str namespace: currently can only be 'local', 'remote', 'aws'. + """ + return self._begin_subsegment_helper(name, namespace) + + + def begin_subsegment_without_sampling(self, name): + """ + Begin a new unsampled subsegment. + If there is open subsegment, the newly created subsegment will be the + child of latest opened subsegment. + If not, it will be the child of the current open segment. + + :param str name: the name of the subsegment. + """ + return self._begin_subsegment_helper(name, beginWithoutSampling=True) + def current_subsegment(self): """ Return the latest opened subsegment. In a multithreading environment, @@ -435,10 +456,10 @@ def record_subsegment(self, wrapped, instance, args, kwargs, name, try: return_value = wrapped(*args, **kwargs) return return_value - except Exception as e: - exception = e + except Exception as exc: + exception = exc stack = stacktrace.get_stacktrace(limit=self.max_trace_back) - raise + six.raise_from(exc, exc) finally: # No-op if subsegment is `None` due to `LOG_ERROR`. if subsegment is not None: @@ -486,7 +507,8 @@ def _send_segment(self): def _stream_subsegment_out(self, subsegment): log.debug("streaming subsegments...") - self.emitter.send_entity(subsegment) + if subsegment.sampled: + self.emitter.send_entity(subsegment) def _load_sampling_rules(self, sampling_rules): diff --git a/aws_xray_sdk/core/utils/sqs_message_helper.py b/aws_xray_sdk/core/utils/sqs_message_helper.py new file mode 100644 index 00000000..f2a1a1c8 --- /dev/null +++ b/aws_xray_sdk/core/utils/sqs_message_helper.py @@ -0,0 +1,11 @@ +SQS_XRAY_HEADER = "AWSTraceHeader" +class SqsMessageHelper: + + @staticmethod + def isSampled(sqs_message): + attributes = sqs_message['attributes'] + + if SQS_XRAY_HEADER not in attributes: + return False + + return 'Sampled=1' in attributes[SQS_XRAY_HEADER] \ No newline at end of file diff --git a/aws_xray_sdk/ext/aiohttp/middleware.py b/aws_xray_sdk/ext/aiohttp/middleware.py index a58316fc..cc54c482 100644 --- a/aws_xray_sdk/ext/aiohttp/middleware.py +++ b/aws_xray_sdk/ext/aiohttp/middleware.py @@ -1,3 +1,5 @@ +import six + """ AioHttp Middleware """ @@ -64,14 +66,14 @@ async def middleware(request, handler): except HTTPException as exc: # Non 2XX responses are raised as HTTPExceptions response = exc - raise - except Exception as err: + six.raise_from(exc, exc) + except BaseException as exc: # Store exception information including the stacktrace to the segment response = None segment.put_http_meta(http.STATUS, 500) stack = stacktrace.get_stacktrace(limit=xray_recorder.max_trace_back) - segment.add_exception(err, stack) - raise + segment.add_exception(exc, stack) + six.raise_from(exc, exc) finally: if response is not None: segment.put_http_meta(http.STATUS, response.status) diff --git a/aws_xray_sdk/ext/django/conf.py b/aws_xray_sdk/ext/django/conf.py index a1b51c0a..1b5c23d0 100644 --- a/aws_xray_sdk/ext/django/conf.py +++ b/aws_xray_sdk/ext/django/conf.py @@ -19,6 +19,7 @@ 'PATCH_MODULES': [], 'AUTO_PATCH_PARENT_SEGMENT_NAME': None, 'IGNORE_MODULE_PATTERNS': [], + 'URLS_AS_ANNOTATION': 'LAMBDA', # 3 valid values, NONE -> don't ever, LAMBDA -> only for AWS Lambdas, ALL -> every time } XRAY_NAMESPACE = 'XRAY_RECORDER' diff --git a/aws_xray_sdk/ext/django/middleware.py b/aws_xray_sdk/ext/django/middleware.py index bb0c8a3a..d565be61 100644 --- a/aws_xray_sdk/ext/django/middleware.py +++ b/aws_xray_sdk/ext/django/middleware.py @@ -1,4 +1,5 @@ import logging +from .conf import settings from aws_xray_sdk.core import xray_recorder from aws_xray_sdk.core.models import http @@ -30,6 +31,14 @@ def __init__(self, get_response): if check_in_lambda() and type(xray_recorder.context) == LambdaContext: self.in_lambda_ctx = True + def _urls_as_annotation(self): + if settings.URLS_AS_ANNOTATION == "LAMBDA" and self.in_lambda_ctx: + return True + elif settings.URLS_AS_ANNOTATION == "ALL": + return True + return False + + # hooks for django version >= 1.10 def __call__(self, request): @@ -50,9 +59,10 @@ def __call__(self, request): recorder=xray_recorder, sampling_req=sampling_req, ) - if self.in_lambda_ctx: segment = xray_recorder.begin_subsegment(name) + # X-Ray can't search/filter subsegments on URL but it can search annotations + # So for lambda to be able to filter by annotation we add these as annotations else: segment = xray_recorder.begin_segment( name=name, @@ -64,23 +74,37 @@ def __call__(self, request): segment.save_origin_trace_header(xray_header) segment.put_http_meta(http.URL, request.build_absolute_uri()) segment.put_http_meta(http.METHOD, request.method) + if self._urls_as_annotation(): + segment.put_annotation(http.URL, request.build_absolute_uri()) + segment.put_annotation(http.METHOD, request.method) if meta.get(USER_AGENT_KEY): segment.put_http_meta(http.USER_AGENT, meta.get(USER_AGENT_KEY)) + if self._urls_as_annotation(): + segment.put_annotation(http.USER_AGENT, meta.get(USER_AGENT_KEY)) if meta.get(X_FORWARDED_KEY): # X_FORWARDED_FOR may come from untrusted source so we # need to set the flag to true as additional information segment.put_http_meta(http.CLIENT_IP, meta.get(X_FORWARDED_KEY)) segment.put_http_meta(http.X_FORWARDED_FOR, True) + if self._urls_as_annotation(): + segment.put_annotation(http.CLIENT_IP, meta.get(X_FORWARDED_KEY)) + segment.put_annotation(http.X_FORWARDED_FOR, True) elif meta.get(REMOTE_ADDR_KEY): segment.put_http_meta(http.CLIENT_IP, meta.get(REMOTE_ADDR_KEY)) + if self._urls_as_annotation(): + segment.put_annotation(http.CLIENT_IP, meta.get(REMOTE_ADDR_KEY)) response = self.get_response(request) segment.put_http_meta(http.STATUS, response.status_code) + if self._urls_as_annotation(): + segment.put_annotation(http.STATUS, response.status_code) if response.has_header(CONTENT_LENGTH_KEY): length = int(response[CONTENT_LENGTH_KEY]) segment.put_http_meta(http.CONTENT_LENGTH, length) + if self._urls_as_annotation(): + segment.put_annotation(http.CONTENT_LENGTH, length) response[http.XRAY_HEADER] = prepare_response_header(xray_header, segment) if self.in_lambda_ctx: diff --git a/aws_xray_sdk/ext/httpx/__init__.py b/aws_xray_sdk/ext/httpx/__init__.py new file mode 100644 index 00000000..4e8acac6 --- /dev/null +++ b/aws_xray_sdk/ext/httpx/__init__.py @@ -0,0 +1,3 @@ +from .patch import patch + +__all__ = ['patch'] diff --git a/aws_xray_sdk/ext/httpx/patch.py b/aws_xray_sdk/ext/httpx/patch.py new file mode 100644 index 00000000..dfcd9bf8 --- /dev/null +++ b/aws_xray_sdk/ext/httpx/patch.py @@ -0,0 +1,71 @@ +import httpx + +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.models import http +from aws_xray_sdk.ext.util import inject_trace_header, get_hostname + + +def patch(): + httpx.Client = _InstrumentedClient + httpx.AsyncClient = _InstrumentedAsyncClient + httpx._api.Client = _InstrumentedClient + + +class _InstrumentedClient(httpx.Client): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._original_transport = self._transport + self._transport = SyncInstrumentedTransport(self._transport) + + +class _InstrumentedAsyncClient(httpx.AsyncClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._original_transport = self._transport + self._transport = AsyncInstrumentedTransport(self._transport) + + +class SyncInstrumentedTransport(httpx.BaseTransport): + def __init__(self, transport: httpx.BaseTransport): + self._wrapped_transport = transport + + def handle_request(self, request: httpx.Request) -> httpx.Response: + with xray_recorder.in_subsegment( + get_hostname(str(request.url)), namespace="remote" + ) as subsegment: + if subsegment is not None: + subsegment.put_http_meta(http.METHOD, request.method) + subsegment.put_http_meta( + http.URL, + str(request.url.copy_with(password=None, query=None, fragment=None)), + ) + inject_trace_header(request.headers, subsegment) + + response = self._wrapped_transport.handle_request(request) + if subsegment is not None: + subsegment.put_http_meta(http.STATUS, response.status_code) + return response + + +class AsyncInstrumentedTransport(httpx.AsyncBaseTransport): + def __init__(self, transport: httpx.AsyncBaseTransport): + self._wrapped_transport = transport + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + async with xray_recorder.in_subsegment_async( + get_hostname(str(request.url)), namespace="remote" + ) as subsegment: + if subsegment is not None: + subsegment.put_http_meta(http.METHOD, request.method) + subsegment.put_http_meta( + http.URL, + str(request.url.copy_with(password=None, query=None, fragment=None)), + ) + inject_trace_header(request.headers, subsegment) + + response = await self._wrapped_transport.handle_async_request(request) + if subsegment is not None: + subsegment.put_http_meta(http.STATUS, response.status_code) + return response diff --git a/aws_xray_sdk/ext/psycopg2/patch.py b/aws_xray_sdk/ext/psycopg2/patch.py index 708bf45c..f6a1d4c6 100644 --- a/aws_xray_sdk/ext/psycopg2/patch.py +++ b/aws_xray_sdk/ext/psycopg2/patch.py @@ -7,7 +7,6 @@ def patch(): - wrapt.wrap_function_wrapper( 'psycopg2', 'connect', @@ -24,11 +23,16 @@ def patch(): _xray_register_type_fix ) + wrapt.wrap_function_wrapper( + 'psycopg2.extras', + 'register_default_jsonb', + _xray_register_default_jsonb_fix + ) -def _xray_traced_connect(wrapped, instance, args, kwargs): +def _xray_traced_connect(wrapped, instance, args, kwargs): conn = wrapped(*args, **kwargs) - parameterized_dsn = { c[0]: c[-1] for c in map(methodcaller('split', '='), conn.dsn.split(' '))} + parameterized_dsn = {c[0]: c[-1] for c in map(methodcaller('split', '='), conn.dsn.split(' '))} meta = { 'database_type': 'PostgreSQL', 'url': 'postgresql://{}@{}:{}/{}'.format( @@ -44,6 +48,7 @@ def _xray_traced_connect(wrapped, instance, args, kwargs): return XRayTracedConn(conn, meta) + def _xray_register_type_fix(wrapped, instance, args, kwargs): """Send the actual connection or curser to register type.""" our_args = list(copy.copy(args)) @@ -51,3 +56,14 @@ def _xray_register_type_fix(wrapped, instance, args, kwargs): our_args[1] = our_args[1].__wrapped__ return wrapped(*our_args, **kwargs) + + +def _xray_register_default_jsonb_fix(wrapped, instance, args, kwargs): + our_kwargs = dict() + for key, value in kwargs.items(): + if key == "conn_or_curs" and isinstance(value, (XRayTracedConn, XRayTracedCursor)): + # unwrap the connection or cursor to be sent to register_default_jsonb + value = value.__wrapped__ + our_kwargs[key] = value + + return wrapped(*args, **our_kwargs) diff --git a/aws_xray_sdk/ext/resources/aws_para_whitelist.json b/aws_xray_sdk/ext/resources/aws_para_whitelist.json index 5f0d9b54..30a45cd6 100644 --- a/aws_xray_sdk/ext/resources/aws_para_whitelist.json +++ b/aws_xray_sdk/ext/resources/aws_para_whitelist.json @@ -6,6 +6,11 @@ "request_parameters": [ "TopicArn" ] + }, + "PublishBatch": { + "request_parameters": [ + "TopicArn" + ] } } }, @@ -912,4 +917,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws_xray_sdk/ext/sqlalchemy_core/patch.py b/aws_xray_sdk/ext/sqlalchemy_core/patch.py index acab1fd4..42aeb521 100644 --- a/aws_xray_sdk/ext/sqlalchemy_core/patch.py +++ b/aws_xray_sdk/ext/sqlalchemy_core/patch.py @@ -1,12 +1,13 @@ import logging import sys +import wrapt +import six if sys.version_info >= (3, 0, 0): from urllib.parse import urlparse, uses_netloc else: from urlparse import urlparse, uses_netloc -import wrapt from aws_xray_sdk.core import xray_recorder from aws_xray_sdk.core.patcher import _PATCHED_MODULES @@ -72,12 +73,12 @@ def _process_request(wrapped, engine_instance, args, kwargs): subsegment = None try: res = wrapped(*args, **kwargs) - except Exception: + except Exception as exc: if subsegment is not None: exception = sys.exc_info()[1] stack = stacktrace.get_stacktrace(limit=xray_recorder._max_trace_back) subsegment.add_exception(exception, stack) - raise + six.raise_from(exc, exc) finally: if subsegment is not None: subsegment.set_sql(sql) diff --git a/aws_xray_sdk/ext/util.py b/aws_xray_sdk/ext/util.py index 8390f9ee..ad9d5207 100644 --- a/aws_xray_sdk/ext/util.py +++ b/aws_xray_sdk/ext/util.py @@ -35,7 +35,6 @@ def inject_trace_header(headers, entity): else: header = entity.get_origin_trace_header() data = header.data if header else None - to_insert = TraceHeader( root=entity.trace_id, parent=entity.id, diff --git a/aws_xray_sdk/version.py b/aws_xray_sdk/version.py index 5bed97f9..a13b4dce 100644 --- a/aws_xray_sdk/version.py +++ b/aws_xray_sdk/version.py @@ -1 +1 @@ -VERSION = '2.10.0' +VERSION = '2.11.0' diff --git a/docs/conf.py b/docs/conf.py index 2cc56f2b..59c3e5de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,9 +62,9 @@ # built documents. # # The short X.Y version. -version = u'2.10.0' +version = u'2.11.0' # The full version, including alpha/beta/rc tags. -release = u'2.10.0' +release = u'2.11.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/tests/ext/aiohttp/test_middleware.py b/tests/ext/aiohttp/test_middleware.py index b5fc9f14..4c875de9 100644 --- a/tests/ext/aiohttp/test_middleware.py +++ b/tests/ext/aiohttp/test_middleware.py @@ -75,9 +75,10 @@ async def handle_unauthorized(self, request: web.Request) -> web.Response: async def handle_exception(self, request: web.Request) -> web.Response: """ - Handle /exception which raises a KeyError + Handle /exception which raises a CancelledError; this is important, as starting from python 3.8 CancelledError + extends BaseException instead of Exception """ - return {}['key'] + raise asyncio.CancelledError() async def handle_delay(self, request: web.Request) -> web.Response: """ @@ -213,8 +214,8 @@ async def test_exception(test_client, loop, recorder): """ client = await test_client(ServerTest.app(loop=loop)) - resp = await client.get('/exception') - await resp.text() # Need this to trigger Exception + with pytest.raises(Exception): + await client.get('/exception') segment = recorder.emitter.pop() assert not segment.in_progress @@ -227,7 +228,7 @@ async def test_exception(test_client, loop, recorder): assert request['url'] == 'http://127.0.0.1:{port}/exception'.format(port=client.port) assert request['client_ip'] == '127.0.0.1' assert response['status'] == 500 - assert exception.type == 'KeyError' + assert exception.type == 'CancelledError' async def test_unhauthorized(test_client, loop, recorder): diff --git a/tests/ext/django/app/views.py b/tests/ext/django/app/views.py index 1b3b6f62..ec76c846 100644 --- a/tests/ext/django/app/views.py +++ b/tests/ext/django/app/views.py @@ -1,7 +1,7 @@ import sqlite3 from django.http import HttpResponse -from django.conf.urls import url +from django.urls import path from django.views.generic import TemplateView @@ -32,9 +32,9 @@ def call_db(request): urlpatterns = [ - url(r'^200ok/$', ok, name='200ok'), - url(r'^500fault/$', fault, name='500fault'), - url(r'^call_db/$', call_db, name='call_db'), - url(r'^template/$', IndexView.as_view(), name='template'), - url(r'^template_block/$', TemplateBlockView.as_view(), name='template_block'), + path('200ok/', ok, name='200ok'), + path('500fault/', fault, name='500fault'), + path('call_db/', call_db, name='call_db'), + path('template/', IndexView.as_view(), name='template'), + path('template_block/', TemplateBlockView.as_view(), name='template_block'), ] diff --git a/tests/ext/httpx/__init__.py b/tests/ext/httpx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/httpx/test_httpx.py b/tests/ext/httpx/test_httpx.py new file mode 100644 index 00000000..3bfeb967 --- /dev/null +++ b/tests/ext/httpx/test_httpx.py @@ -0,0 +1,218 @@ +import pytest + +import httpx +from aws_xray_sdk.core import patch +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.util import strip_url, get_hostname + + +patch(("httpx",)) + +# httpbin.org is created by the same author of requests to make testing http easy. +BASE_URL = "httpbin.org" + + +@pytest.fixture(autouse=True) +def construct_ctx(): + """ + Clean up context storage on each test run and begin a segment + so that later subsegment can be attached. After each test run + it cleans up context storage again. + """ + xray_recorder.configure(service="test", sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment("name") + yield + xray_recorder.clear_trace_entities() + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_ok(use_client): + status_code = 200 + url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.get(url) + else: + response = httpx.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert get_hostname(url) == BASE_URL + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_error(use_client): + status_code = 400 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.post(url) + else: + response = httpx.post(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "POST" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_throttle(use_client): + status_code = 429 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.head(url) + else: + response = httpx.head(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + assert subsegment.throttle + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "HEAD" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_fault(use_client): + status_code = 500 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.put(url) + else: + response = httpx.put(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "PUT" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_nonexistent_domain(use_client): + with pytest.raises(httpx.ConnectError): + if use_client: + with httpx.Client() as client: + client.get("http://doesnt.exist") + else: + httpx.get("http://doesnt.exist") + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.fault + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "ConnectError" + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_invalid_url(use_client): + url = "KLSDFJKLSDFJKLSDJF" + with pytest.raises(httpx.UnsupportedProtocol): + if use_client: + with httpx.Client() as client: + client.get(url) + else: + httpx.get(url) + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == "/{}".format(strip_url(url)) + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "UnsupportedProtocol" + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_name_uses_hostname(use_client): + if use_client: + client = httpx.Client() + else: + client = httpx + + try: + url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + client.get(url1) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta1 = subsegment.http + assert http_meta1["request"]["url"] == strip_url(url1) + assert http_meta1["request"]["method"].upper() == "GET" + + url2 = "http://{}/".format(BASE_URL) + client.get(url2, params={"some": "payload", "not": "toBeIncluded"}) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta2 = subsegment.http + assert http_meta2["request"]["url"] == strip_url(url2) + assert http_meta2["request"]["method"].upper() == "GET" + + url3 = "http://subdomain.{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + try: + client.get(url3) + except httpx.ConnectError: + pass + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == "subdomain." + BASE_URL + http_meta3 = subsegment.http + assert http_meta3["request"]["url"] == strip_url(url3) + assert http_meta3["request"]["method"].upper() == "GET" + finally: + if use_client: + client.close() + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_strip_http_url(use_client): + status_code = 200 + url = "http://{}/get?foo=bar".format(BASE_URL) + if use_client: + with httpx.Client() as client: + response = client.get(url) + else: + response = httpx.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code diff --git a/tests/ext/httpx/test_httpx_async.py b/tests/ext/httpx/test_httpx_async.py new file mode 100644 index 00000000..c5d0560a --- /dev/null +++ b/tests/ext/httpx/test_httpx_async.py @@ -0,0 +1,190 @@ +import pytest + +import httpx +from aws_xray_sdk.core import patch +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.util import strip_url, get_hostname + + +patch(("httpx",)) + +# httpbin.org is created by the same author of requests to make testing http easy. +BASE_URL = "httpbin.org" + + +@pytest.fixture(autouse=True) +def construct_ctx(): + """ + Clean up context storage on each test run and begin a segment + so that later subsegment can be attached. After each test run + it cleans up context storage again. + """ + xray_recorder.configure(service="test", sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment("name") + yield + xray_recorder.clear_trace_entities() + + +@pytest.mark.asyncio +async def test_ok_async(): + status_code = 200 + url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert get_hostname(url) == BASE_URL + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_error_async(): + status_code = 400 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.post(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "POST" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_throttle_async(): + status_code = 429 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.head(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + assert subsegment.throttle + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "HEAD" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_fault_async(): + status_code = 500 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.put(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "PUT" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_nonexistent_domain_async(): + with pytest.raises(httpx.ConnectError): + async with httpx.AsyncClient() as client: + await client.get("http://doesnt.exist") + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.fault + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "ConnectError" + + +@pytest.mark.asyncio +async def test_invalid_url_async(): + url = "KLSDFJKLSDFJKLSDJF" + with pytest.raises(httpx.UnsupportedProtocol): + async with httpx.AsyncClient() as client: + await client.get(url) + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == "/{}".format(strip_url(url)) + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "UnsupportedProtocol" + + +@pytest.mark.asyncio +async def test_name_uses_hostname_async(): + async with httpx.AsyncClient() as client: + url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + await client.get(url1) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta1 = subsegment.http + assert http_meta1["request"]["url"] == strip_url(url1) + assert http_meta1["request"]["method"].upper() == "GET" + + url2 = "http://{}/".format(BASE_URL) + await client.get(url2, params={"some": "payload", "not": "toBeIncluded"}) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta2 = subsegment.http + assert http_meta2["request"]["url"] == strip_url(url2) + assert http_meta2["request"]["method"].upper() == "GET" + + url3 = "http://subdomain.{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + try: + await client.get(url3) + except Exception: + # This is an invalid url so we dont want to break the test + pass + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == "subdomain." + BASE_URL + http_meta3 = subsegment.http + assert http_meta3["request"]["url"] == strip_url(url3) + assert http_meta3["request"]["method"].upper() == "GET" + + +@pytest.mark.asyncio +async def test_strip_http_url_async(): + status_code = 200 + url = "http://{}/get?foo=bar".format(BASE_URL) + async with httpx.AsyncClient() as client: + response = await client.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code diff --git a/tests/ext/psycopg2/test_psycopg2.py b/tests/ext/psycopg2/test_psycopg2.py index 4736b5c0..9ab80069 100644 --- a/tests/ext/psycopg2/test_psycopg2.py +++ b/tests/ext/psycopg2/test_psycopg2.py @@ -173,3 +173,17 @@ def test_query_as_string(): test_sql = psycopg2.sql.Identifier('test') assert test_sql.as_string(conn) assert test_sql.as_string(conn.cursor()) + + +def test_register_default_jsonb(): + with testing.postgresql.Postgresql() as postgresql: + url = postgresql.url() + dsn = postgresql.dsn() + conn = psycopg2.connect('dbname=' + dsn['database'] + + ' password=mypassword' + + ' host=' + dsn['host'] + + ' port=' + str(dsn['port']) + + ' user=' + dsn['user']) + + assert psycopg2.extras.register_default_jsonb(conn_or_curs=conn, loads=lambda x: x) + assert psycopg2.extras.register_default_jsonb(conn_or_curs=conn.cursor(), loads=lambda x: x) diff --git a/tests/test_facade_segment.py b/tests/test_facade_segment.py index 30842019..5b95115a 100644 --- a/tests/test_facade_segment.py +++ b/tests/test_facade_segment.py @@ -55,3 +55,18 @@ def test_structure_intact(): assert segment.subsegments[0] is subsegment assert subsegment.subsegments[0] is subsegment2 + +def test_adding_unsampled_subsegment(): + + segment = FacadeSegment('name', 'id', 'id', True) + subsegment = Subsegment('sampled', 'local', segment) + subsegment2 = Subsegment('unsampled', 'local', segment) + subsegment2.sampled = False + + segment.add_subsegment(subsegment) + subsegment.add_subsegment(subsegment2) + + + assert segment.subsegments[0] is subsegment + assert subsegment.subsegments[0] is subsegment2 + assert subsegment2.sampled == False diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py index 29b1cc42..98e1687c 100644 --- a/tests/test_lambda_context.py +++ b/tests/test_lambda_context.py @@ -8,7 +8,8 @@ TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793' PARENT_ID = '53995c3f42cd8ad8' -HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID) +DATA = 'Foo=Bar' +HEADER_VAR = "Root=%s;Parent=%s;Sampled=1;%s" % (TRACE_ID, PARENT_ID, DATA) os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR context = lambda_launcher.LambdaContext() @@ -26,6 +27,7 @@ def test_facade_segment_generation(): assert segment.id == PARENT_ID assert segment.trace_id == TRACE_ID assert segment.sampled + assert DATA in segment.get_origin_trace_header().to_header_str() def test_put_subsegment(): @@ -43,6 +45,7 @@ def test_put_subsegment(): assert subsegment2.parent_id == subsegment.id assert subsegment.parent_id == segment.id assert subsegment2.parent_segment is segment + assert DATA in subsegment2.parent_segment.get_origin_trace_header().to_header_str() context.end_subsegment() assert context.get_trace_entity().id == subsegment.id @@ -60,6 +63,7 @@ def test_disable(): global_sdk_config.set_sdk_enabled(False) segment = context.get_trace_entity() assert not segment.sampled + assert DATA in segment.get_origin_trace_header().to_header_str() def test_non_initialized(): diff --git a/tests/test_recorder.py b/tests/test_recorder.py index c060d3b7..614de01b 100644 --- a/tests/test_recorder.py +++ b/tests/test_recorder.py @@ -1,7 +1,11 @@ import platform +import time import pytest +from aws_xray_sdk.core.sampling.sampling_rule import SamplingRule +from aws_xray_sdk.core.sampling.rule_cache import RuleCache +from aws_xray_sdk.core.sampling.sampler import DefaultSampler from aws_xray_sdk.version import VERSION from .util import get_new_stubbed_recorder @@ -38,7 +42,6 @@ def test_default_runtime_context(): def test_subsegment_parenting(): - segment = xray_recorder.begin_segment('name') subsegment = xray_recorder.begin_subsegment('name') xray_recorder.end_subsegment('name') @@ -97,7 +100,6 @@ def test_put_annotation_metadata(): def test_pass_through_with_missing_context(): - xray_recorder = get_new_stubbed_recorder() xray_recorder.configure(sampling=False, context_missing='LOG_ERROR') assert not xray_recorder.is_sampled() @@ -139,6 +141,24 @@ def test_first_begin_segment_sampled(): assert segment.sampled +def test_unsampled_subsegment_of_sampled_parent(): + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(sampling=True) + segment = xray_recorder.begin_segment('name', sampling=True) + subsegment = xray_recorder.begin_subsegment_without_sampling('unsampled') + + assert segment.sampled == True + assert subsegment.sampled == False + +def test_begin_subsegment_unsampled(): + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(sampling=False) + segment = xray_recorder.begin_segment('name', sampling=False) + subsegment = xray_recorder.begin_subsegment_without_sampling('unsampled') + + assert segment.sampled == False + assert subsegment.sampled == False + def test_in_segment_closing(): xray_recorder = get_new_stubbed_recorder() @@ -175,7 +195,6 @@ def test_in_segment_exception(): assert segment.fault is True assert len(segment.cause['exceptions']) == 1 - with pytest.raises(Exception): with xray_recorder.in_segment('name') as segment: with xray_recorder.in_subsegment('name') as subsegment: @@ -200,6 +219,23 @@ def test_disable_is_dummy(): assert type(xray_recorder.current_segment()) is DummySegment assert type(xray_recorder.current_subsegment()) is DummySubsegment +def test_unsampled_subsegment_is_dummy(): + assert global_sdk_config.sdk_enabled() + segment = xray_recorder.begin_segment('name') + subsegment = xray_recorder.begin_subsegment_without_sampling('name') + + assert type(xray_recorder.current_subsegment()) is DummySubsegment + +def test_subsegment_respects_parent_sampling_decision(): + assert global_sdk_config.sdk_enabled() + segment = xray_recorder.begin_segment('name') + subsegment = xray_recorder.begin_subsegment_without_sampling('name2') + subsegment2 = xray_recorder.begin_subsegment('unsampled-subsegment') + + assert type(xray_recorder.current_subsegment()) is DummySubsegment + assert subsegment.sampled == False + assert subsegment2.sampled == False + def test_disabled_empty_context_current_calls(): global_sdk_config.set_sdk_enabled(False) @@ -259,7 +295,6 @@ def test_disabled_get_context_entity(): assert type(entity) is DummySegment - def test_max_stack_trace_zero(): xray_recorder.configure(max_trace_back=1) with pytest.raises(Exception): @@ -279,3 +314,41 @@ def test_max_stack_trace_zero(): assert len(segment_with_stack.cause['exceptions'][0].stack) == 1 assert len(segment_no_stack.cause['exceptions'][0].stack) == 0 + + +# CustomSampler to mimic the DefaultSampler, +# but without the rule and target polling logic. +class CustomSampler(DefaultSampler): + def start(self): + pass + + def should_trace(self, sampling_req=None): + rule_cache = RuleCache() + rule_cache.last_updated = int(time.time()) + sampling_rule_a = SamplingRule(name='rule_a', + priority=2, + rate=0.5, + reservoir_size=1, + service='app_a') + sampling_rule_b = SamplingRule(name='rule_b', + priority=2, + rate=0.5, + reservoir_size=1, + service='app_b') + rule_cache.load_rules([sampling_rule_a, sampling_rule_b]) + now = int(time.time()) + if sampling_req and not sampling_req.get('service_type', None): + sampling_req['service_type'] = self._origin + elif sampling_req is None: + sampling_req = {'service_type': self._origin} + matched_rule = rule_cache.get_matched_rule(sampling_req, now) + if matched_rule: + return self._process_matched_rule(matched_rule, now) + else: + return self._local_sampler.should_trace(sampling_req) + + +def test_begin_segment_matches_sampling_rule_on_name(): + xray_recorder.configure(sampling=True, sampler=CustomSampler()) + segment = xray_recorder.begin_segment("app_b") + assert segment.aws.get('xray').get('sampling_rule_name') == 'rule_b' diff --git a/tests/test_sqs_message_helper.py b/tests/test_sqs_message_helper.py new file mode 100644 index 00000000..6ee44b8d --- /dev/null +++ b/tests/test_sqs_message_helper.py @@ -0,0 +1,68 @@ +from aws_xray_sdk.core.utils.sqs_message_helper import SqsMessageHelper + +import pytest + +sampleSqsMessageEvent = { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185", + "AWSTraceHeader":"Root=1-632BB806-bd862e3fe1be46a994272793;Sampled=1" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + }, + { + "messageId": "2e1424d4-f796-459a-8184-9c92662be6da", + "receiptHandle": "AQEBzWwaftRI0KuVm4tP+/7q1rGgNqicHq...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082650636", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082650649", + "AWSTraceHeader":"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=0" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + }, + { + "messageId": "2e1424d4-f796-459a-8184-9c92662be6da", + "receiptHandle": "AQEBzWwaftRI0KuVm4tP+/7q1rGgNqicHq...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082650636", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082650649", + "AWSTraceHeader":"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + } + ] + } + +def test_return_true_when_sampling_1(): + assert SqsMessageHelper.isSampled(sampleSqsMessageEvent['Records'][0]) == True + +def test_return_false_when_sampling_0(): + assert SqsMessageHelper.isSampled(sampleSqsMessageEvent['Records'][1]) == False + +def test_return_false_with_no_sampling_flag(): + assert SqsMessageHelper.isSampled(sampleSqsMessageEvent['Records'][2]) == False \ No newline at end of file diff --git a/tests/test_trace_entities.py b/tests/test_trace_entities.py index e42cee0c..7d987ed0 100644 --- a/tests/test_trace_entities.py +++ b/tests/test_trace_entities.py @@ -11,6 +11,9 @@ from aws_xray_sdk.core.exceptions.exceptions import AlreadyEndedException from .util import entity_to_dict +from .util import get_new_stubbed_recorder + +xray_recorder = get_new_stubbed_recorder() def test_unicode_entity_name(): @@ -263,3 +266,19 @@ def test_add_exception_appending_exceptions(): assert isinstance(segment.cause, dict) assert len(segment.cause['exceptions']) == 2 + +def test_adding_subsegments_with_recorder(): + xray_recorder.configure(sampling=False) + xray_recorder.clear_trace_entities() + + segment = xray_recorder.begin_segment('parent'); + subsegment = xray_recorder.begin_subsegment('sampled-child') + unsampled_subsegment = xray_recorder.begin_subsegment_without_sampling('unsampled-child1') + unsampled_child_subsegment = xray_recorder.begin_subsegment('unsampled-child2') + + assert segment.sampled == True + assert subsegment.sampled == True + assert unsampled_subsegment.sampled == False + assert unsampled_child_subsegment.sampled == False + + xray_recorder.clear_trace_entities() \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 939fde42..9c35ad84 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,10 @@ -from aws_xray_sdk.ext.util import to_snake_case, get_hostname, strip_url +from aws_xray_sdk.ext.util import to_snake_case, get_hostname, strip_url, inject_trace_header +from aws_xray_sdk.core.models.segment import Segment +from aws_xray_sdk.core.models.subsegment import Subsegment +from aws_xray_sdk.core.models.dummy_entities import DummySegment, DummySubsegment +from .util import get_new_stubbed_recorder + +xray_recorder = get_new_stubbed_recorder() UNKNOWN_HOST = "UNKNOWN HOST" @@ -52,3 +58,38 @@ def test_strip_url(): assert strip_url("") == "" assert not strip_url(None) + + +def test_inject_trace_header_unsampled(): + headers = {'host': 'test', 'accept': '*/*', 'connection': 'keep-alive', 'X-Amzn-Trace-Id': 'Root=1-6369739a-7d8bb07e519b795eb24d382d;Parent=089e3de743fb9e79;Sampled=1'} + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(sampling=True) + segment = xray_recorder.begin_segment('name', sampling=True) + subsegment = xray_recorder.begin_subsegment_without_sampling('unsampled') + + inject_trace_header(headers, subsegment) + + assert 'Sampled=0' in headers['X-Amzn-Trace-Id'] + +def test_inject_trace_header_respects_parent_subsegment(): + headers = {'host': 'test', 'accept': '*/*', 'connection': 'keep-alive', 'X-Amzn-Trace-Id': 'Root=1-6369739a-7d8bb07e519b795eb24d382d;Parent=089e3de743fb9e79;Sampled=1'} + + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(sampling=True) + segment = xray_recorder.begin_segment('name', sampling=True) + subsegment = xray_recorder.begin_subsegment_without_sampling('unsampled') + subsegment2 = xray_recorder.begin_subsegment('unsampled2') + inject_trace_header(headers, subsegment2) + + assert 'Sampled=0' in headers['X-Amzn-Trace-Id'] + +def test_inject_trace_header_sampled(): + headers = {'host': 'test', 'accept': '*/*', 'connection': 'keep-alive', 'X-Amzn-Trace-Id': 'Root=1-6369739a-7d8bb07e519b795eb24d382d;Parent=089e3de743fb9e79;Sampled=1'} + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(sampling=True) + segment = xray_recorder.begin_segment('name') + subsegment = xray_recorder.begin_subsegment('unsampled') + + inject_trace_header(headers, subsegment) + + assert 'Sampled=1' in headers['X-Amzn-Trace-Id'] \ No newline at end of file diff --git a/tox.ini b/tox.ini index e9dc394f..f973c697 100644 --- a/tox.ini +++ b/tox.ini @@ -19,12 +19,17 @@ envlist = ; Django3 is only for python 3.6+ py{36,37,38,39}-ext-django-3 + ; Django4 is only for python 3.8+ + py{38,39}-ext-django-4 + py{27,34,35,36,37,38,39}-ext-flask py{27,34,35,36,37,38,39}-ext-flask_sqlalchemy py{27,34,35,36,37,38,39}-ext-httplib + py{37,38,39}-ext-httpx + py{27,34,35,36,37,38,39}-ext-pg8000 py{27,34,35,36,37,38,39}-ext-psycopg2 @@ -75,6 +80,9 @@ deps = ; Also, the stable version is only supported for Python 3.7+ ext-aiohttp: pytest-aiohttp < 1.0.0 + ext-httpx: httpx >= 0.20 + ext-httpx: pytest-asyncio >= 0.19 + ext-requests: requests ext-bottle: bottle >= 0.10 @@ -83,7 +91,7 @@ deps = ext-flask: flask >= 0.10 ext-flask_sqlalchemy: flask >= 0.10 - ext-flask_sqlalchemy: Flask-SQLAlchemy + ext-flask_sqlalchemy: Flask-SQLAlchemy <= 2.5.1 ext-sqlalchemy: sqlalchemy @@ -93,6 +101,7 @@ deps = ext-django-2: Django >=2.0,<3.0 ext-django-3: Django >=3.0,<4.0 + ext-django-4: Django >=4.0,<5.0 ext-django: django-fake-model ext-pynamodb: pynamodb >= 3.3.1 @@ -135,6 +144,8 @@ commands = ext-httplib: coverage run --append --source aws_xray_sdk -m pytest tests/ext/httplib + ext-httpx: coverage run --append --source aws_xray_sdk -m pytest tests/ext/httpx + ext-pg8000: coverage run --append --source aws_xray_sdk -m pytest tests/ext/pg8000 ext-psycopg2: coverage run --append --source aws_xray_sdk -m pytest tests/ext/psycopg2