8000 fix(telemetry): fix agent span start and end when using Agent.stream_… · lgigit200/sdk-python@bd60f90 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit bd60f90

Browse files
authored
fix(telemetry): fix agent span start and end when using Agent.stream_async() (strands-agents#119)
1 parent a331e63 commit bd60f90

File tree

2 files changed

+134
-19
lines changed

2 files changed

+134
-19
lines changed

src/strands/agent/agent.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
328328
- metrics: Performance metrics from the event loop
329329
- state: The final state of the event loop
330330
"""
331-
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
332-
333-
self.trace_span = self.tracer.start_agent_span(
334-
prompt=prompt,
335-
model_id=model_id,
336-
tools=self.tool_names,
337-
system_prompt=self.system_prompt,
338-
custom_trace_attributes=self.trace_attributes,
339-
)
331+
self._start_agent_trace_span(prompt)
340332

341333
try:
342334
# Run the event loop and get the result
343335
result = self._run_loop(prompt, kwargs)
344336

345-
if self.trace_span:
346-
self.tracer.end_agent_span(span=self.trace_span, response=result)
337+
self._end_agent_trace_span(response=result)
347338

348339
return result
349340
except Exception as e:
350-
if self.trace_span:
351-
self.tracer.end_agent_span(span=self.trace_span, error=e)
341+
self._end_agent_trace_span(error=e)
352342

353343
# Re-raise the exception to preserve original behavior
354344
raise
@@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
383373
yield event["data"]
384374
```
385375
"""
376+
self._start_agent_trace_span(prompt)
377+
386378
_stop_event = uuid4()
387379

388380
queue = asyncio.Queue[Any]()
@@ -400,8 +392,10 @@ def target_callback() -> None:
400392
nonlocal kwargs
401393

402394
try:
403-
self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
404-
except BaseException as e:
395+
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
396+
self._end_agent_trace_span(response=result)
397+
except Exception as e:
398+
self._end_agent_trace_span(error=e)
405399
enqueue(e)
406400
finally:
407401
enqueue(_stop_event)
@@ -414,7 +408,7 @@ def target_callback() -> None:
414408
item = await queue.get()
415409
if item == _stop_event:
416410
break
417-
if isinstance(item, BaseException):
411+
if isinstance(item, Exception):
418412
raise item
419413
yield item
420414
finally:
@@ -546,3 +540,43 @@ def _record_tool_execution(
546540
messages.append(tool_use_msg)
547541
messages.append(tool_result_msg)
548542
messages.append(assistant_msg)
543+
544+
def _start_agent_trace_span(self, prompt: str) -> None:
545+
"""Starts a trace span for the agent.
546+
547+
Args:
548+
prompt: The natural language prompt from the user.
549+
"""
550+
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
551+
552+
self.trace_span = self.tracer.start_agent_span(
553+
prompt=prompt,
554+
model_id=model_id,
555+
tools=self.tool_names,
556+
system_prompt=self.system_prompt,
557+
custom_trace_attributes=self.trace_attributes,
558+
)
559+
560+
def _end_agent_trace_span(
561+
self,
562+
response: Optional[AgentResult] = None,
563+
error: Optional[Exception] = None,
564+
) -> None:
565+
"""Ends a trace span for the agent.
566+
567+
Args:
568+
span: The span to end.
569+
response: Response to record as a trace attribute.
570+
error: Error to record as a trace attribute.
571+
"""
572+
if self.trace_span:
573+
trace_attributes: Dict[str, Any] = {
574+
"span": self.trace_span,
575+
}
576+
577+
if response:
578+
trace_attributes["response"] = response
579+
if error:
580+
trace_attributes["error"] = error
581+
582+
self.tracer.end_agent_span(**trace_attributes)

tests/strands/agent/test_agent.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import pytest
1010

1111
import strands
12-
from strands.agent.agent import Agent
12+
from strands import Agent
13+
from strands.agent import AgentResult
1314
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
1415
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
1516
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
@@ -687,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler():
687688

688689
@pytest.mark.asyncio
689690
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
690-
mock_event_loop_cycle.side_effect = ValueError("Test exception")
691-
692691
agent = Agent()
693692

694693
# Define the side effect to simulate callback handler being called multiple times
@@ -952,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
952951
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result)
953952

954953

954+
@pytest.mark.asyncio
955+
@unittest.mock.patch("strands.agent.agent.get_tracer")
956+
async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle):
957+
"""Test that stream_async creates and ends a span when the call succeeds."""
958+
# Setup mock tracer and span
959+
mock_tracer = unittest.mock.MagicMock()
960+
mock_span = unittest.mock.MagicMock()
961+
mock_tracer.start_agent_span.return_value = mock_span
962+
mock_get_tracer.return_value = mock_tracer
963+
964+
# Define the side effect to simulate callback handler being called multiple times
965+
def call_callback_handler(*args, **kwargs):
966+
# Extract the callback handler from kwargs
967+
callback_handler = kwargs.get("callback_handler")
968+
# Call the callback handler with different data values
969+
callback_handler(data="First chunk")
970+
callback_handler(data="Second chunk")
971+
callback_handler(data="Final chunk", complete=True)
972+
# Return expected values from event_loop_cycle
973+
return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}
974+
975+
mock_event_loop_cycle.side_effect = call_callback_handler
976+
977+
# Create agent and make a call
978+
agent = Agent(model=mock_model)
979+
iterator = agent.stream_async("test prompt")
980+
async for _event in iterator:
981+
pass # NoOp
982+
983+
# Verify span was created
984+
mock_tracer.start_agent_span.assert_called_once_with(
985+
prompt="test prompt",
986+
model_id=unittest.mock.ANY,
987+
tools=agent.tool_names,
988+
system_prompt=agent.system_prompt,
989+
custom_trace_attributes=agent.trace_attributes,
990+
)
991+
992+
expected_response = AgentResult(
993+
stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={}
994+
)
995+
996+
# Verify span was ended with the result
997+
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response)
998+
999+
9551000
@unittest.mock.patch("strands.agent.agent.get_tracer")
9561001
def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
9571002
"""Test that __call__ creates and ends a span when an exception occurs."""
@@ -985,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
9851030
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
9861031

9871032

1033+
@pytest.mark.asyncio
1034+
@unittest.mock.patch("strands.agent.agent.get_tracer")
1035+
async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
1036+
"""Test that stream_async creates and ends a span when the call succeeds."""
1037+
# Setup mock tracer and span
1038+
mock_tracer = unittest.mock.MagicMock()
1039+
mock_span = unittest.mock.MagicMock()
1040+
mock_tracer.start_agent_span.return_value = mock_span
1041+
mock_get_tracer.return_value = mock_tracer
1042+
1043+
# Define the side effect to simulate callback handler raising an Exception
1044+
test_exception = ValueError("Test exception")
1045+
mock_model.mock_converse.side_effect = test_exception
1046+
1047+
# Create agent and make a call
1048+
agent = Agent(model=mock_model)
1049+
1050+
# Call the agent and catch the exception
1051+
with pytest.raises(ValueError):
1052+
iterator = agent.stream_async("test prompt")
1053+
async for _event in iterator:
1054+
pass # NoOp
1055+
1056+
# Verify span was created
1057+
mock_tracer.start_agent_span.assert_called_once_with(
1058+
prompt="test prompt",
1059+
model_id=unittest.mock.ANY,
1060+
tools=agent.tool_names,
1061+
system_prompt=agent.system_prompt,
1062+
custom_trace_attributes=agent.trace_attributes,
1063+
)
1064+
1065+
# Verify span was ended with the exception
1066+
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
1067+
1068+
9881069
@unittest.mock.patch("strands.agent.agent.get_tracer")
9891070
def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model):
9901071
"""Test that event_loop_cycle is called with the parent span."""

0 commit comments

Comments
 (0)
0