diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 3260fa3b35..cd8e576a26 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -37,6 +37,7 @@ "debug": bool, "attach_stacktrace": bool, "ca_certs": Optional[str], + "propagate_traces": bool, }, total=False, ) @@ -69,6 +70,7 @@ "debug": False, "attach_stacktrace": False, "ca_certs": None, + "propagate_traces": True, } diff --git a/sentry_sdk/hub.py b/sentry_sdk/hub.py index d81bf0171b..cfaa97349c 100644 --- a/sentry_sdk/hub.py +++ b/sentry_sdk/hub.py @@ -415,6 +415,18 @@ def flush(self, timeout=None, callback=None): if client is not None: return client.flush(timeout=timeout, callback=callback) + def iter_trace_propagation_headers(self): + client, scope = self._stack[-1] + if scope._span is None: + return + + propagate_traces = client and client.options["propagate_traces"] + if not propagate_traces: + return + + for item in scope._span.iter_headers(): + yield item + GLOBAL_HUB = Hub() _local.set(GLOBAL_HUB) diff --git a/sentry_sdk/integrations/celery.py b/sentry_sdk/integrations/celery.py index 5b7646d429..75b0c442af 100644 --- a/sentry_sdk/integrations/celery.py +++ b/sentry_sdk/integrations/celery.py @@ -6,6 +6,7 @@ from sentry_sdk.hub import Hub from sentry_sdk.utils import capture_internal_exceptions, event_from_exception +from sentry_sdk.tracing import SpanContext from sentry_sdk._compat import reraise from sentry_sdk.integrations import Integration from sentry_sdk.integrations.logging import ignore_logger @@ -14,6 +15,9 @@ class CeleryIntegration(Integration): identifier = "celery" + def __init__(self, propagate_traces=True): + self.propagate_traces = propagate_traces + @staticmethod def setup_once(): import celery.app.trace as trace # type: ignore @@ -25,6 +29,7 @@ def sentry_build_tracer(name, task, *args, **kwargs): # short-circuits to task.run if it thinks it's safe. task.__call__ = _wrap_task_call(task, task.__call__) task.run = _wrap_task_call(task, task.run) + task.apply_async = _wrap_apply_async(task, task.apply_async) return _wrap_tracer(task, old_build_tracer(name, task, *args, **kwargs)) trace.build_tracer = sentry_build_tracer @@ -37,6 +42,23 @@ def sentry_build_tracer(name, task, *args, **kwargs): ignore_logger("celery.worker.job") +def _wrap_apply_async(task, f): + def apply_async(self, *args, **kwargs): + hub = Hub.current + integration = hub.get_integration(CeleryIntegration) + if integration is not None and integration.propagate_traces: + headers = None + for key, value in hub.iter_trace_propagation_headers(): + if headers is None: + headers = dict(kwargs.get("headers") or {}) + headers[key] = value + if headers is not None: + kwargs["headers"] = headers + return f(self, *args, **kwargs) + + return apply_async + + def _wrap_tracer(task, f): # Need to wrap tracer for pushing the scope before prerun is sent, and # popping it after postrun is sent. @@ -52,6 +74,7 @@ def _inner(*args, **kwargs): with hub.push_scope() as scope: scope._name = "celery" scope.clear_breadcrumbs() + _continue_trace(args[3].get("headers") or {}, scope) scope.add_event_processor(_make_event_processor(task, *args, **kwargs)) return f(*args, **kwargs) @@ -59,6 +82,14 @@ def _inner(*args, **kwargs): return _inner +def _continue_trace(headers, scope): + if headers: + span_context = SpanContext.continue_from_headers(headers) + else: + span_context = SpanContext.start_trace() + scope.set_span_context(span_context) + + def _wrap_task_call(task, f): # Need to wrap task call because the exception is caught before we get to # see it. Also celery's reported stacktrace is untrustworthy. diff --git a/sentry_sdk/integrations/flask.py b/sentry_sdk/integrations/flask.py index 47bdb5d6a8..437e9fed0b 100644 --- a/sentry_sdk/integrations/flask.py +++ b/sentry_sdk/integrations/flask.py @@ -96,9 +96,10 @@ def _request_started(sender, **kwargs): if integration is None: return - weak_request = weakref.ref(_request_ctx_stack.top.request) app = _app_ctx_stack.top.app with hub.configure_scope() as scope: + request = _request_ctx_stack.top.request + weak_request = weakref.ref(request) scope.add_event_processor( _make_request_event_processor( # type: ignore app, weak_request, integration diff --git a/sentry_sdk/integrations/stdlib.py b/sentry_sdk/integrations/stdlib.py index 0d0b20cbf8..8b16c73ce4 100644 --- a/sentry_sdk/integrations/stdlib.py +++ b/sentry_sdk/integrations/stdlib.py @@ -24,7 +24,8 @@ def install_httplib(): def putrequest(self, method, url, *args, **kwargs): rv = real_putrequest(self, method, url, *args, **kwargs) - if Hub.current.get_integration(StdlibIntegration) is None: + hub = Hub.current + if hub.get_integration(StdlibIntegration) is None: return rv self._sentrysdk_data_dict = data = {} @@ -42,6 +43,9 @@ def putrequest(self, method, url, *args, **kwargs): url, ) + for key, value in hub.iter_trace_propagation_headers(): + self.putheader(key, value) + data["url"] = real_url data["method"] = method return rv diff --git a/sentry_sdk/integrations/wsgi.py b/sentry_sdk/integrations/wsgi.py index 05587c60d2..a9f2c615eb 100644 --- a/sentry_sdk/integrations/wsgi.py +++ b/sentry_sdk/integrations/wsgi.py @@ -3,6 +3,7 @@ from sentry_sdk.hub import Hub, _should_send_default_pii from sentry_sdk.utils import capture_internal_exceptions, event_from_exception from sentry_sdk._compat import PY2, reraise +from sentry_sdk.tracing import SpanContext from sentry_sdk.integrations._wsgi_common import _filter_headers if False: @@ -81,6 +82,7 @@ def __call__(self, environ, start_response): with hub.configure_scope() as scope: scope.clear_breadcrumbs() scope._name = "wsgi" + scope.set_span_context(SpanContext.continue_from_environ(environ)) scope.add_event_processor(_make_wsgi_event_processor(environ)) try: diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index 1762046058..ce4cb2c501 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -59,6 +59,7 @@ class Scope(object): "_event_processors", "_error_processors", "_should_capture", + "_span", ) def __init__(self): @@ -88,6 +89,10 @@ def user(self, value): """When set a specific user is bound to the scope.""" self._user = value + def set_span_context(self, span_context): + """Sets the span context.""" + self._span = span_context + def set_tag(self, key, value): """Sets a tag for a key to a specific value.""" self._tags[key] = value @@ -127,6 +132,8 @@ def clear(self): self.clear_breadcrumbs() self._should_capture = True + self._span = None + def clear_breadcrumbs(self): # type: () -> None """Clears breadcrumb buffer.""" @@ -193,6 +200,12 @@ def _drop(event, cause, ty): if self._contexts: event.setdefault("contexts", {}).update(self._contexts) + if self._span is not None: + event.setdefault("contexts", {})["trace"] = { + "trace_id": self._span.trace_id, + "span_id": self._span.span_id, + } + exc_info = hint.get("exc_info") if hint is not None else None if exc_info is not None: for processor in self._error_processors: @@ -230,6 +243,7 @@ def __copy__(self): rv._error_processors = list(self._error_processors) rv._should_capture = self._should_capture + rv._span = self._span return rv diff --git a/sentry_sdk/tracing.py b/sentry_sdk/tracing.py new file mode 100644 index 0000000000..37c1ee356d --- /dev/null +++ b/sentry_sdk/tracing.py @@ -0,0 +1,93 @@ +import re +import uuid + +_traceparent_header_format_re = re.compile( + "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})" "(-.*)?[ \t]*$" +) + + +class _EnvironHeaders(object): + def __init__(self, environ): + self.environ = environ + + def get(self, key): + return self.environ.get("HTTP_" + key.replace("-", "_").upper()) + + +class SpanContext(object): + def __init__(self, trace_id, span_id, recorded=False, parent=None): + self.trace_id = trace_id + self.span_id = span_id + self.recorded = recorded + self.parent = None + + def __repr__(self): + return "%s(trace_id=%r, span_id=%r, recorded=%r)" % ( + self.__class__.__name__, + self.trace_id, + self.span_id, + self.recorded, + ) + + @classmethod + def start_trace(cls, recorded=False): + return cls( + trace_id=uuid.uuid4().hex, span_id=uuid.uuid4().hex[16:], recorded=recorded + ) + + def new_span(self): + if self.trace_id is None: + return SpanContext.start_trace() + return SpanContext( + trace_id=self.trace_id, + span_id=uuid.uuid4().hex[16:], + parent=self, + recorded=self.recorded, + ) + + @classmethod + def continue_from_environ(cls, environ): + return cls.continue_from_headers(_EnvironHeaders(environ)) + + @classmethod + def continue_from_headers(cls, headers): + parent = cls.from_traceparent(headers.get("sentry-trace")) + if parent is None: + return cls.start_trace() + return parent.new_span() + + def iter_headers(self): + yield "sentry-trace", self.to_traceparent() + + @classmethod + def from_traceparent(cls, traceparent): + if not traceparent: + return None + + match = _traceparent_header_format_re.match(traceparent) + if match is None: + return None + + version, trace_id, span_id, trace_options, extra = match.groups() + + if int(trace_id, 16) == 0 or int(span_id, 16) == 0: + return None + + version = int(version, 16) + if version == 0: + if extra: + return None + elif version == 255: + return None + + options = int(trace_options, 16) + + return cls(trace_id=trace_id, span_id=span_id, recorded=options & 1 != 0) + + def to_traceparent(self): + return "%02x-%s-%s-%02x" % ( + 0, + self.trace_id, + self.span_id, + self.recorded and 1 or 0, + ) diff --git a/tests/integrations/celery/test_celery.py b/tests/integrations/celery/test_celery.py index dc5e1f3e91..758bb04783 100644 --- a/tests/integrations/celery/test_celery.py +++ b/tests/integrations/celery/test_celery.py @@ -4,8 +4,9 @@ pytest.importorskip("celery") -from sentry_sdk import Hub +from sentry_sdk import Hub, configure_scope from sentry_sdk.integrations.celery import CeleryIntegration +from sentry_sdk.tracing import SpanContext from celery import Celery, VERSION from celery.bin import worker @@ -22,8 +23,8 @@ def inner(signal, f): @pytest.fixture def init_celery(sentry_init): - def inner(): - sentry_init(integrations=[CeleryIntegration()]) + def inner(propagate_traces=True): + sentry_init(integrations=[CeleryIntegration(propagate_traces=propagate_traces)]) celery = Celery(__name__) if VERSION < (4,): celery.conf.CELERY_ALWAYS_EAGER = True @@ -47,9 +48,15 @@ def dummy_task(x, y): foo = 42 # noqa return x / y + span_context = SpanContext.start_trace() + with configure_scope() as scope: + scope.set_span_context(span_context) dummy_task.delay(1, 2) dummy_task.delay(1, 0) + event, = events + assert event["contexts"]["trace"]["trace_id"] == span_context.trace_id + assert event["contexts"]["trace"]["span_id"] != span_context.span_id assert event["transaction"] == "dummy_task" assert event["extra"]["celery-job"] == { "args": [1, 0], @@ -63,6 +70,26 @@ def dummy_task(x, y): assert exception["stacktrace"]["frames"][0]["vars"]["foo"] == "42" +def test_simple_no_propagation(capture_events, init_celery): + celery = init_celery(propagate_traces=False) + events = capture_events() + + @celery.task(name="dummy_task") + def dummy_task(): + 1 / 0 + + span_context = SpanContext.start_trace() + with configure_scope() as scope: + scope.set_span_context(span_context) + dummy_task.delay() + + event, = events + assert event["contexts"]["trace"]["trace_id"] != span_context.trace_id + assert event["transaction"] == "dummy_task" + exception, = event["exception"]["values"] + assert exception["type"] == "ZeroDivisionError" + + def test_ignore_expected(capture_events, celery): events = capture_events()