|
9 | 9 | import pytest
|
10 | 10 |
|
11 | 11 | import strands
|
12 |
| -from strands.agent.agent import Agent |
| 12 | +from strands import Agent |
| 13 | +from strands.agent import AgentResult |
13 | 14 | from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
|
14 | 15 | from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
|
15 | 16 | from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
|
@@ -687,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler():
|
687 | 688 |
|
688 | 689 | @pytest.mark.asyncio
|
689 | 690 | async def test_stream_async_returns_all_events(mock_event_loop_cycle):
|
690 |
| - mock_event_loop_cycle.side_effect = ValueError("Test exception") |
691 |
| - |
692 | 691 | agent = Agent()
|
693 | 692 |
|
694 | 693 | # Define the side effect to simulate callback handler being called multiple times
|
@@ -952,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
|
952 | 951 | mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result)
|
953 | 952 |
|
954 | 953 |
|
| 954 | +@pytest.mark.asyncio |
| 955 | +@unittest.mock.patch("strands.agent.agent.get_tracer") |
| 956 | +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): |
| 957 | + """Test that stream_async creates and ends a span when the call succeeds.""" |
| 958 | + # Setup mock tracer and span |
| 959 | + mock_tracer = unittest.mock.MagicMock() |
| 960 | + mock_span = unittest.mock.MagicMock() |
| 961 | + mock_tracer.start_agent_span.return_value = mock_span |
| 962 | + mock_get_tracer.return_value = mock_tracer |
| 963 | + |
| 964 | + # Define the side effect to simulate callback handler being called multiple times |
| 965 | + def call_callback_handler(*args, **kwargs): |
| 966 | + # Extract the callback handler from kwargs |
| 967 | + callback_handler = kwargs.get("callback_handler") |
| 968 | + # Call the callback handler with different data values |
| 969 | + callback_handler(data="First chunk") |
| 970 | + callback_handler(data="Second chunk") |
| 971 | + callback_handler(data="Final chunk", complete=True) |
| 972 | + # Return expected values from event_loop_cycle |
| 973 | + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} |
| 974 | + |
| 975 | + mock_event_loop_cycle.side_effect = call_callback_handler |
| 976 | + |
| 977 | + # Create agent and make a call |
| 978 | + agent = Agent(model=mock_model) |
| 979 | + iterator = agent.stream_async("test prompt") |
| 980 | + async for _event in iterator: |
| 981 | + pass # NoOp |
| 982 | + |
| 983 | + # Verify span was created |
| 984 | + mock_tracer.start_agent_span.assert_called_once_with( |
| 985 | + prompt="test prompt", |
| 986 | + model_id=unittest.mock.ANY, |
| 987 | + tools=agent.tool_names, |
| 988 | + system_prompt=agent.system_prompt, |
| 989 | + custom_trace_attributes=agent.trace_attributes, |
| 990 | + ) |
| 991 | + |
| 992 | + expected_response = AgentResult( |
| 993 | + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} |
| 994 | + ) |
| 995 | + |
| 996 | + # Verify span was ended with the result |
| 997 | + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) |
| 998 | + |
| 999 | + |
955 | 1000 | @unittest.mock.patch("strands.agent.agent.get_tracer")
|
956 | 1001 | def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
|
957 | 1002 | """Test that __call__ creates and ends a span when an exception occurs."""
|
@@ -985,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
|
985 | 1030 | mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
|
986 | 1031 |
|
987 | 1032 |
|
| 1033 | +@pytest.mark.asyncio |
| 1034 | +@unittest.mock.patch("strands.agent.agent.get_tracer") |
| 1035 | +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): |
| 1036 | + """Test that stream_async creates and ends a span when the call succeeds.""" |
| 1037 | + # Setup mock tracer and span |
| 1038 | + mock_tracer = unittest.mock.MagicMock() |
| 1039 | + mock_span = unittest.mock.MagicMock() |
| 1040 | + mock_tracer.start_agent_span.return_value = mock_span |
| 1041 | + mock_get_tracer.return_value = mock_tracer |
| 1042 | + |
| 1043 | + # Define the side effect to simulate callback handler raising an Exception |
| 1044 | + test_exception = ValueError("Test exception") |
| 1045 | + mock_model.mock_converse.side_effect = test_exception |
| 1046 | + |
| 1047 | + # Create agent and make a call |
| 1048 | + agent = Agent(model=mock_model) |
| 1049 | + |
| 1050 | + # Call the agent and catch the exception |
| 1051 | + with pytest.raises(ValueError): |
| 1052 | + iterator = agent.stream_async("test prompt") |
| 1053 | + async for _event in iterator: |
| 1054 | + pass # NoOp |
| 1055 | + |
| 1056 | + # Verify span was created |
| 1057 | + mock_tracer.start_agent_span.assert_called_once_with( |
| 1058 | + prompt="test prompt", |
| 1059 | + model_id=unittest.mock.ANY, |
| 1060 | + tools=agent.tool_names, |
| 1061 | + system_prompt=agent.system_prompt, |
| 1062 | + custom_trace_attributes=agent.trace_attributes, |
| 1063 | + ) |
| 1064 | + |
| 1065 | + # Verify span was ended with the exception |
| 1066 | + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) |
| 1067 | + |
| 1068 | + |
988 | 1069 | @unittest.mock.patch("strands.agent.agent.get_tracer")
|
989 | 1070 | def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model):
|
990 | 1071 | """Test that event_loop_cycle is called with the parent span."""
|
|
0 commit comments