From c28737c8aeaaf6d47e3f423ec4da967a72314eab Mon Sep 17 00:00:00 2001 From: Stefano Amorelli Date: Wed, 11 Jun 2025 01:58:39 +0300 Subject: [PATCH 01/22] feat(summarizing_conversation_manager): implement summarization strategy (#112) --- src/strands/agent/__init__.py | 8 +- .../agent/conversation_manager/__init__.py | 10 +- .../summarizing_conversation_manager.py | 222 +++++++ ...rizing_conversation_manager_integration.py | 374 ++++++++++++ .../test_summarizing_conversation_manager.py | 566 ++++++++++++++++++ 5 files changed, 1178 insertions(+), 2 deletions(-) create mode 100644 src/strands/agent/conversation_manager/summarizing_conversation_manager.py create mode 100644 tests-integ/test_summarizing_conversation_manager_integration.py create mode 100644 tests/strands/agent/test_summarizing_conversation_manager.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 4d2fa1fe..6618d332 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,7 +8,12 @@ from .agent import Agent from .agent_result import AgentResult -from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager +from .conversation_manager import ( + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +) __all__ = [ "Agent", @@ -16,4 +21,5 @@ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", + "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 68541877..c5962321 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -6,6 +6,8 @@ - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence +- SummarizingConversationManager: An implementation that summarizes older context instead + of simply trimming it Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -14,5 +16,11 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager +from .summarizing_conversation_manager import SummarizingConversationManager -__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] +__all__ = [ + "ConversationManager", + "NullConversationManager", + "SlidingWindowConversationManager", + "SummarizingConversationManager", +] diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py new file mode 100644 index 00000000..a6b112dd --- /dev/null +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -0,0 +1,222 @@ +"""Summarizing conversation history management with configurable options.""" + +import logging +from typing import TYPE_CHECKING, List, Optional + +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + + +logger = logging.getLogger(__name__) + + +DEFAULT_SUMMARIZATION_PROMPT = """You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information +* +## Tools Executed +* Tool X: Result Y""" + + +class SummarizingConversationManager(ConversationManager): + """Implements a summarizing window manager. + + This manager provides a configurable option to summarize older context instead of + simply trimming it, helping preserve important information while staying within + context limits. + """ + + def __init__( + self, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_agent: Optional["Agent"] = None, + summarization_system_prompt: Optional[str] = None, + ): + """Initialize the summarizing conversation manager. + + Args: + summary_ratio: Ratio of messages to summarize vs keep when context overflow occurs. + Value between 0.1 and 0.8. Defaults to 0.3 (summarize 30% of oldest messages). + preserve_recent_messages: Minimum number of recent messages to always keep. + Defaults to 10 messages. + summarization_agent: Optional agent to use for summarization instead of the parent agent. + If provided, this agent can use tools as part of the summarization process. + summarization_system_prompt: Optional system prompt override for summarization. + If None, uses the default summarization prompt. + """ + if summarization_agent is not None and summarization_system_prompt is not None: + raise ValueError( + "Cannot provide both summarization_agent and summarization_system_prompt. " + "Agents come with their own system prompt." + ) + + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) + self.preserve_recent_messages = preserve_recent_messages + self.summarization_agent = summarization_agent + self.summarization_system_prompt = summarization_system_prompt + + def apply_management(self, agent: "Agent") -> None: + """Apply management strategy to conversation history. + + For the summarizing conversation manager, no proactive management is performed. + Summarization only occurs when there's a context overflow that triggers reduce_context. + + Args: + agent: The agent whose conversation history will be managed. + The agent's messages list is modified in-place. + """ + # No proactive management - summarization only happens on context overflow + pass + + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: + """Reduce context using summarization. + + Args: + agent: The agent whose conversation history will be reduced. + The agent's messages list is modified in-place. + e: The exception that triggered the context reduction, if any. + + Raises: + ContextWindowOverflowException: If the context cannot be summarized. + """ + try: + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] + + # Generate summary + summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [summary_message] + remaining_messages + + except Exception as summarization_error: + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + """Generate a summary of the provided messages. + + Args: + messages: The messages to summarize. + agent: The agent instance to use for summarization. + + Returns: + A message containing the conversation summary. + + Raises: + Exception: If summary generation fails. + """ + # Choose which agent to use for summarization + summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + + # Save original system prompt and messages to restore later + original_system_prompt = summarization_agent.system_prompt + original_messages = summarization_agent.messages.copy() + + try: + # Only override system prompt if no agent was provided during initialization + if self.summarization_agent is None: + # Use custom system prompt if provided, otherwise use default + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + # Temporarily set the system prompt for summarization + summarization_agent.system_prompt = system_prompt + summarization_agent.messages = messages + + # Use the agent to generate summary with rich content (can use tools if needed) + result = summarization_agent("Please summarize this conversation.") + + return result.message + + finally: + # Restore original agent state + summarization_agent.system_prompt = original_system_prompt + summarization_agent.messages = original_messages + + def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. + + Uses the same logic as SlidingWindowConversationManager for consistency. + + Args: + messages: The full list of messages. + split_point: The initially calculated split point. + + Returns: + The adjusted split point that doesn't break ToolUse/ToolResult pairs. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + if split_point > len(messages): + raise ContextWindowOverflowException("Split point exceeds message array length") + + if split_point == len(messages): + return split_point + + # Find the next valid split_point + while split_point < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[split_point]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[split_point]["content"]) + and split_point + 1 < len(messages) + and not any("toolResult" in content for content in messages[split_point + 1]["content"]) + ) + ): + split_point += 1 + else: + break + else: + # If we didn't find a valid split_point, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split_point diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 00000000..5dcf4944 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py new file mode 100644 index 00000000..9952203e --- /dev/null +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -0,0 +1,566 @@ +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.types.content import Messages +from strands.types.exceptions import ContextWindowOverflowException + +if TYPE_CHECKING: + from strands.agent.agent import Agent + + +class MockAgent: + """Mock agent for testing summarization.""" + + def __init__(self, summary_response="This is a summary of the conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.call_tracker = Mock() + + def __call__(self, prompt): + """Mock agent call that returns a summary.""" + self.call_tracker(prompt) + result = Mock() + result.message = {"role": "assistant", "content": [{"text": self.summary_response}]} + return result + + +def create_mock_agent(summary_response="This is a summary of the conversation.") -> "Agent": + """Factory function that returns a properly typed MockAgent.""" + return cast("Agent", MockAgent(summary_response)) + + +@pytest.fixture +def mock_agent(): + """Fixture for mock agent.""" + return create_mock_agent() + + +@pytest.fixture +def summarizing_manager(): + """Fixture for summarizing conversation manager with default settings.""" + return SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + + +def test_init_default_values(): + """Test initialization with default values.""" + manager = SummarizingConversationManager() + + assert manager.summarization_agent is None + assert manager.summary_ratio == 0.3 + assert manager.preserve_recent_messages == 10 + + +def test_init_clamps_summary_ratio(): + """Test that summary_ratio is clamped to valid range.""" + # Test lower bound + manager = SummarizingConversationManager(summary_ratio=0.05) + assert manager.summary_ratio == 0.1 + + # Test upper bound + manager = SummarizingConversationManager(summary_ratio=0.95) + assert manager.summary_ratio == 0.8 + + +def test_reduce_context_raises_when_no_agent(): + """Test that reduce_context raises exception when agent has no messages.""" + manager = SummarizingConversationManager() + + # Create a mock agent with no messages + mock_agent = Mock() + empty_messages: Messages = [] + mock_agent.messages = empty_messages + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_with_summarization(summarizing_manager, mock_agent): + """Test reduce_context with summarization enabled.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Message 3"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + mock_agent.messages = test_messages + + summarizing_manager.reduce_context(mock_agent) + + # Should have: 1 summary message + 2 preserved recent messages + remaining from summarization + assert len(mock_agent.messages) == 4 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + first_content = mock_agent.messages[0]["content"][0] + assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] + + # Recent messages should be preserved + assert "Message 3" in str(mock_agent.messages[-2]["content"]) + assert "Response 3" in str(mock_agent.messages[-1]["content"]) + + +def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, mock_agent): + """Test that reduce_context raises exception when there are too few messages to summarize effectively.""" + # Create a scenario where calculation results in 0 messages to summarize + manager = SummarizingConversationManager( + summary_ratio=0.1, # Very small ratio + preserve_recent_messages=5, # High preservation + ) + + insufficient_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_insufficient_messages_for_summarization(mock_agent): + """Test reduce_context when there aren't enough messages to summarize.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=3, + ) + + insufficient_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_messages + + # This should raise an exception since there aren't enough messages to summarize + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_raises_on_summarization_failure(): + """Test that reduce_context raises exception when summarization fails.""" + # Create an agent that will fail + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + failing_agent_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + ] + failing_agent.messages = failing_agent_messages + + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + with pytest.raises(Exception, match="Agent failed"): + manager.reduce_context(failing_agent) + + # Should log the error + mock_logger.error.assert_called_once() + + +def test_generate_summary(summarizing_manager, mock_agent): + """Test the _generate_summary method.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + summary = summarizing_manager._generate_summary(test_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): + """Test summary generation with tool use and results.""" + tool_messages: Messages = [ + {"role": "user", "content": [{"text": "Use a tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + summary = summarizing_manager._generate_summary(tool_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_raises_on_agent_failure(): + """Test that _generate_summary raises exception when agent fails.""" + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + empty_failing_messages: Messages = [] + failing_agent.messages = empty_failing_messages + + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should raise the exception from the agent + with pytest.raises(Exception, match="Agent failed"): + manager._generate_summary(messages, failing_agent) + + +def test_adjust_split_point_for_tool_pairs(summarizing_manager): + """Test that the split point is adjusted to avoid breaking ToolUse/ToolResult pairs.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Tool output"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Response after tool"}]}, + ] + + # If we try to split at message 2 (the ToolResult), it should move forward to message 3 + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert adjusted_split == 3 # Should move to after the ToolResult + + # If we try to split at message 3, it should be fine (no tool issues) + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + assert adjusted_split == 3 + + # If we try to split at message 1 (toolUse with following toolResult), it should be valid + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert adjusted_split == 1 # Should be valid because toolResult follows + + +def test_apply_management_no_op(summarizing_manager, mock_agent): + """Test apply_management does not modify messages (no-op behavior).""" + apply_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More messages"}]}, + {"role": "assistant", "content": [{"text": "Even more"}]}, + ] + mock_agent.messages = apply_test_messages + original_messages = mock_agent.messages.copy() + + summarizing_manager.apply_management(mock_agent) + + # Should never modify messages - summarization only happens on context overflow + assert mock_agent.messages == original_messages + + +def test_init_with_custom_parameters(): + """Test initialization with custom parameters.""" + mock_agent = create_mock_agent() + + manager = SummarizingConversationManager( + summary_ratio=0.4, + preserve_recent_messages=5, + summarization_agent=mock_agent, + ) + assert manager.summary_ratio == 0.4 + assert manager.preserve_recent_messages == 5 + assert manager.summarization_agent == mock_agent + assert manager.summarization_system_prompt is None + + +def test_init_with_both_agent_and_prompt_raises_error(): + """Test that providing both agent and system prompt raises ValueError.""" + mock_agent = create_mock_agent() + custom_prompt = "Custom summarization prompt" + + with pytest.raises(ValueError, match="Cannot provide both summarization_agent and summarization_system_prompt"): + SummarizingConversationManager( + summarization_agent=mock_agent, + summarization_system_prompt=custom_prompt, + ) + + +def test_uses_summarization_agent_when_provided(): + """Test that summarization_agent is used when provided.""" + summary_agent = create_mock_agent("Custom summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the dedicated summarization agent, not the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Custom summary from dedicated agent" + + # Assert that the summarization agent was called + summary_agent.call_tracker.assert_called_once() + + +def test_uses_parent_agent_when_no_summarization_agent(): + """Test that parent agent is used when no summarization_agent is provided.""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Parent agent summary" + + # Assert that the parent agent was called + parent_agent.call_tracker.assert_called_once() + + +def test_uses_custom_system_prompt(): + """Test that custom system prompt is used when provided.""" + custom_prompt = "Custom system prompt for summarization" + manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) + mock_agent = create_mock_agent() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Capture the agent's system prompt changes + original_prompt = mock_agent.system_prompt + manager._generate_summary(messages, mock_agent) + + # The agent's system prompt should be restored after summarization + assert mock_agent.system_prompt == original_prompt + + +def test_agent_state_restoration(): + """Test that agent state is properly restored after summarization.""" + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + + # Set initial state + original_system_prompt = "Original system prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] + mock_agent.system_prompt = original_system_prompt + mock_agent.messages = original_messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager._generate_summary(messages, mock_agent) + + # State should be restored + assert mock_agent.system_prompt == original_system_prompt + assert mock_agent.messages == original_messages + + +def test_agent_state_restoration_on_exception(): + """Test that agent state is restored even when summarization fails.""" + manager = SummarizingConversationManager() + + # Create an agent that fails during summarization + mock_agent = Mock() + mock_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + mock_agent.messages = agent_messages + mock_agent.side_effect = Exception("Summarization failed") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should restore state even on exception + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, mock_agent) + + # State should still be restored + assert mock_agent.system_prompt == "Original prompt" + + +def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): + """Test that tool pair adjustment works correctly with the forward-search logic.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + mock_agent = create_mock_agent() + # Create messages where the split point would be adjusted to 0 due to tool pairs + tool_pair_messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest message"}]}, + ] + mock_agent.messages = tool_pair_messages + + # With 3 messages, preserve_recent_messages=1, summary_ratio=0.5: + # messages_to_summarize_count = (3 - 1) * 0.5 = 1 + # But split point adjustment will move forward from the toolUse, potentially increasing count + manager.reduce_context(mock_agent) + # Should have summary + remaining messages + assert len(mock_agent.messages) == 2 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + summary_content = mock_agent.messages[0]["content"][0] + assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] + + # Last message should be the preserved recent message + assert mock_agent.messages[1]["role"] == "user" + assert mock_agent.messages[1]["content"][0]["text"] == "Latest message" + + +def test_adjust_split_point_exceeds_message_length(summarizing_manager): + """Test that split point exceeding message array length raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Try to split at point 5 when there are only 2 messages + with pytest.raises(ContextWindowOverflowException, match="Split point exceeds message array length"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 5) + + +def test_adjust_split_point_equals_message_length(summarizing_manager): + """Test that split point equal to message array length returns unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Split point equals message length (2) - should return unchanged + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 2 + + +def test_adjust_split_point_no_tool_result_at_split(summarizing_manager): + """Test split point that doesn't contain tool result, ensuring we reach return split_point.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + + # Split point message is not a tool result, so it should directly return split_point + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert result == 1 + + +def test_adjust_split_point_tool_result_without_tool_use(summarizing_manager): + """Test that having tool results without tool uses raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Has tool result but no tool use - invalid state + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + + +def test_adjust_split_point_tool_result_moves_to_end(summarizing_manager): + """Test tool result at split point moves forward to valid position at end.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "different_tool", "input": {}}}]}, + ] + + # Split at message 2 (toolResult) - will move forward to message 3 (toolUse at end is valid) + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 3 + + +def test_adjust_split_point_tool_result_no_forward_position(summarizing_manager): + """Test tool result at split point where forward search finds no valid position.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "Message between"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Split at message 3 (toolResult) - will try to move forward but no valid position exists + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + + +def test_reduce_context_adjustment_returns_zero(): + """Test that tool pair adjustment can return zero, triggering the check at line 122.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + # Mock the adjustment method to return 0 + def mock_adjust(messages, split_point): + return 0 # This should trigger the <= 0 check at line 122 + + manager._adjust_split_point_for_tool_pairs = mock_adjust + + mock_agent = Mock() + simple_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = simple_messages + + # The adjustment method will return 0, which should trigger line 122-123 + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) From 7c7f91eddc0e7d0ecc99f23b31fbbde047797959 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 11 Jun 2025 10:14:49 -0400 Subject: [PATCH 02/22] chore: moved truncation logic to conversation manager and added should_truncate_results (#192) --- .../sliding_window_conversation_manager.py | 83 +++++++++++- src/strands/event_loop/error_handler.py | 69 +--------- src/strands/event_loop/event_loop.py | 17 +-- src/strands/event_loop/message_processor.py | 59 +-------- tests/strands/agent/test_agent.py | 40 +++++- .../agent/test_conversation_manager.py | 37 +++++- .../strands/event_loop/test_error_handler.py | 125 +----------------- tests/strands/event_loop/test_event_loop.py | 49 ------- .../event_loop/test_message_processor.py | 81 ------------ 9 files changed, 166 insertions(+), 394 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 3381247c..683cb52f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results def apply_management(self, agent: "Agent") -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: converted. """ messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: # Overwrite message history messages[:] = messages[trim_index:] + + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. + + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. + + Args: + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. + + Returns: + True if any changes were made to the message, False otherwise. + """ + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", + ) + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index a5c85668..6dc0d9ee 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f..71165926 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,7 +22,7 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -160,16 +160,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -248,6 +239,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d7..4e1a39dc 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 02b1470b..60b38ffa 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index bbec3cd1..7d43199e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -25,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -168,7 +169,7 @@ def test_apply_management(conversation_manager, messages, expected_messages): def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): - manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, @@ -182,6 +183,40 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con assert messages == original_messages +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + + assert messages == expected_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef..fe1b3e9f 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f1..8c46e009 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -157,55 +157,6 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( - model, - model_id, - system_prompt, - messages, - tool_config, - callback_handler, - tool_handler, - tool_execution_handler, -): - model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], - ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) - - tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( - model=model, - model_id=model_id, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, - ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state - - @unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( model, diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py index 395c71a1..fcf531df 100644 --- a/tests/strands/event_loop/test_message_processor.py +++ b/tests/strands/event_loop/test_message_processor.py @@ -45,84 +45,3 @@ def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): result = message_processor.clean_orphaned_empty_tool_uses(test_messages) assert result == expected assert test_messages == expected_messages - - -@pytest.mark.parametrize( - "messages,expected_idx", - [ - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - 1, - ), - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - None, - ), - ( - [], - None, - ), - ], -) -def test_find_last_message_with_tool_results(messages, expected_idx): - idx = message_processor.find_last_message_with_tool_results(messages) - assert idx == expected_idx - - -@pytest.mark.parametrize( - "messages,msg_idx,expected_changed,expected_content", - [ - ( - [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], - } - ], - 0, - True, - [ - { - "toolResult": { - "toolUseId": "1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - } - } - ], - ), - ( - [{"role": "user", "content": [{"text": "no tool result"}]}], - 0, - False, - [{"text": "no tool result"}], - ), - ( - [], - 0, - False, - [], - ), - ( - [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], - 2, - False, - [{"toolResult": {"toolUseId": "1"}}], - ), - ], -) -def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): - test_messages = copy.deepcopy(messages) - changed = message_processor.truncate_tool_results(test_messages, msg_idx) - assert changed == expected_changed - if 0 <= msg_idx < len(test_messages): - assert test_messages[msg_idx]["content"] == expected_content - else: - assert test_messages == messages From 264f5115d08dfc42bb635953be8ff196767c44b8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 11 Jun 2025 16:54:45 -0400 Subject: [PATCH 03/22] refactor: Disallow similar tool names in the tool registry (#193) Per follow-up to #178, where we discussed preventing similar_tool and similar-tool from both being added to the tool registry, to avoid ambiguity in direct-method invocations --- src/strands/agent/agent.py | 8 ++------ src/strands/telemetry/tracer.py | 4 +--- src/strands/tools/registry.py | 15 +++++++++++++++ tests/strands/agent/test_agent.py | 22 ---------------------- tests/strands/tools/test_registry.py | 20 ++++++++++++++++++++ 5 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5854fba6..56f5b92e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -108,14 +108,10 @@ def find_normalized_tool_name() -> Optional[str]: # all tools that can be represented with the normalized name if "_" in name: filtered_tools = [ - tool_name - for (tool_name, tool) in tool_registry.items() - if tool_name.replace("-", "_") == name + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name ] - if len(filtered_tools) > 1: - raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}") - + # The registry itself defends against similar names, so we can just take the first match if filtered_tools: return filtered_tools[0] diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9f731996..34eb7bed 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -13,9 +13,7 @@ from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - -# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore -from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2efdd600..e56ee999 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + if self.registry.get(tool.tool_name) is None: + normalized_name = tool.tool_name.replace("-", "_") + + matching_tools = [ + tool_name + for (tool_name, tool) in self.registry.items() + if tool_name.replace("-", "_") == normalized_name + ] + + if matching_tools: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'." + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + # Register in main registry self.registry[tool.tool_name] = tool diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 60b38ffa..d6f47be0 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -739,28 +739,6 @@ def function(system_prompt: str) -> str: } -def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint): - agent.tool_handler = unittest.mock.Mock() - - @strands.tools.tool(name="system-prompter_1") - def function1(system_prompt: str) -> str: - return system_prompt - - @strands.tools.tool(name="system-prompter-1") - def function2(system_prompt: str) -> str: - return system_prompt - - agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1)) - agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2)) - - mock_randint.return_value = 1 - - with pytest.raises(AttributeError) as err: - agent.tool.system_prompter_1(system_prompt="tool prompt") - - assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1" - - def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): agent.tool_handler = unittest.mock.Mock() diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3dca5371..1b274f46 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,8 +2,11 @@ Tests for the SDK tool registry module. """ +from unittest.mock import MagicMock + import pytest +from strands.tools import PythonAgentTool from strands.tools.registry import ToolRegistry @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path(): with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): tool_registry.process_tools([invalid_path]) + + +def test_register_tool_with_similar_name_raises(): + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + + with pytest.raises(ValueError) as err: + tool_registry.register_tool(tool_2) + + assert ( + str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " + "Cannot add a duplicate tool which differs by a '-' or '_'" + ) From 4b44410e276830e23df4ce9dd8f55921dddb4af0 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 12 Jun 2025 00:10:55 +0300 Subject: [PATCH 04/22] ci: add integration test workflow (#201) --- .github/workflows/integration-test.yml | 62 ++++++++++++++++++++++++++ tests-integ/test_bedrock_guardrails.py | 4 +- tests-integ/test_mcp_client.py | 6 +++ tests-integ/test_model_litellm.py | 2 +- 4 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/integration-test.yml diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 00000000..389924c5 --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,62 @@ +name: Secure Integration test + +on: + pull_request_target: + types: [opened, synchronize, labeled, unlabled, reopened] + +jobs: + check-access-and-checkout: + runs-on: ubuntu-latest + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Check PR labels and author + id: check + uses: actions/github-script@v7 + with: + script: | + const pr = context.payload.pull_request; + + const labels = pr.labels.map(label => label.name); + const hasLabel = labels.includes('approved-for-integ-test') + if (hasLabel) { + core.info('PR contains label approved-for-integ-test') + return + } + + const isOwner = pr.author_association === 'OWNER' + if (isOwner) { + core.info('PR author is an OWNER') + return + } + + core.setFailed('Pull Request must either have label approved-for-integ-test or be created by an owner') + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + - name: Checkout base branch + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run integration tests + env: + AWS_REGION: us-east-1 + AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + id: tests + run: | + hatch test tests-integ + + diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py index 9ffd1bdf..bf0be706 100644 --- a/tests-integ/test_bedrock_guardrails.py +++ b/tests-integ/test_bedrock_guardrails.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def boto_session(): - return boto3.Session(region_name="us-west-2") + return boto3.Session(region_name="us-east-1") @pytest.fixture(scope="module") @@ -142,7 +142,7 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_stream_processing_mode=processing_mode, guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, - region_name="us-west-2", + region_name="us-east-1", ) agent = Agent( diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index 59ae2a14..f0669284 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -1,8 +1,10 @@ import base64 +import os import threading import time from typing import List, Literal +import pytest from mcp import StdioServerParameters, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -101,6 +103,10 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == 'true', + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue" +) def test_streamable_http_mcp_client(): server_thread = threading.Thread( target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61f..d6a83b50 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -7,7 +7,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture From 7c5f7a74dbae351041eaec28d57a55d590999bc7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 12 Jun 2025 20:56:47 +0300 Subject: [PATCH 05/22] fix: add inference profile to litellm test and remove ownership check in workflow (#209) --- .github/workflows/integration-test.yml | 8 +------- tests-integ/test_model_litellm.py | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 389924c5..294a2f3e 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -26,13 +26,7 @@ jobs: return } - const isOwner = pr.author_association === 'OWNER' - if (isOwner) { - core.info('PR author is an OWNER') - return - } - - core.setFailed('Pull Request must either have label approved-for-integ-test or be created by an owner') + core.setFailed('Pull Request must either have label approved-for-integ-test') - name: Configure Credentials uses: aws-actions/configure-aws-credentials@v4 with: diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index d6a83b50..86f6b42f 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -7,7 +7,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture From 68740c5439b852e9354a0eef37f0110c08f11ddb Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 13 Jun 2025 09:30:09 -0400 Subject: [PATCH 06/22] chore: allow custom tracer provider in Agent (#207) --- .../sliding_window_conversation_manager.py | 2 +- src/strands/telemetry/tracer.py | 81 ++++++++++++------- tests/strands/telemetry/test_tracer.py | 74 +++++++++++++---- 3 files changed, 113 insertions(+), 44 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 683cb52f..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: if results_truncated: logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) return - + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed..3353237d 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -11,12 +11,16 @@ from importlib.metadata import version from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult from ..types.content import Message, Messages @@ -133,16 +137,30 @@ def __init__( self.service_name = service_name self.otlp_headers = otlp_headers or {} - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + # Create resource with service information resource = Resource.create( { @@ -154,7 +172,7 @@ def _initialize_tracer(self) -> None: ) # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -190,15 +208,19 @@ def _initialize_tracer(self) -> None: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -212,7 +234,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -224,7 +246,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -239,7 +261,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -272,13 +294,13 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, 'force_flush'): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -294,12 +316,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -328,7 +350,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -347,9 +369,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -374,7 +394,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -402,10 +422,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -436,7 +456,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -466,7 +486,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -506,7 +526,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -557,13 +577,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 32a4ac0a..ac4ea257 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -4,7 +4,9 @@ from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -18,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -38,7 +52,7 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @@ -104,8 +118,17 @@ def env_with_both(): yield -def test_init_default(): +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -141,9 +164,14 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" + mock_is_initialized.return_value = False mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -161,8 +189,12 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -191,7 +223,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -262,7 +294,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -300,7 +332,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -338,7 +370,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -373,7 +405,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -464,9 +496,11 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") @@ -486,6 +520,18 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -530,7 +576,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer From 5fab010715fa26bc2bd68505367dbb7c09e0e3ed Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Mon, 16 Jun 2025 13:32:17 -0400 Subject: [PATCH 07/22] build(a2a): add a2a deps and mitigate otel conflict (#232) * build(a2a): add a2a deps and mitigate otel conflict --- pyproject.toml | 20 ++++++++++++++++---- src/strands/telemetry/tracer.py | 20 +++++++++++++++++--- tests-integ/test_mcp_client.py | 4 ++-- tests/strands/telemetry/test_tracer.py | 17 ++++++++++++++--- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd309732..835def0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -78,13 +77,23 @@ ollama = [ openai = [ "openai>=1.68.0,<2.0.0", ] +otel = [ + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", +] +a2a = [ + "a2a-sdk>=0.2.6", + "uvicorn>=0.34.2", + "httpx>=0.28.1", + "fastapi>=0.115.12", + "starlette>=0.46.2", +] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -107,7 +116,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -123,8 +132,11 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] +[tool.hatch.envs.a2a] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] [[tool.hatch.envs.hatch-test.matrix]] diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 3353237d..4d6770bf 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -14,7 +14,6 @@ import opentelemetry.trace as trace_api from opentelemetry import propagate from opentelemetry.baggage.propagation import W3CBaggagePropagator -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider @@ -30,6 +29,19 @@ logger = logging.getLogger(__name__) +HAS_OTEL_EXPORTER_MODULE = False +OTEL_EXPORTER_MODULE_ERROR = ( + "opentelemetry-exporter-otlp-proto-http not detected;" + "please install strands-agents with the optional 'otel' target" + "otel http exporting is currently DISABLED" +) +try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + HAS_OTEL_EXPORTER_MODULE = True +except ImportError: + pass + class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" @@ -181,7 +193,7 @@ def _initialize_tracer(self) -> None: self.tracer_provider.add_span_processor(console_processor) # Add OTLP exporter if endpoint is provided - if self.otlp_endpoint and self.tracer_provider: + if HAS_OTEL_EXPORTER_MODULE and self.otlp_endpoint and self.tracer_provider: try: # Ensure endpoint has the right format endpoint = self.otlp_endpoint @@ -206,6 +218,8 @@ def _initialize_tracer(self) -> None: logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) + elif self.otlp_endpoint and self.tracer_provider: + logger.warning(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider trace_api.set_tracer_provider(self.tracer_provider) @@ -294,7 +308,7 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, 'force_flush'): + if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index f0669284..8b1dade3 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -104,8 +104,8 @@ def test_can_reuse_mcp_client(): @pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == 'true', - reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue" + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): server_thread = threading.Thread( diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index ac4ea257..98849883 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -58,7 +58,10 @@ def mock_set_tracer_provider(): @pytest.fixture def mock_otlp_exporter(): - with mock.patch("strands.telemetry.tracer.OTLPSpanExporter") as mock_exporter: + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") as mock_exporter, + ): yield mock_exporter @@ -199,7 +202,11 @@ def test_initialize_tracer_with_otlp( mock_resource.create.return_value = mock_resource_instance # Initialize Tracer - Tracer(otlp_endpoint="http://test-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) @@ -508,7 +515,11 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( # This should not raise an exception, but should log an error # Initialize Tracer - Tracer(otlp_endpoint="http://invalid-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) From e12bc2f030e73dea9504284f4218de6785ded83f Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Mon, 16 Jun 2025 14:47:37 -0400 Subject: [PATCH 08/22] chore(otel): raise exception if exporter unavailable (#234) --- src/strands/telemetry/tracer.py | 2 +- tests/strands/telemetry/test_tracer.py | 27 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 4d6770bf..e9a37a4a 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -219,7 +219,7 @@ def _initialize_tracer(self) -> None: except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) elif self.otlp_endpoint and self.tracer_provider: - logger.warning(OTEL_EXPORTER_MODULE_ERROR) + raise ModuleNotFoundError(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider trace_api.set_tracer_provider(self.tracer_provider) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 98849883..030dcd37 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -531,6 +531,33 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_missing_module( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_resource +): + """Test initializing the tracer when the OTLP exporter module is missing.""" + mock_is_initialized.return_value = False + + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + # Initialize Tracer with OTLP endpoint but missing module + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + Tracer(otlp_endpoint="http://test-endpoint") + + # Verify the error message + assert "opentelemetry-exporter-otlp-proto-http not detected" in str(excinfo.value) + assert "otel http exporting is currently DISABLED" in str(excinfo.value) + + # Verify the tracer provider was created with correct resource + mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + + # Verify set_tracer_provider was not called since an exception was raised + mock_set_tracer_provider.assert_not_called() + + def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): """Test initializing the tracer with NoOpTracerProvider.""" mock_is_initialized.return_value = True From 756a0276b2bdc3787ffc5d479cf2c34ea13e6a69 Mon Sep 17 00:00:00 2001 From: Adnan Khan Date: Tue, 17 Jun 2025 10:37:22 -0400 Subject: [PATCH 09/22] fix: Update PR Integration Test Workflow (#237) * Use deployment environment gating for integration tests. * Only run on PRs that target main. * Use correct head name. --- .github/workflows/integration-test.yml | 54 +++++++++++++++++--------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 294a2f3e..39b53c49 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -2,41 +2,59 @@ name: Secure Integration test on: pull_request_target: - types: [opened, synchronize, labeled, unlabled, reopened] + branches: main jobs: + authorization-check: + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result }} + steps: + - name: Collaborator Check + uses: actions/github-script@v7 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) + return "manual-approval" + } check-access-and-checkout: runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} permissions: id-token: write pull-requests: read contents: read steps: - - name: Check PR labels and author - id: check - uses: actions/github-script@v7 - with: - script: | - const pr = context.payload.pull_request; - - const labels = pr.labels.map(label => label.name); - const hasLabel = labels.includes('approved-for-integ-test') - if (hasLabel) { - core.info('PR contains label approved-for-integ-test') - return - } - - core.setFailed('Pull Request must either have label approved-for-integ-test') - name: Configure Credentials uses: aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 mask-aws-account-id: true - - name: Checkout base branch + - name: Checkout head commit uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.ref }} # Pull the commit from the forked repo + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python uses: actions/setup-python@v5 From 4dd0819812681d80953d8b6a981e3d8f3a81d825 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:44:29 -0400 Subject: [PATCH 10/22] fix: remove swagger-parser (#220) Having the dependency breaks uv installs as swagger-parser depends on pre-release library which is not allowed by default As far as I can tell, this dependency is not used anywhere and can be safely removed Co-authored-by: Mackenzie Zastrow --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 835def0f..17bc110e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ dev = [ "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", From 52c68aaddb484b0553d6feec77c8f90fdcc915d1 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:45:33 -0400 Subject: [PATCH 11/22] fix: Update throttling logic to use exponential back-off (#223) current_delay was being thrown away and not applied to subsequent retries Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 3 +- tests/strands/event_loop/test_event_loop.py | 56 ++++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 71165926..02a56a1c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -136,6 +136,7 @@ def event_loop_cycle( metrics: Metrics # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( @@ -168,7 +169,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 8c46e009..734457aa 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -157,8 +164,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -191,6 +198,53 @@ def test_event_loop_cycle_text_response_throttling( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() + + +def test_event_loop_cycle_exponential_backoff( + mock_time, + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding + model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ], + ] + + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error( From eb50073f2ac82302b35d69dfca5a2337aa1da7ab Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:25:54 -0400 Subject: [PATCH 12/22] feat: Simplify contribution template + pr scripts to run (#221) - Add a single script that runs the commands we want contributions to run. - Revamp the contribution template by: - Using HTML comments to guide the author, reducing the need to delete the text - Converting the type of change to be a flat list, allowing authors to delete all except the one they want - Updated the checklist items to allow items that are unnecessary to still be checked Co-authored-by: Mackenzie Zastrow --- .github/PULL_REQUEST_TEMPLATE.md | 33 ++++++++++++++++---------------- pyproject.toml | 6 +++++- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a6efda65..fa894231 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,37 +1,38 @@ ## Description -[Provide a detailed description of the changes in this PR] + ## Related Issues -[Link to related issues using #issue-number format] + + ## Documentation PR -[Link to related associated PR in the agent-docs repo] + + ## Type of Change -- Bug fix -- New feature -- Breaking change -- Documentation update -- Other (please describe): -[Choose one of the above types of changes] + +Bug fix +New feature +Breaking change +Documentation update +Other (please describe): ## Testing -[How have you tested the change?] -* `hatch fmt --linter` -* `hatch fmt --formatter` -* `hatch test --all` -* Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +- [ ] I ran `hatch run prepare` ## Checklist - [ ] I have read the CONTRIBUTING document -- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] I have added any necessary tests that prove my fix is effective or my feature works - [ ] I have updated the documentation accordingly -- [ ] I have added an appropriate example to the documentation to outline the feature +- [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published +---- + By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. diff --git a/pyproject.toml b/pyproject.toml index 17bc110e..bf7615c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,11 @@ test = [ test-integ = [ "hatch test tests-integ {args}" ] - +prepare = [ + "hatch fmt --linter", + "hatch fmt --formatter", + "hatch test --all" +] [tool.mypy] python_version = "3.10" From cc5be1200123d0c48765f264fce92de711dac953 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 18 Jun 2025 16:09:55 +0300 Subject: [PATCH 13/22] chore(deps): relax docstring_parser version to allow 1.0 (#239) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf7615c2..56c1a40e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", From 4d7bb9820fd5153946abe08ac3513a4e2527bd73 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 18 Jun 2025 18:49:16 -0400 Subject: [PATCH 14/22] feat: initialized meter (#219) --- src/strands/agent/agent.py | 1 - src/strands/telemetry/__init__.py | 5 +- src/strands/telemetry/config.py | 33 +++++++++++++ src/strands/telemetry/metrics.py | 46 ++++++++++++++++++ src/strands/telemetry/metrics_constants.py | 3 ++ src/strands/telemetry/tracer.py | 15 ++---- tests/strands/telemetry/test_metrics.py | 56 ++++++++++++++++++++++ tests/strands/telemetry/test_tracer.py | 24 ++++------ 8 files changed, 153 insertions(+), 30 deletions(-) create mode 100644 src/strands/telemetry/config.py create mode 100644 src/strands/telemetry/metrics_constants.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 56f5b92e..2475b87e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -308,7 +308,6 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None - self.tool_caller = Agent.ToolCaller(self) @property diff --git a/src/strands/telemetry/__init__.py b/src/strands/telemetry/__init__.py index 15981216..21dd6ebf 100644 --- a/src/strands/telemetry/__init__.py +++ b/src/strands/telemetry/__init__.py @@ -3,7 +3,8 @@ This module provides metrics and tracing functionality. """ -from .metrics import EventLoopMetrics, Trace, metrics_to_string +from .config import get_otel_resource +from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string from .tracer import Tracer, get_tracer __all__ = [ @@ -12,4 +13,6 @@ "metrics_to_string", "Tracer", "get_tracer", + "MetricsClient", + "get_otel_resource", ] diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py new file mode 100644 index 00000000..9f5a05fd --- /dev/null +++ b/src/strands/telemetry/config.py @@ -0,0 +1,33 @@ +"""OpenTelemetry configuration and setup utilities for Strands agents. + +This module provides centralized configuration and initialization functionality +for OpenTelemetry components and other telemetry infrastructure shared across Strands applications. +""" + +from importlib.metadata import version + +from opentelemetry.sdk.resources import Resource + + +def get_otel_resource() -> Resource: + """Create a standard OpenTelemetry resource with service information. + + This function implements a singleton pattern - it will return the same + Resource object for the same service_name parameter. + + Args: + service_name: Name of the service for OpenTelemetry. + + Returns: + Resource object with standard service information. + """ + resource = Resource.create( + { + "service.name": __name__, + "service.version": version("strands-agents"), + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.language": "python", + } + ) + + return resource diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index cd70819b..af940643 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -6,6 +6,10 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +import opentelemetry.metrics as metrics_api +from opentelemetry.metrics import Counter, Meter + +from ..telemetry import metrics_constants as constants from ..types.content import Message from ..types.streaming import Metrics, Usage from ..types.tools import ToolUse @@ -355,3 +359,45 @@ def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optio A formatted string representation of the metrics. """ return "\n".join(_metrics_summary_to_lines(event_loop_metrics, allowed_names or set())) + + +class MetricsClient: + """Singleton client for managing OpenTelemetry metrics instruments. + + The actual metrics export destination (console, OTLP endpoint, etc.) is configured + through OpenTelemetry SDK configuration by users, not by this client. + """ + + _instance: Optional["MetricsClient"] = None + meter: Meter + strands_agent_invocation_count: Counter + + def __new__(cls) -> "MetricsClient": + """Create or return the singleton instance of MetricsClient. + + Returns: + The single MetricsClient instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the MetricsClient. + + This method only runs once due to the singleton pattern. + Sets up the OpenTelemetry meter and creates metric instruments. + """ + if hasattr(self, "meter"): + return + + logger.info("Creating Strands MetricsClient") + meter_provider: metrics_api.MeterProvider = metrics_api.get_meter_provider() + self.meter = meter_provider.get_meter(__name__) + self.create_instruments() + + def create_instruments(self) -> None: + """Create and initialize all OpenTelemetry metric instruments.""" + self.strands_agent_invocation_count = self.meter.create_counter( + name=constants.STRANDS_AGENT_INVOCATION_COUNT, unit="Count" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py new file mode 100644 index 00000000..d3d3e81f --- /dev/null +++ b/src/strands/telemetry/metrics_constants.py @@ -0,0 +1,3 @@ +"""Metrics that are emitted in Strands-Agent.""" + +STRANDS_AGENT_INVOCATION_COUNT = "strands.agent.invocation_count" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index e9a37a4a..813c90e1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -8,20 +8,19 @@ import logging import os from datetime import date, datetime, timezone -from importlib.metadata import version from typing import Any, Dict, Mapping, Optional import opentelemetry.trace as trace_api from opentelemetry import propagate from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import Span, StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult +from ..telemetry import get_otel_resource from ..types.content import Message, Messages from ..types.streaming import Usage from ..types.tools import ToolResult, ToolUse @@ -151,7 +150,6 @@ def __init__( self.otlp_headers = otlp_headers or {} self.tracer_provider: Optional[trace_api.TracerProvider] = None self.tracer: Optional[trace_api.Tracer] = None - propagate.set_global_textmap( CompositePropagator( [ @@ -173,15 +171,7 @@ def _initialize_tracer(self) -> None: self.tracer = self.tracer_provider.get_tracer(self.service_name) return - # Create resource with service information - resource = Resource.create( - { - "service.name": self.service_name, - "service.version": version("strands-agents"), - "telemetry.sdk.name": "opentelemetry", - "telemetry.sdk.language": "python", - } - ) + resource = get_otel_resource() # Create tracer provider self.tracer_provider = SDKTracerProvider(resource=resource) @@ -216,6 +206,7 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) + except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) elif self.otlp_endpoint and self.tracer_provider: diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 4e84f0fd..cafbd4bb 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -1,9 +1,13 @@ import dataclasses import unittest +from unittest import mock import pytest +from opentelemetry.metrics._internal import _ProxyMeter +from opentelemetry.sdk.metrics import MeterProvider import strands +from strands.telemetry import MetricsClient from strands.types.streaming import Metrics, Usage @@ -117,6 +121,30 @@ def test_trace_end(mock_time, end_time, trace): assert tru_end_time == exp_end_time +@pytest.fixture +def mock_get_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + meter_provider_mock = mock.MagicMock(spec=MeterProvider) + mock_get_meter_provider.return_value = meter_provider_mock + + mock_meter = mock.MagicMock() + meter_provider_mock.get_meter.return_value = mock_meter + + yield mock_get_meter_provider + + +@pytest.fixture +def mock_sdk_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_sdk.MeterProvider") as mock_meter_provider: + yield mock_meter_provider + + +@pytest.fixture +def mock_resource(): + with mock.patch("opentelemetry.sdk.resources.Resource") as mock_resource: + yield mock_resource + + def test_trace_add_child(child_trace, trace): trace.add_child(child_trace) @@ -379,3 +407,31 @@ def test_metrics_to_string(trace, child_trace, tool_metrics, exp_str, event_loop tru_str = strands.telemetry.metrics.metrics_to_string(event_loop_metrics) assert tru_str == exp_str + + +def test_setup_meter_if_meter_provider_is_set( + mock_get_meter_provider, + mock_resource, +): + """Test global meter_provider and meter are used""" + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + metrics_client = MetricsClient() + + mock_get_meter_provider.assert_called() + mock_get_meter_provider.return_value.get_meter.assert_called() + + assert metrics_client is not None + + +def test_use_ProxyMeter_if_no_global_meter_provider(): + """Return _ProxyMeter""" + # Reset the singleton instance + strands.telemetry.metrics.MetricsClient._instance = None + + # Create a new instance which should use the real _ProxyMeter + metrics_client = MetricsClient() + + # Verify it's using a _ProxyMeter + assert isinstance(metrics_client.meter, _ProxyMeter) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 030dcd37..6ae3e1ad 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -73,7 +73,9 @@ def mock_console_exporter(): @pytest.fixture def mock_resource(): - with mock.patch("strands.telemetry.tracer.Resource") as mock_resource: + with mock.patch("strands.telemetry.tracer.get_otel_resource") as mock_resource: + mock_resource_instance = mock.MagicMock() + mock_resource.return_value = mock_resource_instance yield mock_resource @@ -175,14 +177,12 @@ def test_initialize_tracer_with_console( ): """Test initializing the tracer with console exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance # Initialize Tracer Tracer(enable_console_export=True) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify console exporter was added mock_console_exporter.assert_called_once() @@ -198,9 +198,6 @@ def test_initialize_tracer_with_otlp( """Test initializing the tracer with OTLP exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), @@ -209,7 +206,7 @@ def test_initialize_tracer_with_otlp( Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was added with correct endpoint mock_otlp_exporter.assert_called_once() @@ -508,8 +505,6 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( """Test initializing the tracer with an invalid OTLP endpoint.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") # This should not raise an exception, but should log an error @@ -522,7 +517,7 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was attempted mock_otlp_exporter.assert_called_once() @@ -537,9 +532,6 @@ def test_initialize_tracer_with_missing_module( """Test initializing the tracer when the OTLP exporter module is missing.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with OTLP endpoint but missing module with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), @@ -552,13 +544,13 @@ def test_initialize_tracer_with_missing_module( assert "otel http exporting is currently DISABLED" in str(excinfo.value) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify set_tracer_provider was not called since an exception was raised mock_set_tracer_provider.assert_not_called() -def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): +def test_initialize_tracer_with_custom_tracer_provider(mock_is_initialized, mock_get_tracer_provider, mock_resource): """Test initializing the tracer with NoOpTracerProvider.""" mock_is_initialized.return_value = True tracer = Tracer(otlp_endpoint="http://invalid-endpoint") From 40f622aa80dc865343db33a8bef595e46ef92394 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 19 Jun 2025 09:16:31 -0400 Subject: [PATCH 15/22] models - openai - images - b64 validate (#251) --- src/strands/types/models/openai.py | 12 +++++++++++- tests/strands/types/models/test_openai.py | 20 +++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 96f758d5..0df1dda5 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -57,7 +57,17 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") + image_bytes = content["image"]["source"]["bytes"] + try: + base64.b64decode(image_bytes, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images will not be accepted in a future version", + "https://github.com/strands-agents/sdk-python/issues/252" + ) + except ValueError: + image_bytes = base64.b64encode(image_bytes) + + image_data = image_bytes.decode("utf-8") return { "image_url": { "detail": "auto", diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 9db08bc9..2827969d 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,3 +1,4 @@ +import base64 import unittest.mock import pytest @@ -90,7 +91,24 @@ def system_prompt(): "image_url": { "detail": "auto", "format": "image/jpeg", - "url": "data:image/jpeg;base64,image", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", }, "type": "image_url", }, From 735d0c078c0c653d462b3a7a84aa81d27a34e749 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 19 Jun 2025 11:03:50 -0400 Subject: [PATCH 16/22] chore: Inline event loop helper functions (#222) While reading the through the event loop, it made more sense to me to inline the implementation - most of the actual apis are signature + docs, while the actual code is 2 lines each. Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 44 +------- tests/strands/event_loop/test_event_loop.py | 107 +++++++++++++------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 02a56a1c..e336642c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -33,23 +33,6 @@ MAX_DELAY = 240 # 4 minutes -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs - - def event_loop_cycle( model: Model, system_prompt: Optional[str], @@ -107,7 +90,8 @@ def event_loop_cycle( event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) + if "request_state" not in kwargs: + kwargs["request_state"] = {} cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() kwargs["event_loop_cycle_trace"] = cycle_trace @@ -310,26 +294,6 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. - - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. - - Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. - - Returns: - Updated keyword arguments ready for the next cycle. - """ - # Store parent cycle ID - kwargs["event_loop_metrics"] = event_loop_metrics - kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - - return kwargs - - def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -403,7 +367,9 @@ def _handle_tool_execution( parallel_tool_executor=tool_execution_handler, ) - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + # Store parent cycle ID for the next cycle + kwargs["event_loop_metrics"] = event_loop_metrics + kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] tool_result_message: Message = { "role": "user", diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 734457aa..efdf7af8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -111,27 +111,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -465,19 +444,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -733,3 +699,76 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"] From 684b3f7852776140a03a6dcf53e52a3cac4a5ef1 Mon Sep 17 00:00:00 2001 From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com> Date: Thu, 19 Jun 2025 11:11:44 -0500 Subject: [PATCH 17/22] feat: add structured output support using Pydantic models (#60) * feat: add structured output support using Pydantic models - Add method to Agent class for handling structured outputs - Create structured_output.py utility for converting Pydantic models to tool specs - Improve error handling when extracting model_id from configuration - Add integration tests to validate structured output functionality * fix: import cleanups and unused vars * feat: wip adding `structured_output` methods * feat: wip added structured output to bedrock and anthropic * feat: litellm structured output and some integ tests * feat: all structured outputs working, tbd llama api * feat: updated docstring * fix: otel ci dep issue * fix: remove unnecessary changes and comments * feat: basic test WIP * feat: better test coverage * fix: remove unused fixture * fix: resolve some comments * fix: inline basemodel classes * feat: update litellm, add checks * fix: autoformatting issue * feat: resolves comments * fix: ollama skip tests, pyproject whitespace diffs --- pyproject.toml | 4 +- src/strands/agent/agent.py | 32 +- src/strands/models/anthropic.py | 51 ++- src/strands/models/bedrock.py | 47 +- src/strands/models/litellm.py | 49 ++- src/strands/models/llamaapi.py | 33 +- src/strands/models/ollama.py | 27 +- src/strands/models/openai.py | 39 +- src/strands/tools/__init__.py | 2 + src/strands/tools/structured_output.py | 415 ++++++++++++++++++ src/strands/types/models/model.py | 26 +- src/strands/types/models/openai.py | 18 +- tests-integ/test_model_anthropic.py | 14 + tests-integ/test_model_bedrock.py | 31 ++ tests-integ/test_model_litellm.py | 14 + tests-integ/test_model_ollama.py | 47 ++ tests-integ/test_model_openai.py | 20 + tests/strands/agent/test_agent.py | 26 ++ tests/strands/models/test_anthropic.py | 25 +- tests/strands/tools/test_structured_output.py | 228 ++++++++++ tests/strands/types/models/test_model.py | 17 + 21 files changed, 1147 insertions(+), 18 deletions(-) create mode 100644 src/strands/tools/structured_output.py create mode 100644 tests-integ/test_model_ollama.py create mode 100644 tests/strands/tools/test_structured_output.py diff --git a/pyproject.toml b/pyproject.toml index 56c1a40e..e0cc2578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -264,4 +264,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 2475b87e..a5e26a07 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -43,6 +44,9 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -386,6 +390,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages, self.callback_handler) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 57394e2c..ab427e53 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,11 +7,15 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -20,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield event.model_dump() usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + # process the stream and get the tool use input + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) + + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9bbcca7d..ac1c4a38 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,13 +6,17 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, cast +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -29,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -477,3 +483,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + # process the stream and get the tool use input + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) + + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 62f16d31..66138186 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,17 +3,22 @@ - Docs: https://docs.litellm.ai/ """ +import json import logging -from typing import Any, Optional, TypedDict, cast +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema +from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from .openai import OpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" @@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] } return super().format_request_message_content(content) + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + """ + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 583db2f2..755e07ad 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. + """ + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7ed12216..b062fe14 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,9 +5,10 @@ import json import logging -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6cbef664..783ce379 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,15 +4,20 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel from typing_extensions import Unpack, override +from ..types.content import Messages from ..types.models import OpenAIModel as SAOpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -125,3 +130,35 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: _ = event yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..12979015 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 00000000..5421cdc6 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,415 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + +from ..types.tools import ToolSpec + + +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + required_props: list[str] = [] + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: + flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") + + return flattened + + +def _process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def _process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = _process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_tool_spec( + model: Type[BaseModel], + description: Optional[str] = None, +) -> ToolSpec: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + + Returns: + ToolSpec: Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + _process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + _expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = _flatten_schema(input_schema) + + final_schema = flattened_schema + + # Construct the tool specification + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) + + +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + _process_properties(ref_def, field_type) + + +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 23e74602..071c8a51 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,9 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel from ..content import Messages from ..streaming import StreamEvent @@ -10,6 +12,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -38,6 +42,26 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + pass + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 0df1dda5..cbd2cfd2 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,8 +11,9 @@ import json import logging import mimetypes -from typing import Any, Optional, cast +from typing import Any, Callable, Optional, Type, TypeVar, cast +from pydantic import BaseModel from typing_extensions import override from ..content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OpenAIModel(Model, abc.ABC): """Base OpenAI model provider implementation. @@ -272,3 +275,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c9..95bfceb5 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -47,3 +48,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny", "&"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index a6a29aa9..5378a9b2 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -118,3 +119,33 @@ def calculator(expression: str) -> float: agent("What is 123 + 456?") assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index 86f6b42f..01a3e121 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 00000000..38b46821 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index c9046ad5..b0790ba0 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -44,3 +45,22 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6f47be0..85d17544 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,6 +7,7 @@ from time import sleep import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -793,6 +794,31 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9421650e..a0cfc4d4 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -615,10 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -631,7 +645,10 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 00000000..2e354b83 --- /dev/null +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,228 @@ +from typing import Literal, Optional + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_tool_spec +from strands.types.tools import ToolSpec + + +# Basic test model +class User(BaseModel): + """User model with name and age.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): + tool_spec = convert_pydantic_to_tool_spec(User) + + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + assert tool_spec == expected_spec + + +def test_convert_pydantic_to_tool_spec_complex(): + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): + name: str = Field(description="The name of the user") + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index f2797fe5..03690733 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,8 +1,16 @@ +from typing import Type + import pytest +from pydantic import BaseModel from strands.types.models import Model as SAModel +class Person(BaseModel): + name: str + age: int + + class TestModel(SAModel): def update_config(self, **model_config): return model_config @@ -10,6 +18,9 @@ def update_config(self, **model_config): def get_config(self): return + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + def format_request(self, messages, tool_specs, system_prompt): return { "messages": messages, @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt): }, ] assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) From 3a23ce2f3d4218444304979f48d2d491b08e6c31 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 19 Jun 2025 12:24:44 -0400 Subject: [PATCH 18/22] fix: Emit warning that us-west-2 will not be the default region (#254) Related to #238 Co-authored-by: Mackenzie Zastrow --- src/strands/models/bedrock.py | 11 ++++++++++- src/strands/types/models/openai.py | 2 +- tests/strands/models/test_bedrock.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ac1c4a38..3de41198 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -118,8 +118,17 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + region_for_boto = region_name or os.getenv("AWS_REGION") + if region_for_boto is None: + region_for_boto = "us-west-2" + logger.warning("defaulted to us-west-2 because no region was specified") + logger.warning( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + session = boto_session or boto3.Session( - region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", + region_name=region_for_boto, ) # Add strands-agents to the request user agent diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index cbd2cfd2..a6bd93cc 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -65,7 +65,7 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] base64.b64decode(image_bytes, validate=True) logger.warning( "issue=<%s> | base64 encoded images will not be accepted in a future version", - "https://github.com/strands-agents/sdk-python/issues/252" + "https://github.com/strands-agents/sdk-python/issues/252", ) except ValueError: image_bytes = base64.b64encode(image_bytes) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b326eee7..137b57c8 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -91,6 +91,23 @@ def test__init__default_model_id(bedrock_client): assert tru_model_id == exp_model_id +def test__init__with_default_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + default_region = "us-west-2" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name=default_region) + # Assert that warning logs are emitted + mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") + mock_warning.assert_any_call( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + + def test__init__with_custom_region(bedrock_client): """Test that BedrockModel uses the provided region.""" _ = bedrock_client From 76ee1ada81ef0d884c234d810d8332ad3f64ac1a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 20 Jun 2025 10:14:47 -0400 Subject: [PATCH 19/22] models - openai - b64encode method (#260) --- src/strands/types/models/openai.py | 39 ++++++++++++++++------- tests/strands/types/models/test_openai.py | 12 +++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index a6bd93cc..8ff37d35 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -34,6 +34,32 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. @@ -60,17 +86,8 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_bytes = content["image"]["source"]["bytes"] - try: - base64.b64decode(image_bytes, validate=True) - logger.warning( - "issue=<%s> | base64 encoded images will not be accepted in a future version", - "https://github.com/strands-agents/sdk-python/issues/252", - ) - except ValueError: - image_bytes = base64.b64encode(image_bytes) - - image_data = image_bytes.decode("utf-8") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { "image_url": { "detail": "auto", diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 2827969d..3a1a940b 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -362,3 +362,15 @@ def test_format_chunk_unknown_type(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result From e6937381c94f7614ae461b64dd5e6e89258d0d3f Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 20 Jun 2025 10:23:37 -0400 Subject: [PATCH 20/22] chore: emit strands metrics (#248) --- src/strands/event_loop/event_loop.py | 7 +- src/strands/telemetry/metrics.py | 95 +++++++++++++++++++--- src/strands/telemetry/metrics_constants.py | 16 +++- tests/strands/telemetry/test_metrics.py | 52 +++++++++--- tests/strands/tools/test_executor.py | 7 ++ 5 files changed, 149 insertions(+), 28 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index e336642c..49769aab 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -88,11 +88,11 @@ def event_loop_cycle( kwargs["event_loop_cycle_id"] = uuid.uuid4() event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) - # Initialize state and get cycle trace if "request_state" not in kwargs: kwargs["request_state"] = {} - cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() + attributes = {"event_loop_cycle_id": str(kwargs.get("event_loop_cycle_id"))} + cycle_start_time, cycle_trace = event_loop_metrics.start_cycle(attributes=attributes) kwargs["event_loop_cycle_trace"] = cycle_trace callback_handler(start=True) @@ -211,7 +211,7 @@ def event_loop_cycle( ) # End the cycle and return results - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) if cycle_span: tracer.end_event_loop_cycle_span( span=cycle_span, @@ -344,7 +344,6 @@ def _handle_tool_execution( if not tool_uses: return stop_reason, message, event_loop_metrics, kwargs["request_state"] - tool_handler_process = partial( tool_handler.process, messages=messages, diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index af940643..332ab2ae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import opentelemetry.metrics as metrics_api -from opentelemetry.metrics import Counter, Meter +from opentelemetry.metrics import Counter, Histogram, Meter from ..telemetry import metrics_constants as constants from ..types.content import Message @@ -121,22 +121,34 @@ class ToolMetrics: error_count: int = 0 total_time: float = 0.0 - def add_call(self, tool: ToolUse, duration: float, success: bool) -> None: + def add_call( + self, + tool: ToolUse, + duration: float, + success: bool, + metrics_client: "MetricsClient", + attributes: Optional[Dict[str, Any]] = None, + ) -> None: """Record a new tool call with its outcome. Args: tool: The tool that was called. duration: How long the call took in seconds. success: Whether the call was successful. + metrics_client: The metrics client for recording the metrics. + attributes: attributes of the metrics. """ self.tool = tool # Update with latest tool state self.call_count += 1 self.total_time += duration - + metrics_client.tool_call_count.add(1, attributes=attributes) + metrics_client.tool_duration.record(duration, attributes=attributes) if success: self.success_count += 1 + metrics_client.tool_success_count.add(1, attributes=attributes) else: self.error_count += 1 + metrics_client.tool_error_count.add(1, attributes=attributes) @dataclass @@ -159,32 +171,53 @@ class EventLoopMetrics: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) - def start_cycle(self) -> Tuple[float, Trace]: + @property + def _metrics_client(self) -> "MetricsClient": + """Get the singleton MetricsClient instance.""" + return MetricsClient() + + def start_cycle( + self, + attributes: Optional[Dict[str, Any]] = None, + ) -> Tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. + Args: + attributes: attributes of the metrics. + Returns: A tuple containing the start time and the cycle trace object. """ + self._metrics_client.event_loop_cycle_count.add(1, attributes=attributes) + self._metrics_client.event_loop_start_cycle.add(1, attributes=attributes) self.cycle_count += 1 start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: """End the current event loop cycle and record its duration. Args: start_time: The timestamp when the cycle started. cycle_trace: The trace object for this cycle. + attributes: attributes of the metrics. """ + self._metrics_client.event_loop_end_cycle.add(1, attributes) end_time = time.time() duration = end_time - start_time + self._metrics_client.event_loop_cycle_duration.record(duration, attributes) self.cycle_durations.append(duration) cycle_trace.end(end_time) def add_tool_usage( - self, tool: ToolUse, duration: float, tool_trace: Trace, success: bool, message: Message + self, + tool: ToolUse, + duration: float, + tool_trace: Trace, + success: bool, + message: Message, ) -> None: """Record metrics for a tool invocation. @@ -207,8 +240,16 @@ def add_tool_usage( tool_trace.raw_name = f"{tool_name} - {tool_use_id}" tool_trace.add_message(message) - self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call(tool, duration, success) - + self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call( + tool, + duration, + success, + self._metrics_client, + attributes={ + "tool_name": tool_name, + "tool_use_id": tool_use_id, + }, + ) tool_trace.end() def update_usage(self, usage: Usage) -> None: @@ -217,6 +258,8 @@ def update_usage(self, usage: Usage) -> None: Args: usage: The usage data to add to the accumulated totals. """ + self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) + self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) self.accumulated_usage["inputTokens"] += usage["inputTokens"] self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] @@ -227,6 +270,7 @@ def update_metrics(self, metrics: Metrics) -> None: Args: metrics: The metrics data to add to the accumulated totals. """ + self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -370,7 +414,18 @@ class MetricsClient: _instance: Optional["MetricsClient"] = None meter: Meter - strands_agent_invocation_count: Counter + event_loop_cycle_count: Counter + event_loop_start_cycle: Counter + event_loop_end_cycle: Counter + event_loop_cycle_duration: Histogram + event_loop_latency: Histogram + event_loop_input_tokens: Histogram + event_loop_output_tokens: Histogram + + tool_call_count: Counter + tool_success_count: Counter + tool_error_count: Counter + tool_duration: Histogram def __new__(cls) -> "MetricsClient": """Create or return the singleton instance of MetricsClient. @@ -398,6 +453,24 @@ def __init__(self) -> None: def create_instruments(self) -> None: """Create and initialize all OpenTelemetry metric instruments.""" - self.strands_agent_invocation_count = self.meter.create_counter( - name=constants.STRANDS_AGENT_INVOCATION_COUNT, unit="Count" + self.event_loop_cycle_count = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_CYCLE_COUNT, unit="Count" + ) + self.event_loop_start_cycle = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_START_CYCLE, unit="Count" + ) + self.event_loop_end_cycle = self.meter.create_counter(name=constants.STRANDS_EVENT_LOOP_END_CYCLE, unit="Count") + self.event_loop_cycle_duration = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CYCLE_DURATION, unit="s" + ) + self.event_loop_latency = self.meter.create_histogram(name=constants.STRANDS_EVENT_LOOP_LATENCY, unit="ms") + self.tool_call_count = self.meter.create_counter(name=constants.STRANDS_TOOL_CALL_COUNT, unit="Count") + self.tool_success_count = self.meter.create_counter(name=constants.STRANDS_TOOL_SUCCESS_COUNT, unit="Count") + self.tool_error_count = self.meter.create_counter(name=constants.STRANDS_TOOL_ERROR_COUNT, unit="Count") + self.tool_duration = self.meter.create_histogram(name=constants.STRANDS_TOOL_DURATION, unit="s") + self.event_loop_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKENS, unit="token" + ) + self.event_loop_output_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index d3d3e81f..b622eebf 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -1,3 +1,15 @@ -"""Metrics that are emitted in Strands-Agent.""" +"""Metrics that are emitted in Strands-Agents.""" -STRANDS_AGENT_INVOCATION_COUNT = "strands.agent.invocation_count" +STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count" +STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle" +STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle" +STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count" +STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count" +STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count" + +# Histograms +STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency" +STRANDS_TOOL_DURATION = "strands.tool.duration" +STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" +STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" +STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index cafbd4bb..215e1efd 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -124,12 +124,19 @@ def test_trace_end(mock_time, end_time, trace): @pytest.fixture def mock_get_meter_provider(): with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + MetricsClient._instance = None meter_provider_mock = mock.MagicMock(spec=MeterProvider) - mock_get_meter_provider.return_value = meter_provider_mock mock_meter = mock.MagicMock() + mock_create_counter = mock.MagicMock() + mock_meter.create_counter.return_value = mock_create_counter + + mock_create_histogram = mock.MagicMock() + mock_meter.create_histogram.return_value = mock_create_histogram meter_provider_mock.get_meter.return_value = mock_meter + mock_get_meter_provider.return_value = meter_provider_mock + yield mock_get_meter_provider @@ -190,11 +197,14 @@ def test_trace_to_dict(trace): @pytest.mark.parametrize("success", [True, False]) -def test_tool_metrics_add_call(success, tool, tool_metrics): +def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provider): tool = dict(tool, **{"name": "updated"}) duration = 1 + metrics_client = MetricsClient() + + attributes = {"foo": "bar"} - tool_metrics.add_call(tool, duration, success) + tool_metrics.add_call(tool, duration, success, metrics_client, attributes=attributes) tru_attrs = dataclasses.asdict(tool_metrics) exp_attrs = { @@ -205,12 +215,17 @@ def test_tool_metrics_add_call(success, tool, tool_metrics): "total_time": duration, } + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client.tool_call_count.add.assert_called_with(1, attributes=attributes) + metrics_client.tool_duration.record.assert_called_with(duration, attributes=attributes) + if success: + metrics_client.tool_success_count.add.assert_called_with(1, attributes=attributes) assert tru_attrs == exp_attrs @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") -def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics): +def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 mock_uuid4.return_value = "i1" @@ -220,6 +235,8 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} exp_attrs = {"cycle_count": 1, "traces": [tru_cycle_trace]} + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_cycle_count.add.assert_called() assert ( tru_start_time == exp_start_time and tru_cycle_trace.to_dict() == exp_cycle_trace.to_dict() @@ -228,10 +245,11 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): +def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace) + attributes = {"foo": "bar"} + event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace, attributes=attributes) tru_cycle_durations = event_loop_metrics.cycle_durations exp_cycle_durations = [1] @@ -243,17 +261,23 @@ def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): assert tru_trace_end_time == exp_trace_end_time + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_end_cycle.add.assert_called_with(1, attributes) + metrics_client.event_loop_cycle_duration.record.assert_called() + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics): +def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} event_loop_metrics.add_tool_usage(tool, duration, trace, success, message) + mock_get_meter_provider.return_value.get_meter.assert_called() + tru_event_loop_metrics_attrs = {"tool_metrics": event_loop_metrics.tool_metrics} exp_event_loop_metrics_attrs = { "tool_metrics": { @@ -286,7 +310,7 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me assert tru_trace_attrs == exp_trace_attrs -def test_event_loop_metrics_update_usage(usage, event_loop_metrics): +def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_usage(usage) @@ -298,9 +322,13 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics): ) assert tru_usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_input_tokens.record.assert_called() + metrics_client.event_loop_output_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): +def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_metrics(metrics) @@ -310,9 +338,11 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): ) assert tru_metrics == exp_metrics + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) -def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics): +def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index a6ea45c3..4b238792 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -43,6 +43,12 @@ def tool_uses(request, tool_use): return request.param if hasattr(request, "param") else [tool_use] +@pytest.fixture +def mock_metrics_client(): + with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: + yield mock_metrics_client + + @pytest.fixture def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() @@ -303,6 +309,7 @@ def test_run_tools_creates_and_ends_span_on_success( mock_get_tracer, tool_handler, tool_uses, + mock_metrics_client, event_loop_metrics, request_state, invalid_tool_use_ids, From d8ce2d5e69322211b281567eb0da99e0ba47b574 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 20 Jun 2025 14:47:31 -0400 Subject: [PATCH 21/22] iterative streaming (#241) --- src/strands/event_loop/event_loop.py | 23 +- src/strands/event_loop/streaming.py | 80 ++--- src/strands/models/anthropic.py | 12 +- src/strands/models/bedrock.py | 12 +- tests-integ/test_model_anthropic.py | 4 +- tests/strands/event_loop/test_event_loop.py | 100 ++++++ tests/strands/event_loop/test_streaming.py | 362 ++++++++++++++------ 7 files changed, 427 insertions(+), 166 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 49769aab..9580ea35 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -130,14 +130,19 @@ def event_loop_cycle( ) try: - stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( - model, - system_prompt, - messages, - tool_config, - callback_handler, - **kwargs, - ) + # TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the + # call stack. At this point, we converted all events that were previously passed to the handler in + # `stream_messages` into yielded events that now have the "callback" key. To maintain backwards + # compatability, we need to combine the event with kwargs before passing to the handler. This we will + # revisit when migrating to strongly typed events. + for event in stream_messages(model, system_prompt, messages, tool_config): + if "callback" in event: + inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} + callback_handler(**inputs) + else: + stop_reason, message, usage, metrics = event["stop"] + kwargs.setdefault("request_state", {}) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage) break # Success! Break out of retry loop @@ -334,7 +339,7 @@ def _handle_tool_execution( kwargs (Dict[str, Any]): Additional keyword arguments, including request state. Returns: - Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: - The stop reason, - The updated message, - The updated event loop metrics, diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806f..0e9d472b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: return message -def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: +def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: """Handles the start of a content block by extracting tool usage information if any. Args: @@ -102,31 +102,31 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any -) -> Dict[str, Any]: + event: ContentBlockDeltaEvent, state: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. """ delta_content = event["delta"] + callback_event = {} + if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) + callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) + callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,29 +134,27 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_handler( - reasoningText=delta_content["reasoningContent"]["text"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoningText": delta_content["reasoningContent"]["text"], + "delta": delta_content, + "reasoning": True, + } elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_handler( - reasoning_signature=delta_content["reasoningContent"]["signature"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoning_signature": delta_content["reasoningContent"]["signature"], + "delta": delta_content, + "reasoning": True, + } - return state + return state, callback_event -def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: +def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: """Handles the end of a content block by finalizing tool usage, text content, or reasoning content. Args: @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: Returns: Updated state with finalized content block. """ - content: List[ContentBlock] = state["content"] + content: list[ContentBlock] = state["content"] current_tool_use = state["current_tool_use"] text = state["text"] @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: def process_stream( chunks: Iterable[StreamEvent], - callback_handler: Any, messages: Messages, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - callback_handler: Callback for processing events as they happen. messages: The agents messages. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the constructed message, the usage metrics, and the updated request state. + The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" - state: Dict[str, Any] = { + state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, "text": "", "current_tool_use": {}, @@ -285,18 +278,16 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - kwargs.setdefault("request_state", {}) - for chunk in chunks: - # Callback handler call here allows each event to be visible to the caller - callback_handler(event=chunk) + yield {"callback": {"event": chunk}} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) + state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +297,7 @@ def process_stream( elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], messages, state) - return stop_reason, state["message"], usage, metrics, kwargs["request_state"] + yield {"stop": (stop_reason, state["message"], usage, metrics)} def stream_messages( @@ -314,9 +305,7 @@ def stream_messages( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Streams messages to the model and processes the response. Args: @@ -324,12 +313,9 @@ def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_config: Configuration for the tools to use. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the final message, the usage metrics, and updated request state. + The reason for stopping, the final message, and the usage metrics """ logger.debug("model=<%s> | streaming messages", model) @@ -337,4 +323,4 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - return process_stream(chunks, callback_handler, messages, **kwargs) + yield from process_stream(chunks, messages) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index ab427e53..51089d47 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -387,15 +387,15 @@ def structured_output( prompt(Messages): The prompt messages to use for the agent. callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ + callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - # process the stream and get the tool use input - results = process_stream( - response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt - ) - - stop_reason, messages, _, _, _ = results + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 3de41198..a5ffb539 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -504,15 +504,15 @@ def structured_output( prompt(Messages): The prompt messages to use for the agent. callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. """ + callback_handler = callback_handler or PrintingCallbackHandler() tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - # process the stream and get the tool use input - results = process_stream( - response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt - ) - - stop_reason, messages, _, _, _ = results + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 95bfceb5..50033f8f 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -34,7 +34,7 @@ def tool_weather() -> str: @pytest.fixture def system_prompt(): - return "You are an AI assistant that uses & instead of ." + return "You are an AI assistant." @pytest.fixture @@ -47,7 +47,7 @@ def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - assert all(string in text for string in ["12:00", "sunny", "&"]) + assert all(string in text for string in ["12:00", "sunny"]) @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index efdf7af8..11f14503 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -701,6 +701,106 @@ def test_event_loop_cycle_with_parent_span( ) +def test_event_loop_cycle_callback( + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + model.converse.return_value = [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + callback_handler.assert_has_calls( + [ + call(start=True), + call(start_event_loop=True), + call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + call( + delta={"toolUse": {"input": '{"value"}'}}, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + call( + reasoningText="value", + delta={"reasoningContent": {"text": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + call( + reasoning_signature="value", + delta={"reasoningContent": {"signature": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + call( + data="value", + delta={"text": "value"}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + ], + ) + + def test_request_state_initialization(): # Call without providing request_state tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48..e91f4986 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -3,6 +3,7 @@ import pytest import strands +import strands.event_loop from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -17,13 +18,6 @@ def moto_autouse(moto_env, moto_mock_aws): _ = moto_mock_aws -@pytest.fixture -def agent(): - mock = unittest.mock.Mock() - - return mock - - @pytest.mark.parametrize( ("messages", "exp_result"), [ @@ -81,7 +75,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "exp_handler_args"), + ("event", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( @@ -148,21 +142,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, exp_handler_args): - if exp_handler_args: - exp_handler_args.update({"delta": event["delta"], "extra_arg": 1}) +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): + exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} - tru_handler_args = {} - - def callback_handler(**kwargs): - tru_handler_args.update(kwargs) - - tru_updated_state = strands.event_loop.streaming.handle_content_block_delta( - event, state, callback_handler, extra_arg=1 - ) + tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) assert tru_updated_state == exp_updated_state - assert tru_handler_args == exp_handler_args + assert tru_callback_event == exp_callback_event @pytest.mark.parametrize( @@ -275,8 +261,9 @@ def test_extract_usage_metrics(): @pytest.mark.parametrize( - ("response", "exp_stop_reason", "exp_message", "exp_usage", "exp_metrics", "exp_request_state", "exp_messages"), + ("response", "exp_events"), [ + # Standard Message ( [ {"messageStart": {"role": "assistant"}}, @@ -297,28 +284,127 @@ def test_extract_usage_metrics(): } }, ], - "tool_use", - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", + }, + }, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + }, + }, + { + "callback": { + "current_tool_use": { + "input": { + "key": "value", + }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "tool_use", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "tool_use", + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], ), + # Empty Message ( [{}], - "end_turn", - { - "role": "assistant", - "content": [], - }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - {}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": {}, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [], + }, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ), + }, + ], ), + # Redacted Message ( [ {"messageStart": {"role": "assistant"}}, @@ -345,77 +431,161 @@ def test_extract_usage_metrics(): } }, ], - "guardrail_intervened", - { - "role": "assistant", - "content": [{"text": "REDACTED."}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "REDACTED"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", + }, + }, + }, + }, + }, + { + "callback": { + "data": "Hello!", + "delta": { + "text": "Hello!", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", + }, + }, + }, + }, + { + "callback": { + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "guardrail_intervened", + { + "role": "assistant", + "content": [{"text": "REDACTED."}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], ), ], ) -def test_process_stream( - response, exp_stop_reason, exp_message, exp_usage, exp_metrics, exp_request_state, exp_messages -): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 - - tru_messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.process_stream(response, callback_handler, tru_messages) - ) - - assert tru_stop_reason == exp_stop_reason - assert tru_message == exp_message - assert tru_usage == exp_usage - assert tru_metrics == exp_metrics - assert tru_request_state == exp_request_state - assert tru_messages == exp_messages +def test_process_stream(response, exp_events): + messages = [{"role": "user", "content": [{"text": "Some input!"}]}] + stream = strands.event_loop.streaming.process_stream(response, messages) + tru_events = list(stream) + assert tru_events == exp_events -def test_stream_messages(agent): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 +def test_stream_messages(): mock_model = unittest.mock.MagicMock() mock_model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, ] - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.stream_messages( - mock_model, - model_id="test_model", - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - callback_handler=callback_handler, - agent=agent, - ) + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], + tool_config=None, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test"}]} - exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - exp_metrics = {"latencyMs": 0} - exp_request_state = {"calls": 1} - - assert ( - tru_stop_reason == exp_stop_reason - and tru_message == exp_message - and tru_usage == exp_usage - and tru_metrics == exp_metrics - and tru_request_state == exp_request_state - ) + tru_events = list(stream) + exp_events = [ + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", + }, + }, + }, + }, + }, + { + "callback": { + "data": "test", + "delta": { + "text": "test", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "stop": ( + "end_turn", + {"role": "assistant", "content": [{"text": "test"}]}, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ) + }, + ] + assert tru_events == exp_events mock_model.converse.assert_called_with( [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], From be57089d4f0984f53e55d466fc18d05c0b591e75 Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Fri, 20 Jun 2025 15:27:30 -0400 Subject: [PATCH 22/22] Initial A2A server Integration (#218) --- pyproject.toml | 25 +++- src/strands/agent/agent.py | 9 ++ src/strands/multiagent/__init__.py | 13 ++ src/strands/multiagent/a2a/__init__.py | 14 +++ src/strands/multiagent/a2a/agent.py | 149 ++++++++++++++++++++++ src/strands/multiagent/a2a/executor.py | 67 ++++++++++ tests/multiagent/__init__.py | 1 + tests/multiagent/a2a/__init__.py | 1 + tests/multiagent/a2a/conftest.py | 41 ++++++ tests/multiagent/a2a/test_agent.py | 165 +++++++++++++++++++++++++ tests/multiagent/a2a/test_executor.py | 118 ++++++++++++++++++ 11 files changed, 599 insertions(+), 4 deletions(-) create mode 100644 src/strands/multiagent/__init__.py create mode 100644 src/strands/multiagent/a2a/__init__.py create mode 100644 src/strands/multiagent/a2a/agent.py create mode 100644 src/strands/multiagent/a2a/executor.py create mode 100644 tests/multiagent/__init__.py create mode 100644 tests/multiagent/a2a/__init__.py create mode 100644 tests/multiagent/a2a/conftest.py create mode 100644 tests/multiagent/a2a/test_agent.py create mode 100644 tests/multiagent/a2a/test_executor.py diff --git a/pyproject.toml b/pyproject.toml index e0cc2578..4bb69ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,8 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + # excluding due to A2A and OTEL http exporter dependency conflict + "mypy -p src --exclude src/strands/multiagent" ] lint-fix = [ "ruff check --fix" @@ -137,17 +138,29 @@ features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] dev-mode = true features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] +[tool.hatch.envs.a2a.scripts] +run = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}" +] +run-cov = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}" +] +lint-check = [ + "ruff check", + "mypy -p src/strands/multiagent/a2a" +] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] - [tool.hatch.envs.hatch-test.scripts] run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a" ] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a" ] cov-combine = [] @@ -181,6 +194,10 @@ prepare = [ "hatch fmt --formatter", "hatch test --all" ] +test-a2a = [ + # required to run manually due to A2A and OTEL http exporter dependency conflict + "hatch -e a2a run run {args}" +] [tool.mypy] python_version = "3.10" diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index a5e26a07..4ecc42a9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -220,6 +220,9 @@ def __init__( record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, ): """Initialize the Agent with the specified configuration. @@ -252,6 +255,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to True. trace_attributes: Custom trace attributes to apply to the agent's trace span. + name: name of the Agent + Defaults to None. + description: description of what the Agent does + Defaults to None. Raises: ValueError: If max_parallel_tools is less than 1. @@ -313,6 +320,8 @@ def __init__( self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None self.tool_caller = Agent.ToolCaller(self) + self.name = name + self.description = description @property def tool(self) -> ToolCaller: diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py new file mode 100644 index 00000000..1cef1425 --- /dev/null +++ b/src/strands/multiagent/__init__.py @@ -0,0 +1,13 @@ +"""Multiagent capabilities for Strands Agents. + +This module provides support for multiagent systems, including agent-to-agent (A2A) +communication protocols and coordination mechanisms. + +Submodules: + a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables + standardized communication between agents. +""" + +from . import a2a + +__all__ = ["a2a"] diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py new file mode 100644 index 00000000..c5425618 --- /dev/null +++ b/src/strands/multiagent/a2a/__init__.py @@ -0,0 +1,14 @@ +"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. + +This module provides classes and utilities for enabling Strands Agents to communicate +with other agents using the Agent-to-Agent (A2A) protocol. + +Docs: https://google-a2a.github.io/A2A/latest/ + +Classes: + A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. +""" + +from .agent import A2AAgent + +__all__ = ["A2AAgent"] diff --git a/src/strands/multiagent/a2a/agent.py b/src/strands/multiagent/a2a/agent.py new file mode 100644 index 00000000..4359100d --- /dev/null +++ b/src/strands/multiagent/a2a/agent.py @@ -0,0 +1,149 @@ +"""A2A-compatible wrapper for Strands Agent. + +This module provides the A2AAgent class, which adapts a Strands Agent to the A2A protocol, +allowing it to be used in A2A-compatible systems. +""" + +import logging +from typing import Any, Literal + +import uvicorn +from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from fastapi import FastAPI +from starlette.applications import Starlette + +from ...agent.agent import Agent as SAAgent +from .executor import StrandsA2AExecutor + +logger = logging.getLogger(__name__) + + +class A2AAgent: + """A2A-compatible wrapper for Strands Agent.""" + + def __init__( + self, + agent: SAAgent, + *, + # AgentCard + host: str = "0.0.0.0", + port: int = 9000, + version: str = "0.0.1", + ): + """Initialize an A2A-compatible agent from a Strands agent. + + Args: + agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. + host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". + port: The port to bind the A2A server to. Defaults to 9000. + version: The version of the agent. Defaults to "0.0.1". + """ + self.host = host + self.port = port + self.http_url = f"http://{self.host}:{self.port}/" + self.version = version + self.strands_agent = agent + self.name = self.strands_agent.name + self.description = self.strands_agent.description + # TODO: enable configurable capabilities and request handler + self.capabilities = AgentCapabilities() + self.request_handler = DefaultRequestHandler( + agent_executor=StrandsA2AExecutor(self.strands_agent), + task_store=InMemoryTaskStore(), + ) + logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + + @property + def public_agent_card(self) -> AgentCard: + """Get the public AgentCard for this agent. + + The AgentCard contains metadata about the agent, including its name, + description, URL, version, skills, and capabilities. This information + is used by other agents and systems to discover and interact with this agent. + + Returns: + AgentCard: The public agent card containing metadata about this agent. + + Raises: + ValueError: If name or description is None or empty. + """ + if not self.name: + raise ValueError("A2A agent name cannot be None or empty") + if not self.description: + raise ValueError("A2A agent description cannot be None or empty") + + return AgentCard( + name=self.name, + description=self.description, + url=self.http_url, + version=self.version, + skills=self.agent_skills, + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=self.capabilities, + ) + + @property + def agent_skills(self) -> list[AgentSkill]: + """Get the list of skills this agent provides. + + Skills represent specific capabilities that the agent can perform. + Strands agent tools are adapted to A2A skills. + + Returns: + list[AgentSkill]: A list of skills this agent provides. + """ + # TODO: translate Strands tools (native & MCP) to skills + return [] + + def to_starlette_app(self) -> Starlette: + """Create a Starlette application for serving this agent via HTTP. + + This method creates a Starlette application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + Starlette: A Starlette application configured to serve this agent. + """ + return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def to_fastapi_app(self) -> FastAPI: + """Create a FastAPI application for serving this agent via HTTP. + + This method creates a FastAPI application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + FastAPI: A FastAPI application configured to serve this agent. + """ + return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def serve(self, app_type: Literal["fastapi", "starlette"] = "starlette", **kwargs: Any) -> None: + """Start the A2A server with the specified application type. + + This method starts an HTTP server that exposes the agent via the A2A protocol. + The server can be implemented using either FastAPI or Starlette, depending on + the specified app_type. + + Args: + app_type: The type of application to serve, either "fastapi" or "starlette". + Defaults to "starlette". + **kwargs: Additional keyword arguments to pass to uvicorn.run. + """ + try: + logger.info("Starting Strands A2A server...") + if app_type == "fastapi": + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) + else: + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) + except KeyboardInterrupt: + logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") + except Exception: + logger.exception("Strands A2A server encountered exception.") + finally: + logger.info("Strands A2A server has shutdown.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py new file mode 100644 index 00000000..b7a7af09 --- /dev/null +++ b/src/strands/multiagent/a2a/executor.py @@ -0,0 +1,67 @@ +"""Strands Agent executor for the A2A protocol. + +This module provides the StrandsA2AExecutor class, which adapts a Strands Agent +to be used as an executor in the A2A protocol. It handles the execution of agent +requests and the conversion of Strands Agent responses to A2A events. +""" + +import logging + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message +from a2a.utils.errors import ServerError + +from ...agent.agent import Agent as SAAgent +from ...agent.agent_result import AgentResult as SAAgentResult + +log = logging.getLogger(__name__) + + +class StrandsA2AExecutor(AgentExecutor): + """Executor that adapts a Strands Agent to the A2A protocol.""" + + def __init__(self, agent: SAAgent): + """Initialize a StrandsA2AExecutor. + + Args: + agent: The Strands Agent to adapt to the A2A protocol. + """ + self.agent = agent + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute a request using the Strands Agent and send the response as A2A events. + + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. + + Args: + context: The A2A request context, containing the user's input and other metadata. + event_queue: The A2A event queue, used to send response events. + """ + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel an ongoing execution. + + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. + + Args: + context: The A2A request context. + event_queue: The A2A event queue. + + Raises: + ServerError: Always raised with an UnsupportedOperationError, as cancellation + is not currently supported. + """ + raise ServerError(error=UnsupportedOperationError()) diff --git a/tests/multiagent/__init__.py b/tests/multiagent/__init__.py new file mode 100644 index 00000000..b43bae53 --- /dev/null +++ b/tests/multiagent/__init__.py @@ -0,0 +1 @@ +"""Tests for the multiagent module.""" diff --git a/tests/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py new file mode 100644 index 00000000..eb5487d9 --- /dev/null +++ b/tests/multiagent/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A module.""" diff --git a/tests/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py new file mode 100644 index 00000000..558a4594 --- /dev/null +++ b/tests/multiagent/a2a/conftest.py @@ -0,0 +1,41 @@ +"""Common fixtures for A2A module tests.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue + +from strands.agent.agent import Agent as SAAgent +from strands.agent.agent_result import AgentResult as SAAgentResult + + +@pytest.fixture +def mock_strands_agent(): + """Create a mock Strands Agent for testing.""" + agent = MagicMock(spec=SAAgent) + agent.name = "Test Agent" + agent.description = "A test agent for unit testing" + + # Setup default response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + agent.return_value = mock_result + + return agent + + +@pytest.fixture +def mock_request_context(): + """Create a mock RequestContext for testing.""" + context = MagicMock(spec=RequestContext) + context.get_user_input.return_value = "Test input" + return context + + +@pytest.fixture +def mock_event_queue(): + """Create a mock EventQueue for testing.""" + queue = MagicMock(spec=EventQueue) + queue.enqueue_event = AsyncMock() + return queue diff --git a/tests/multiagent/a2a/test_agent.py b/tests/multiagent/a2a/test_agent.py new file mode 100644 index 00000000..5558c2af --- /dev/null +++ b/tests/multiagent/a2a/test_agent.py @@ -0,0 +1,165 @@ +"""Tests for the A2AAgent class.""" + +from unittest.mock import patch + +import pytest +from a2a.types import AgentCapabilities, AgentCard +from fastapi import FastAPI +from starlette.applications import Starlette + +from strands.multiagent.a2a.agent import A2AAgent + + +def test_a2a_agent_initialization(mock_strands_agent): + """Test that A2AAgent initializes correctly with default values.""" + a2a_agent = A2AAgent(mock_strands_agent) + + assert a2a_agent.strands_agent == mock_strands_agent + assert a2a_agent.name == "Test Agent" + assert a2a_agent.description == "A test agent for unit testing" + assert a2a_agent.host == "0.0.0" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://0.0.0:9000/" + assert a2a_agent.version == "0.0.1" + assert isinstance(a2a_agent.capabilities, AgentCapabilities) + + +def test_a2a_agent_initialization_with_custom_values(mock_strands_agent): + """Test that A2AAgent initializes correctly with custom values.""" + a2a_agent = A2AAgent( + mock_strands_agent, + host="127.0.0.1", + port=8080, + version="1.0.0", + ) + + assert a2a_agent.host == "127.0.0.1" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.version == "1.0.0" + + +def test_public_agent_card(mock_strands_agent): + """Test that public_agent_card returns a valid AgentCard.""" + a2a_agent = A2AAgent(mock_strands_agent) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + assert card.url == "http://0.0.0:9000/" + assert card.version == "0.0.1" + assert card.defaultInputModes == ["text"] + assert card.defaultOutputModes == ["text"] + assert card.skills == [] + assert card.capabilities == a2a_agent.capabilities + + +def test_public_agent_card_with_missing_name(mock_strands_agent): + """Test that public_agent_card raises ValueError when name is missing.""" + mock_strands_agent.name = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent name cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_public_agent_card_with_missing_description(mock_strands_agent): + """Test that public_agent_card raises ValueError when description is missing.""" + mock_strands_agent.description = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent description cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_agent_skills(mock_strands_agent): + """Test that agent_skills returns an empty list (current implementation).""" + a2a_agent = A2AAgent(mock_strands_agent) + + skills = a2a_agent.agent_skills + + assert isinstance(skills, list) + assert len(skills) == 0 + + +def test_to_starlette_app(mock_strands_agent): + """Test that to_starlette_app returns a Starlette application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app(mock_strands_agent): + """Test that to_fastapi_app returns a FastAPI application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +@patch("uvicorn.run") +def test_serve_with_starlette(mock_run, mock_strands_agent): + """Test that serve starts a Starlette server by default.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], Starlette) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_fastapi(mock_run, mock_strands_agent): + """Test that serve starts a FastAPI server when specified.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(app_type="fastapi") + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], FastAPI) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_custom_kwargs(mock_run, mock_strands_agent): + """Test that serve passes additional kwargs to uvicorn.run.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(log_level="debug", reload=True) + + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["log_level"] == "debug" + assert kwargs["reload"] is True + + +@patch("uvicorn.run", side_effect=KeyboardInterrupt) +def test_serve_handles_keyboard_interrupt(mock_run, mock_strands_agent, caplog): + """Test that serve handles KeyboardInterrupt gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server shutdown requested (KeyboardInterrupt)" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text + + +@patch("uvicorn.run", side_effect=Exception("Test exception")) +def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): + """Test that serve handles general exceptions gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server encountered exception" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 00000000..2ac9bed9 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,118 @@ +"""Tests for the StrandsA2AExecutor class.""" + +from unittest.mock import MagicMock + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult as SAAgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor + + +def test_executor_initialization(mock_strands_agent): + """Test that StrandsA2AExecutor initializes correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + assert executor.agent == mock_strands_agent + + +@pytest.mark.asyncio +async def test_execute_with_text_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes text responses correctly.""" + # Setup mock agent response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify event was enqueued + mock_event_queue.enqueue_event.assert_called_once() + args, _ = mock_event_queue.enqueue_event.call_args + event = args[0] + assert event.parts[0].root.text == "Test response" + + +@pytest.mark.asyncio +async def test_execute_with_multiple_text_blocks(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes multiple text blocks correctly.""" + # Setup mock agent response with multiple text blocks + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "First response"}, {"text": "Second response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify events were enqueued + assert mock_event_queue.enqueue_event.call_count == 2 + + # Check first event + args1, _ = mock_event_queue.enqueue_event.call_args_list[0] + event1 = args1[0] + assert event1.parts[0].root.text == "First response" + + # Check second event + args2, _ = mock_event_queue.enqueue_event.call_args_list[1] + event2 = args2[0] + assert event2.parts[0].root.text == "Second response" + + +@pytest.mark.asyncio +async def test_execute_with_empty_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles empty responses correctly.""" + # Setup mock agent response with empty content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": []} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_with_no_message(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles responses with no message correctly.""" + # Setup mock agent response with no message + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = None + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError)