diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a6efda65..fa894231 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,37 +1,38 @@ ## Description -[Provide a detailed description of the changes in this PR] + ## Related Issues -[Link to related issues using #issue-number format] + + ## Documentation PR -[Link to related associated PR in the agent-docs repo] + + ## Type of Change -- Bug fix -- New feature -- Breaking change -- Documentation update -- Other (please describe): -[Choose one of the above types of changes] + +Bug fix +New feature +Breaking change +Documentation update +Other (please describe): ## Testing -[How have you tested the change?] -* `hatch fmt --linter` -* `hatch fmt --formatter` -* `hatch test --all` -* Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +- [ ] I ran `hatch run prepare` ## Checklist - [ ] I have read the CONTRIBUTING document -- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] I have added any necessary tests that prove my fix is effective or my feature works - [ ] I have updated the documentation accordingly -- [ ] I have added an appropriate example to the documentation to outline the feature +- [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published +---- + By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 00000000..39b53c49 --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,74 @@ +name: Secure Integration test + +on: + pull_request_target: + branches: main + +jobs: + authorization-check: + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result }} + steps: + - name: Collaborator Check + uses: actions/github-script@v7 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) + return "manual-approval" + } + check-access-and-checkout: + runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + - name: Checkout head commit + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run integration tests + env: + AWS_REGION: us-east-1 + AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + id: tests + run: | + hatch test tests-integ + + diff --git a/pyproject.toml b/pyproject.toml index bd309732..56c1a40e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,13 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -59,7 +58,6 @@ dev = [ "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -78,13 +76,23 @@ ollama = [ openai = [ "openai>=1.68.0,<2.0.0", ] +otel = [ + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", +] +a2a = [ + "a2a-sdk>=0.2.6", + "uvicorn>=0.34.2", + "httpx>=0.28.1", + "fastapi>=0.115.12", + "starlette>=0.46.2", +] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -107,7 +115,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -123,8 +131,11 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] +[tool.hatch.envs.a2a] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] [[tool.hatch.envs.hatch-test.matrix]] @@ -165,7 +176,11 @@ test = [ test-integ = [ "hatch test tests-integ {args}" ] - +prepare = [ + "hatch fmt --linter", + "hatch fmt --formatter", + "hatch test --all" +] [tool.mypy] python_version = "3.10" diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 4d2fa1fe..6618d332 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,7 +8,12 @@ from .agent import Agent from .agent_result import AgentResult -from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager +from .conversation_manager import ( + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +) __all__ = [ "Agent", @@ -16,4 +21,5 @@ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", + "SummarizingConversationManager", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0651d452..56f5b92e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -84,6 +84,7 @@ def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: name: The name of the attribute (tool) being accessed. @@ -92,9 +93,30 @@ def __getattr__(self, name: str) -> Callable[..., Any]: A function that when called will execute the named tool. Raises: - AttributeError: If no tool with the given name exists. + AttributeError: If no tool with the given name exists or if multiple tools match the given name. """ + def find_normalized_tool_name() -> Optional[str]: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def caller(**kwargs: Any) -> Any: """Call a tool directly by name. @@ -115,14 +137,13 @@ def caller(**kwargs: Any) -> Any: Raises: AttributeError: If the tool doesn't exist. """ - if name not in self._agent.tool_registry.registry: - raise AttributeError(f"Tool '{name}' not found") + normalized_name = find_normalized_tool_name() # Create unique tool ID and set up the tool request tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" tool_use = { "toolUseId": tool_id, - "name": name, + "name": normalized_name, "input": kwargs.copy(), } diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 68541877..c5962321 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -6,6 +6,8 @@ - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence +- SummarizingConversationManager: An implementation that summarizes older context instead + of simply trimming it Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -14,5 +16,11 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager +from .summarizing_conversation_manager import SummarizingConversationManager -__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] +__all__ = [ + "ConversationManager", + "NullConversationManager", + "SlidingWindowConversationManager", + "SummarizingConversationManager", +] diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 3381247c..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results def apply_management(self, agent: "Agent") -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: converted. """ messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: # Overwrite message history messages[:] = messages[trim_index:] + + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. + + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. + + Args: + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. + + Returns: + True if any changes were made to the message, False otherwise. + """ + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", + ) + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py new file mode 100644 index 00000000..a6b112dd --- /dev/null +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -0,0 +1,222 @@ +"""Summarizing conversation history management with configurable options.""" + +import logging +from typing import TYPE_CHECKING, List, Optional + +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + + +logger = logging.getLogger(__name__) + + +DEFAULT_SUMMARIZATION_PROMPT = """You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information +* +## Tools Executed +* Tool X: Result Y""" + + +class SummarizingConversationManager(ConversationManager): + """Implements a summarizing window manager. + + This manager provides a configurable option to summarize older context instead of + simply trimming it, helping preserve important information while staying within + context limits. + """ + + def __init__( + self, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_agent: Optional["Agent"] = None, + summarization_system_prompt: Optional[str] = None, + ): + """Initialize the summarizing conversation manager. + + Args: + summary_ratio: Ratio of messages to summarize vs keep when context overflow occurs. + Value between 0.1 and 0.8. Defaults to 0.3 (summarize 30% of oldest messages). + preserve_recent_messages: Minimum number of recent messages to always keep. + Defaults to 10 messages. + summarization_agent: Optional agent to use for summarization instead of the parent agent. + If provided, this agent can use tools as part of the summarization process. + summarization_system_prompt: Optional system prompt override for summarization. + If None, uses the default summarization prompt. + """ + if summarization_agent is not None and summarization_system_prompt is not None: + raise ValueError( + "Cannot provide both summarization_agent and summarization_system_prompt. " + "Agents come with their own system prompt." + ) + + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) + self.preserve_recent_messages = preserve_recent_messages + self.summarization_agent = summarization_agent + self.summarization_system_prompt = summarization_system_prompt + + def apply_management(self, agent: "Agent") -> None: + """Apply management strategy to conversation history. + + For the summarizing conversation manager, no proactive management is performed. + Summarization only occurs when there's a context overflow that triggers reduce_context. + + Args: + agent: The agent whose conversation history will be managed. + The agent's messages list is modified in-place. + """ + # No proactive management - summarization only happens on context overflow + pass + + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: + """Reduce context using summarization. + + Args: + agent: The agent whose conversation history will be reduced. + The agent's messages list is modified in-place. + e: The exception that triggered the context reduction, if any. + + Raises: + ContextWindowOverflowException: If the context cannot be summarized. + """ + try: + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] + + # Generate summary + summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [summary_message] + remaining_messages + + except Exception as summarization_error: + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + """Generate a summary of the provided messages. + + Args: + messages: The messages to summarize. + agent: The agent instance to use for summarization. + + Returns: + A message containing the conversation summary. + + Raises: + Exception: If summary generation fails. + """ + # Choose which agent to use for summarization + summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + + # Save original system prompt and messages to restore later + original_system_prompt = summarization_agent.system_prompt + original_messages = summarization_agent.messages.copy() + + try: + # Only override system prompt if no agent was provided during initialization + if self.summarization_agent is None: + # Use custom system prompt if provided, otherwise use default + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + # Temporarily set the system prompt for summarization + summarization_agent.system_prompt = system_prompt + summarization_agent.messages = messages + + # Use the agent to generate summary with rich content (can use tools if needed) + result = summarization_agent("Please summarize this conversation.") + + return result.message + + finally: + # Restore original agent state + summarization_agent.system_prompt = original_system_prompt + summarization_agent.messages = original_messages + + def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. + + Uses the same logic as SlidingWindowConversationManager for consistency. + + Args: + messages: The full list of messages. + split_point: The initially calculated split point. + + Returns: + The adjusted split point that doesn't break ToolUse/ToolResult pairs. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + if split_point > len(messages): + raise ContextWindowOverflowException("Split point exceeds message array length") + + if split_point == len(messages): + return split_point + + # Find the next valid split_point + while split_point < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[split_point]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[split_point]["content"]) + and split_point + 1 < len(messages) + and not any("toolResult" in content for content in messages[split_point + 1]["content"]) + ) + ): + split_point += 1 + else: + break + else: + # If we didn't find a valid split_point, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split_point diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index a5c85668..6dc0d9ee 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f..02a56a1c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,7 +22,7 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -136,6 +136,7 @@ def event_loop_cycle( metrics: Metrics # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( @@ -160,16 +161,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -177,7 +169,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -248,6 +240,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d7..4e1a39dc 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9f731996..e9a37a4a 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -11,14 +11,15 @@ from importlib.metadata import version from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - -# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore -from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] -from opentelemetry.sdk.trace import TracerProvider +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult from ..types.content import Message, Messages @@ -28,6 +29,19 @@ logger = logging.getLogger(__name__) +HAS_OTEL_EXPORTER_MODULE = False +OTEL_EXPORTER_MODULE_ERROR = ( + "opentelemetry-exporter-otlp-proto-http not detected;" + "please install strands-agents with the optional 'otel' target" + "otel http exporting is currently DISABLED" +) +try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + HAS_OTEL_EXPORTER_MODULE = True +except ImportError: + pass + class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" @@ -135,16 +149,30 @@ def __init__( self.service_name = service_name self.otlp_headers = otlp_headers or {} - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + # Create resource with service information resource = Resource.create( { @@ -156,7 +184,7 @@ def _initialize_tracer(self) -> None: ) # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -165,7 +193,7 @@ def _initialize_tracer(self) -> None: self.tracer_provider.add_span_processor(console_processor) # Add OTLP exporter if endpoint is provided - if self.otlp_endpoint and self.tracer_provider: + if HAS_OTEL_EXPORTER_MODULE and self.otlp_endpoint and self.tracer_provider: try: # Ensure endpoint has the right format endpoint = self.otlp_endpoint @@ -190,17 +218,23 @@ def _initialize_tracer(self) -> None: logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) + elif self.otlp_endpoint and self.tracer_provider: + raise ModuleNotFoundError(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -214,7 +248,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -226,7 +260,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -241,7 +275,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -274,13 +308,13 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -296,12 +330,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -330,7 +364,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -349,9 +383,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -376,7 +408,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -404,10 +436,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -438,7 +470,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -468,7 +500,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -508,7 +540,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -559,13 +591,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2efdd600..e56ee999 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + if self.registry.get(tool.tool_name) is None: + normalized_name = tool.tool_name.replace("-", "_") + + matching_tools = [ + tool_name + for (tool_name, tool) in self.registry.items() + if tool_name.replace("-", "_") == normalized_name + ] + + if matching_tools: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'." + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + # Register in main registry self.registry[tool.tool_name] = tool diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py index 9ffd1bdf..bf0be706 100644 --- a/tests-integ/test_bedrock_guardrails.py +++ b/tests-integ/test_bedrock_guardrails.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def boto_session(): - return boto3.Session(region_name="us-west-2") + return boto3.Session(region_name="us-east-1") @pytest.fixture(scope="module") @@ -142,7 +142,7 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_stream_processing_mode=processing_mode, guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, - region_name="us-west-2", + region_name="us-east-1", ) agent = Agent( diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index 59ae2a14..8b1dade3 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -1,8 +1,10 @@ import base64 +import os import threading import time from typing import List, Literal +import pytest from mcp import StdioServerParameters, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -101,6 +103,10 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", +) def test_streamable_http_mcp_client(): server_thread = threading.Thread( target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61f..86f6b42f 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -7,7 +7,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 00000000..5dcf4944 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0ea20b64..d6f47be0 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 @@ -674,6 +710,46 @@ def function(system_prompt: str) -> str: ) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + tool_name = "system-prompter" + + @strands.tools.tool(name=tool_name) + def function(system_prompt: str) -> str: + return system_prompt + + tool = strands.tools.tools.FunctionTool(function) + agent.tool_registry.register_tool(tool) + + mock_randint.return_value = 1 + + agent.tool.system_prompter(system_prompt="tool prompt") + + # Verify the correct tool was invoked + assert agent.tool_handler.process.call_count == 1 + tool_call = agent.tool_handler.process.call_args.kwargs.get("tool") + + assert tool_call == { + # Note that the tool-use uses the "python safe" name + "toolUseId": "tooluse_system_prompter_1", + # But the name of the tool is the one in the registry + "name": tool_name, + "input": {"system_prompt": "tool prompt"}, + } + + +def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + mock_randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Tool 'system_prompter_1' not found" + + def test_agent_with_none_callback_handler_prints_nothing(): agent = Agent() diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index bbec3cd1..7d43199e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -25,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -168,7 +169,7 @@ def test_apply_management(conversation_manager, messages, expected_messages): def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): - manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, @@ -182,6 +183,40 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con assert messages == original_messages +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + + assert messages == expected_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py new file mode 100644 index 00000000..9952203e --- /dev/null +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -0,0 +1,566 @@ +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.types.content import Messages +from strands.types.exceptions import ContextWindowOverflowException + +if TYPE_CHECKING: + from strands.agent.agent import Agent + + +class MockAgent: + """Mock agent for testing summarization.""" + + def __init__(self, summary_response="This is a summary of the conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.call_tracker = Mock() + + def __call__(self, prompt): + """Mock agent call that returns a summary.""" + self.call_tracker(prompt) + result = Mock() + result.message = {"role": "assistant", "content": [{"text": self.summary_response}]} + return result + + +def create_mock_agent(summary_response="This is a summary of the conversation.") -> "Agent": + """Factory function that returns a properly typed MockAgent.""" + return cast("Agent", MockAgent(summary_response)) + + +@pytest.fixture +def mock_agent(): + """Fixture for mock agent.""" + return create_mock_agent() + + +@pytest.fixture +def summarizing_manager(): + """Fixture for summarizing conversation manager with default settings.""" + return SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + + +def test_init_default_values(): + """Test initialization with default values.""" + manager = SummarizingConversationManager() + + assert manager.summarization_agent is None + assert manager.summary_ratio == 0.3 + assert manager.preserve_recent_messages == 10 + + +def test_init_clamps_summary_ratio(): + """Test that summary_ratio is clamped to valid range.""" + # Test lower bound + manager = SummarizingConversationManager(summary_ratio=0.05) + assert manager.summary_ratio == 0.1 + + # Test upper bound + manager = SummarizingConversationManager(summary_ratio=0.95) + assert manager.summary_ratio == 0.8 + + +def test_reduce_context_raises_when_no_agent(): + """Test that reduce_context raises exception when agent has no messages.""" + manager = SummarizingConversationManager() + + # Create a mock agent with no messages + mock_agent = Mock() + empty_messages: Messages = [] + mock_agent.messages = empty_messages + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_with_summarization(summarizing_manager, mock_agent): + """Test reduce_context with summarization enabled.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Message 3"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + mock_agent.messages = test_messages + + summarizing_manager.reduce_context(mock_agent) + + # Should have: 1 summary message + 2 preserved recent messages + remaining from summarization + assert len(mock_agent.messages) == 4 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + first_content = mock_agent.messages[0]["content"][0] + assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] + + # Recent messages should be preserved + assert "Message 3" in str(mock_agent.messages[-2]["content"]) + assert "Response 3" in str(mock_agent.messages[-1]["content"]) + + +def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, mock_agent): + """Test that reduce_context raises exception when there are too few messages to summarize effectively.""" + # Create a scenario where calculation results in 0 messages to summarize + manager = SummarizingConversationManager( + summary_ratio=0.1, # Very small ratio + preserve_recent_messages=5, # High preservation + ) + + insufficient_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_insufficient_messages_for_summarization(mock_agent): + """Test reduce_context when there aren't enough messages to summarize.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=3, + ) + + insufficient_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_messages + + # This should raise an exception since there aren't enough messages to summarize + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_raises_on_summarization_failure(): + """Test that reduce_context raises exception when summarization fails.""" + # Create an agent that will fail + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + failing_agent_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + ] + failing_agent.messages = failing_agent_messages + + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + with pytest.raises(Exception, match="Agent failed"): + manager.reduce_context(failing_agent) + + # Should log the error + mock_logger.error.assert_called_once() + + +def test_generate_summary(summarizing_manager, mock_agent): + """Test the _generate_summary method.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + summary = summarizing_manager._generate_summary(test_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): + """Test summary generation with tool use and results.""" + tool_messages: Messages = [ + {"role": "user", "content": [{"text": "Use a tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + summary = summarizing_manager._generate_summary(tool_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_raises_on_agent_failure(): + """Test that _generate_summary raises exception when agent fails.""" + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + empty_failing_messages: Messages = [] + failing_agent.messages = empty_failing_messages + + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should raise the exception from the agent + with pytest.raises(Exception, match="Agent failed"): + manager._generate_summary(messages, failing_agent) + + +def test_adjust_split_point_for_tool_pairs(summarizing_manager): + """Test that the split point is adjusted to avoid breaking ToolUse/ToolResult pairs.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Tool output"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Response after tool"}]}, + ] + + # If we try to split at message 2 (the ToolResult), it should move forward to message 3 + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert adjusted_split == 3 # Should move to after the ToolResult + + # If we try to split at message 3, it should be fine (no tool issues) + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + assert adjusted_split == 3 + + # If we try to split at message 1 (toolUse with following toolResult), it should be valid + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert adjusted_split == 1 # Should be valid because toolResult follows + + +def test_apply_management_no_op(summarizing_manager, mock_agent): + """Test apply_management does not modify messages (no-op behavior).""" + apply_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More messages"}]}, + {"role": "assistant", "content": [{"text": "Even more"}]}, + ] + mock_agent.messages = apply_test_messages + original_messages = mock_agent.messages.copy() + + summarizing_manager.apply_management(mock_agent) + + # Should never modify messages - summarization only happens on context overflow + assert mock_agent.messages == original_messages + + +def test_init_with_custom_parameters(): + """Test initialization with custom parameters.""" + mock_agent = create_mock_agent() + + manager = SummarizingConversationManager( + summary_ratio=0.4, + preserve_recent_messages=5, + summarization_agent=mock_agent, + ) + assert manager.summary_ratio == 0.4 + assert manager.preserve_recent_messages == 5 + assert manager.summarization_agent == mock_agent + assert manager.summarization_system_prompt is None + + +def test_init_with_both_agent_and_prompt_raises_error(): + """Test that providing both agent and system prompt raises ValueError.""" + mock_agent = create_mock_agent() + custom_prompt = "Custom summarization prompt" + + with pytest.raises(ValueError, match="Cannot provide both summarization_agent and summarization_system_prompt"): + SummarizingConversationManager( + summarization_agent=mock_agent, + summarization_system_prompt=custom_prompt, + ) + + +def test_uses_summarization_agent_when_provided(): + """Test that summarization_agent is used when provided.""" + summary_agent = create_mock_agent("Custom summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the dedicated summarization agent, not the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Custom summary from dedicated agent" + + # Assert that the summarization agent was called + summary_agent.call_tracker.assert_called_once() + + +def test_uses_parent_agent_when_no_summarization_agent(): + """Test that parent agent is used when no summarization_agent is provided.""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Parent agent summary" + + # Assert that the parent agent was called + parent_agent.call_tracker.assert_called_once() + + +def test_uses_custom_system_prompt(): + """Test that custom system prompt is used when provided.""" + custom_prompt = "Custom system prompt for summarization" + manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) + mock_agent = create_mock_agent() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Capture the agent's system prompt changes + original_prompt = mock_agent.system_prompt + manager._generate_summary(messages, mock_agent) + + # The agent's system prompt should be restored after summarization + assert mock_agent.system_prompt == original_prompt + + +def test_agent_state_restoration(): + """Test that agent state is properly restored after summarization.""" + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + + # Set initial state + original_system_prompt = "Original system prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] + mock_agent.system_prompt = original_system_prompt + mock_agent.messages = original_messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager._generate_summary(messages, mock_agent) + + # State should be restored + assert mock_agent.system_prompt == original_system_prompt + assert mock_agent.messages == original_messages + + +def test_agent_state_restoration_on_exception(): + """Test that agent state is restored even when summarization fails.""" + manager = SummarizingConversationManager() + + # Create an agent that fails during summarization + mock_agent = Mock() + mock_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + mock_agent.messages = agent_messages + mock_agent.side_effect = Exception("Summarization failed") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should restore state even on exception + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, mock_agent) + + # State should still be restored + assert mock_agent.system_prompt == "Original prompt" + + +def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): + """Test that tool pair adjustment works correctly with the forward-search logic.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + mock_agent = create_mock_agent() + # Create messages where the split point would be adjusted to 0 due to tool pairs + tool_pair_messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest message"}]}, + ] + mock_agent.messages = tool_pair_messages + + # With 3 messages, preserve_recent_messages=1, summary_ratio=0.5: + # messages_to_summarize_count = (3 - 1) * 0.5 = 1 + # But split point adjustment will move forward from the toolUse, potentially increasing count + manager.reduce_context(mock_agent) + # Should have summary + remaining messages + assert len(mock_agent.messages) == 2 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + summary_content = mock_agent.messages[0]["content"][0] + assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] + + # Last message should be the preserved recent message + assert mock_agent.messages[1]["role"] == "user" + assert mock_agent.messages[1]["content"][0]["text"] == "Latest message" + + +def test_adjust_split_point_exceeds_message_length(summarizing_manager): + """Test that split point exceeding message array length raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Try to split at point 5 when there are only 2 messages + with pytest.raises(ContextWindowOverflowException, match="Split point exceeds message array length"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 5) + + +def test_adjust_split_point_equals_message_length(summarizing_manager): + """Test that split point equal to message array length returns unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Split point equals message length (2) - should return unchanged + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 2 + + +def test_adjust_split_point_no_tool_result_at_split(summarizing_manager): + """Test split point that doesn't contain tool result, ensuring we reach return split_point.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + + # Split point message is not a tool result, so it should directly return split_point + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert result == 1 + + +def test_adjust_split_point_tool_result_without_tool_use(summarizing_manager): + """Test that having tool results without tool uses raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Has tool result but no tool use - invalid state + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + + +def test_adjust_split_point_tool_result_moves_to_end(summarizing_manager): + """Test tool result at split point moves forward to valid position at end.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "different_tool", "input": {}}}]}, + ] + + # Split at message 2 (toolResult) - will move forward to message 3 (toolUse at end is valid) + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 3 + + +def test_adjust_split_point_tool_result_no_forward_position(summarizing_manager): + """Test tool result at split point where forward search finds no valid position.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "Message between"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Split at message 3 (toolResult) - will try to move forward but no valid position exists + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + + +def test_reduce_context_adjustment_returns_zero(): + """Test that tool pair adjustment can return zero, triggering the check at line 122.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + # Mock the adjustment method to return 0 + def mock_adjust(messages, split_point): + return 0 # This should trigger the <= 0 check at line 122 + + manager._adjust_split_point_for_tool_pairs = mock_adjust + + mock_agent = Mock() + simple_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = simple_messages + + # The adjustment method will return 0, which should trigger line 122-123 + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef..fe1b3e9f 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f1..734457aa 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -157,7 +164,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( +def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -168,26 +176,12 @@ def test_event_loop_cycle_text_response_input_too_long( tool_execution_handler, ): model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, {"contentBlockStop": {}}, ], ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( model=model, @@ -204,10 +198,12 @@ def test_event_loop_cycle_text_response_input_too_long( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") -def test_event_loop_cycle_text_response_throttling( +def test_event_loop_cycle_exponential_backoff( + mock_time, model, model_id, system_prompt, @@ -217,7 +213,11 @@ def test_event_loop_cycle_text_response_throttling( tool_handler, tool_execution_handler, ): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, @@ -235,11 +235,16 @@ def test_event_loop_cycle_text_response_throttling( tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error( diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py index 395c71a1..fcf531df 100644 --- a/tests/strands/event_loop/test_message_processor.py +++ b/tests/strands/event_loop/test_message_processor.py @@ -45,84 +45,3 @@ def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): result = message_processor.clean_orphaned_empty_tool_uses(test_messages) assert result == expected assert test_messages == expected_messages - - -@pytest.mark.parametrize( - "messages,expected_idx", - [ - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - 1, - ), - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - None, - ), - ( - [], - None, - ), - ], -) -def test_find_last_message_with_tool_results(messages, expected_idx): - idx = message_processor.find_last_message_with_tool_results(messages) - assert idx == expected_idx - - -@pytest.mark.parametrize( - "messages,msg_idx,expected_changed,expected_content", - [ - ( - [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], - } - ], - 0, - True, - [ - { - "toolResult": { - "toolUseId": "1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - } - } - ], - ), - ( - [{"role": "user", "content": [{"text": "no tool result"}]}], - 0, - False, - [{"text": "no tool result"}], - ), - ( - [], - 0, - False, - [], - ), - ( - [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], - 2, - False, - [{"toolResult": {"toolUseId": "1"}}], - ), - ], -) -def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): - test_messages = copy.deepcopy(messages) - changed = message_processor.truncate_tool_results(test_messages, msg_idx) - assert changed == expected_changed - if 0 <= msg_idx < len(test_messages): - assert test_messages[msg_idx]["content"] == expected_content - else: - assert test_messages == messages diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 32a4ac0a..030dcd37 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -4,7 +4,9 @@ from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -18,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -38,13 +52,16 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @pytest.fixture def mock_otlp_exporter(): - with mock.patch("strands.telemetry.tracer.OTLPSpanExporter") as mock_exporter: + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") as mock_exporter, + ): yield mock_exporter @@ -104,8 +121,17 @@ def env_with_both(): yield -def test_init_default(): +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -141,9 +167,14 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" + mock_is_initialized.return_value = False mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance @@ -161,13 +192,21 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance # Initialize Tracer - Tracer(otlp_endpoint="http://test-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) @@ -191,7 +230,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -262,7 +301,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -300,7 +339,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -338,7 +377,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -373,7 +412,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -464,9 +503,11 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" + mock_is_initialized.return_value = False + mock_resource_instance = mock.MagicMock() mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") @@ -474,7 +515,11 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( # This should not raise an exception, but should log an error # Initialize Tracer - Tracer(otlp_endpoint="http://invalid-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) @@ -486,6 +531,45 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_missing_module( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_resource +): + """Test initializing the tracer when the OTLP exporter module is missing.""" + mock_is_initialized.return_value = False + + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + # Initialize Tracer with OTLP endpoint but missing module + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + Tracer(otlp_endpoint="http://test-endpoint") + + # Verify the error message + assert "opentelemetry-exporter-otlp-proto-http not detected" in str(excinfo.value) + assert "otel http exporting is currently DISABLED" in str(excinfo.value) + + # Verify the tracer provider was created with correct resource + mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + + # Verify set_tracer_provider was not called since an exception was raised + mock_set_tracer_provider.assert_not_called() + + +def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -530,7 +614,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3dca5371..1b274f46 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,8 +2,11 @@ Tests for the SDK tool registry module. """ +from unittest.mock import MagicMock + import pytest +from strands.tools import PythonAgentTool from strands.tools.registry import ToolRegistry @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path(): with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): tool_registry.process_tools([invalid_path]) + + +def test_register_tool_with_similar_name_raises(): + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + + with pytest.raises(ValueError) as err: + tool_registry.register_tool(tool_2) + + assert ( + str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " + "Cannot add a duplicate tool which differs by a '-' or '_'" + )