diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 683cb52f..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: if results_truncated: logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) return - + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed..3353237d 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -11,12 +11,16 @@ from importlib.metadata import version from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult from ..types.content import Message, Messages @@ -133,16 +137,30 @@ def __init__( self.service_name = service_name self.otlp_headers = otlp_headers or {} - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + # Create resource with service information resource = Resource.create( { @@ -154,7 +172,7 @@ def _initialize_tracer(self) -> None: ) # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -190,15 +208,19 @@ def _initialize_tracer(self) -> None: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -212,7 +234,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -224,7 +246,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -239,7 +261,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -272,13 +294,13 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, 'force_flush'): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -294,12 +316,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -328,7 +350,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -347,9 +369,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -374,7 +394,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -402,10 +422,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -436,7 +456,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -466,7 +486,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -506,7 +526,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -557,13 +577,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 32a4ac0a..ac4ea257 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -4,7 +4,9 @@ from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -18,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -38,7 +52,7 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @@ -104,8 +118,17 @@ def env_with_both(): yield -def test_init_default(): +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -141,9 +164,14 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" + mock_is_initialized.return_value = False mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -161,8 +189,12 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -191,7 +223,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -262,7 +294,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -300,7 +332,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -338,7 +370,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -373,7 +405,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -464,9 +496,11 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") @@ -486,6 +520,18 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -530,7 +576,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer