diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 7a4f8317..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,5 +0,0 @@ -# These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# @strands-agents/contributors will be requested for -# review when someone opens a pull request. -* @strands-agents/maintainers \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a6efda65..fa894231 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,37 +1,38 @@ ## Description -[Provide a detailed description of the changes in this PR] + ## Related Issues -[Link to related issues using #issue-number format] + + ## Documentation PR -[Link to related associated PR in the agent-docs repo] + + ## Type of Change -- Bug fix -- New feature -- Breaking change -- Documentation update -- Other (please describe): -[Choose one of the above types of changes] + +Bug fix +New feature +Breaking change +Documentation update +Other (please describe): ## Testing -[How have you tested the change?] -* `hatch fmt --linter` -* `hatch fmt --formatter` -* `hatch test --all` -* Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +- [ ] I ran `hatch run prepare` ## Checklist - [ ] I have read the CONTRIBUTING document -- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] I have added any necessary tests that prove my fix is effective or my feature works - [ ] I have updated the documentation accordingly -- [ ] I have added an appropriate example to the documentation to outline the feature +- [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published +---- + By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 00000000..39b53c49 --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,74 @@ +name: Secure Integration test + +on: + pull_request_target: + branches: main + +jobs: + authorization-check: + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result }} + steps: + - name: Collaborator Check + uses: actions/github-script@v7 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) + return "manual-approval" + } + check-access-and-checkout: + runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + - name: Checkout head commit + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run integration tests + env: + AWS_REGION: us-east-1 + AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + id: tests + run: | + hatch test tests-integ + + diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 38e88691..2b2d026f 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -13,5 +13,7 @@ concurrency: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: - ref: ${{ github.event.pull_request.head.sha }} \ No newline at end of file + ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 4047f596..8967c552 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -8,11 +8,15 @@ on: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: ref: ${{ github.event.release.target_commitish }} build: name: Build distribution 📦 + permissions: + contents: read needs: - call-test-lint runs-on: ubuntu-latest @@ -75,4 +79,4 @@ jobs: name: python-package-distributions path: dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/pyproject.toml b/pyproject.toml index bd309732..4bb69ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,13 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -59,7 +58,6 @@ dev = [ "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -67,7 +65,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -78,13 +76,23 @@ ollama = [ openai = [ "openai>=1.68.0,<2.0.0", ] +otel = [ + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", +] +a2a = [ + "a2a-sdk>=0.2.6", + "uvicorn>=0.34.2", + "httpx>=0.28.1", + "fastapi>=0.115.12", + "starlette>=0.46.2", +] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -100,14 +108,15 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + # excluding due to A2A and OTEL http exporter dependency conflict + "mypy -p src --exclude src/strands/multiagent" ] lint-fix = [ "ruff check --fix" ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -123,20 +132,35 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] +[tool.hatch.envs.a2a] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] +[tool.hatch.envs.a2a.scripts] +run = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}" +] +run-cov = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}" +] +lint-check = [ + "ruff check", + "mypy -p src/strands/multiagent/a2a" +] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] - [tool.hatch.envs.hatch-test.scripts] run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a" ] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a" ] cov-combine = [] @@ -165,7 +189,15 @@ test = [ test-integ = [ "hatch test tests-integ {args}" ] - +prepare = [ + "hatch fmt --linter", + "hatch fmt --formatter", + "hatch test --all" +] +test-a2a = [ + # required to run manually due to A2A and OTEL http exporter dependency conflict + "hatch -e a2a run run {args}" +] [tool.mypy] python_version = "3.10" @@ -249,4 +281,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 4d2fa1fe..6618d332 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,7 +8,12 @@ from .agent import Agent from .agent_result import AgentResult -from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager +from .conversation_manager import ( + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +) __all__ = [ "Agent", @@ -16,4 +21,5 @@ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", + "SummarizingConversationManager", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bfa83fe2..4ecc42a9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -43,6 +44,19 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + + +# Sentinel class and object to distinguish between explicit None and default parameter value +class _DefaultCallbackHandlerSentinel: + """Sentinel class to distinguish between explicit None and default parameter value.""" + + pass + + +_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() + class Agent: """Core Agent interface. @@ -70,10 +84,11 @@ def __init__(self, agent: "Agent") -> None: # agent tools and thus break their execution. self._agent = agent - def __getattr__(self, name: str) -> Callable: + def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: name: The name of the attribute (tool) being accessed. @@ -82,9 +97,30 @@ def __getattr__(self, name: str) -> Callable: A function that when called will execute the named tool. Raises: - AttributeError: If no tool with the given name exists. + AttributeError: If no tool with the given name exists or if multiple tools match the given name. """ + def find_normalized_tool_name() -> Optional[str]: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def caller(**kwargs: Any) -> Any: """Call a tool directly by name. @@ -105,14 +141,13 @@ def caller(**kwargs: Any) -> Any: Raises: AttributeError: If the tool doesn't exist. """ - if name not in self._agent.tool_registry.registry: - raise AttributeError(f"Tool '{name}' not found") + normalized_name = find_normalized_tool_name() # Create unique tool ID and set up the tool request tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" tool_use = { "toolUseId": tool_id, - "name": name, + "name": normalized_name, "input": kwargs.copy(), } @@ -177,12 +212,17 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, - callback_handler: Optional[Callable] = PrintingCallbackHandler(), + callback_handler: Optional[ + Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] + ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, ): """Initialize the Agent with the specified configuration. @@ -204,7 +244,8 @@ def __init__( system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. callback_handler: Callback for processing events as they happen during agent execution. - Defaults to strands.handlers.PrintingCallbackHandler if None. + If not provided (using the default), a new PrintingCallbackHandler instance is created. + If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. @@ -214,6 +255,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to True. trace_attributes: Custom trace attributes to apply to the agent's trace span. + name: name of the Agent + Defaults to None. + description: description of what the Agent does + Defaults to None. Raises: ValueError: If max_parallel_tools is less than 1. @@ -222,7 +267,17 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.callback_handler = callback_handler or null_callback_handler + + # If not provided, create a new PrintingCallbackHandler instance + # If explicitly set to None, use null_callback_handler + # Otherwise use the passed callback_handler + self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): + self.callback_handler = PrintingCallbackHandler() + elif callback_handler is None: + self.callback_handler = null_callback_handler + else: + self.callback_handler = callback_handler self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -264,8 +319,9 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None - self.tool_caller = Agent.ToolCaller(self) + self.name = name + self.description = description @property def tool(self) -> ToolCaller: @@ -343,6 +399,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages, self.callback_handler) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -415,7 +497,7 @@ def target_callback() -> None: thread.join() def _run_loop( - self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None + self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None ) -> AgentResult: """Execute the agent's event loop with the given prompt and parameters.""" try: @@ -441,7 +523,7 @@ def _run_loop( finally: self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 68541877..c5962321 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -6,6 +6,8 @@ - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence +- SummarizingConversationManager: An implementation that summarizes older context instead + of simply trimming it Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -14,5 +16,11 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager +from .summarizing_conversation_manager import SummarizingConversationManager -__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] +__all__ = [ + "ConversationManager", + "NullConversationManager", + "SlidingWindowConversationManager", + "SummarizingConversationManager", +] diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 3381247c..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results def apply_management(self, agent: "Agent") -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: converted. """ messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: # Overwrite message history messages[:] = messages[trim_index:] + + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. + + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. + + Args: + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. + + Returns: + True if any changes were made to the message, False otherwise. + """ + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", + ) + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py new file mode 100644 index 00000000..a6b112dd --- /dev/null +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -0,0 +1,222 @@ +"""Summarizing conversation history management with configurable options.""" + +import logging +from typing import TYPE_CHECKING, List, Optional + +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + + +logger = logging.getLogger(__name__) + + +DEFAULT_SUMMARIZATION_PROMPT = """You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information +* +## Tools Executed +* Tool X: Result Y""" + + +class SummarizingConversationManager(ConversationManager): + """Implements a summarizing window manager. + + This manager provides a configurable option to summarize older context instead of + simply trimming it, helping preserve important information while staying within + context limits. + """ + + def __init__( + self, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_agent: Optional["Agent"] = None, + summarization_system_prompt: Optional[str] = None, + ): + """Initialize the summarizing conversation manager. + + Args: + summary_ratio: Ratio of messages to summarize vs keep when context overflow occurs. + Value between 0.1 and 0.8. Defaults to 0.3 (summarize 30% of oldest messages). + preserve_recent_messages: Minimum number of recent messages to always keep. + Defaults to 10 messages. + summarization_agent: Optional agent to use for summarization instead of the parent agent. + If provided, this agent can use tools as part of the summarization process. + summarization_system_prompt: Optional system prompt override for summarization. + If None, uses the default summarization prompt. + """ + if summarization_agent is not None and summarization_system_prompt is not None: + raise ValueError( + "Cannot provide both summarization_agent and summarization_system_prompt. " + "Agents come with their own system prompt." + ) + + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) + self.preserve_recent_messages = preserve_recent_messages + self.summarization_agent = summarization_agent + self.summarization_system_prompt = summarization_system_prompt + + def apply_management(self, agent: "Agent") -> None: + """Apply management strategy to conversation history. + + For the summarizing conversation manager, no proactive management is performed. + Summarization only occurs when there's a context overflow that triggers reduce_context. + + Args: + agent: The agent whose conversation history will be managed. + The agent's messages list is modified in-place. + """ + # No proactive management - summarization only happens on context overflow + pass + + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: + """Reduce context using summarization. + + Args: + agent: The agent whose conversation history will be reduced. + The agent's messages list is modified in-place. + e: The exception that triggered the context reduction, if any. + + Raises: + ContextWindowOverflowException: If the context cannot be summarized. + """ + try: + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] + + # Generate summary + summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [summary_message] + remaining_messages + + except Exception as summarization_error: + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + """Generate a summary of the provided messages. + + Args: + messages: The messages to summarize. + agent: The agent instance to use for summarization. + + Returns: + A message containing the conversation summary. + + Raises: + Exception: If summary generation fails. + """ + # Choose which agent to use for summarization + summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + + # Save original system prompt and messages to restore later + original_system_prompt = summarization_agent.system_prompt + original_messages = summarization_agent.messages.copy() + + try: + # Only override system prompt if no agent was provided during initialization + if self.summarization_agent is None: + # Use custom system prompt if provided, otherwise use default + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + # Temporarily set the system prompt for summarization + summarization_agent.system_prompt = system_prompt + summarization_agent.messages = messages + + # Use the agent to generate summary with rich content (can use tools if needed) + result = summarization_agent("Please summarize this conversation.") + + return result.message + + finally: + # Restore original agent state + summarization_agent.system_prompt = original_system_prompt + summarization_agent.messages = original_messages + + def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. + + Uses the same logic as SlidingWindowConversationManager for consistency. + + Args: + messages: The full list of messages. + split_point: The initially calculated split point. + + Returns: + The adjusted split point that doesn't break ToolUse/ToolResult pairs. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + if split_point > len(messages): + raise ContextWindowOverflowException("Split point exceeds message array length") + + if split_point == len(messages): + return split_point + + # Find the next valid split_point + while split_point < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[split_point]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[split_point]["content"]) + and split_point + 1 < len(messages) + and not any("toolResult" in content for content in messages[split_point + 1]["content"]) + ) + ): + split_point += 1 + else: + break + else: + # If we didn't find a valid split_point, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split_point diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index a5c85668..6dc0d9ee 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f..9580ea35 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,7 +22,7 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -33,23 +33,6 @@ MAX_DELAY = 240 # 4 minutes -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs - - def event_loop_cycle( model: Model, system_prompt: Optional[str], @@ -105,10 +88,11 @@ def event_loop_cycle( kwargs["event_loop_cycle_id"] = uuid.uuid4() event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) - # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) - cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() + if "request_state" not in kwargs: + kwargs["request_state"] = {} + attributes = {"event_loop_cycle_id": str(kwargs.get("event_loop_cycle_id"))} + cycle_start_time, cycle_trace = event_loop_metrics.start_cycle(attributes=attributes) kwargs["event_loop_cycle_trace"] = cycle_trace callback_handler(start=True) @@ -136,6 +120,7 @@ def event_loop_cycle( metrics: Metrics # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( @@ -145,14 +130,19 @@ def event_loop_cycle( ) try: - stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( - model, - system_prompt, - messages, - tool_config, - callback_handler, - **kwargs, - ) + # TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the + # call stack. At this point, we converted all events that were previously passed to the handler in + # `stream_messages` into yielded events that now have the "callback" key. To maintain backwards + # compatability, we need to combine the event with kwargs before passing to the handler. This we will + # revisit when migrating to strongly typed events. + for event in stream_messages(model, system_prompt, messages, tool_config): + if "callback" in event: + inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} + callback_handler(**inputs) + else: + stop_reason, message, usage, metrics = event["stop"] + kwargs.setdefault("request_state", {}) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage) break # Success! Break out of retry loop @@ -160,16 +150,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -177,7 +158,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -235,7 +216,7 @@ def event_loop_cycle( ) # End the cycle and return results - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) if cycle_span: tracer.end_event_loop_cycle_span( span=cycle_span, @@ -248,6 +229,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) @@ -314,26 +299,6 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. - - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. - - Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. - - Returns: - Updated keyword arguments ready for the next cycle. - """ - # Store parent cycle ID - kwargs["event_loop_metrics"] = event_loop_metrics - kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - - return kwargs - - def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -374,7 +339,7 @@ def _handle_tool_execution( kwargs (Dict[str, Any]): Additional keyword arguments, including request state. Returns: - Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: - The stop reason, - The updated message, - The updated event loop metrics, @@ -384,7 +349,6 @@ def _handle_tool_execution( if not tool_uses: return stop_reason, message, event_loop_metrics, kwargs["request_state"] - tool_handler_process = partial( tool_handler.process, messages=messages, @@ -407,7 +371,9 @@ def _handle_tool_execution( parallel_tool_executor=tool_execution_handler, ) - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + # Store parent cycle ID for the next cycle + kwargs["event_loop_metrics"] = event_loop_metrics + kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] tool_result_message: Message = { "role": "user", diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d7..4e1a39dc 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806f..0e9d472b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: return message -def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: +def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: """Handles the start of a content block by extracting tool usage information if any. Args: @@ -102,31 +102,31 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any -) -> Dict[str, Any]: + event: ContentBlockDeltaEvent, state: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. """ delta_content = event["delta"] + callback_event = {} + if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) + callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) + callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,29 +134,27 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_handler( - reasoningText=delta_content["reasoningContent"]["text"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoningText": delta_content["reasoningContent"]["text"], + "delta": delta_content, + "reasoning": True, + } elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_handler( - reasoning_signature=delta_content["reasoningContent"]["signature"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoning_signature": delta_content["reasoningContent"]["signature"], + "delta": delta_content, + "reasoning": True, + } - return state + return state, callback_event -def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: +def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: """Handles the end of a content block by finalizing tool usage, text content, or reasoning content. Args: @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: Returns: Updated state with finalized content block. """ - content: List[ContentBlock] = state["content"] + content: list[ContentBlock] = state["content"] current_tool_use = state["current_tool_use"] text = state["text"] @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: def process_stream( chunks: Iterable[StreamEvent], - callback_handler: Any, messages: Messages, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - callback_handler: Callback for processing events as they happen. messages: The agents messages. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the constructed message, the usage metrics, and the updated request state. + The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" - state: Dict[str, Any] = { + state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, "text": "", "current_tool_use": {}, @@ -285,18 +278,16 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - kwargs.setdefault("request_state", {}) - for chunk in chunks: - # Callback handler call here allows each event to be visible to the caller - callback_handler(event=chunk) + yield {"callback": {"event": chunk}} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) + state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +297,7 @@ def process_stream( elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], messages, state) - return stop_reason, state["message"], usage, metrics, kwargs["request_state"] + yield {"stop": (stop_reason, state["message"], usage, metrics)} def stream_messages( @@ -314,9 +305,7 @@ def stream_messages( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Streams messages to the model and processes the response. Args: @@ -324,12 +313,9 @@ def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_config: Configuration for the tools to use. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the final message, the usage metrics, and updated request state. + The reason for stopping, the final message, and the usage metrics """ logger.debug("model=<%s> | streaming messages", model) @@ -337,4 +323,4 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - return process_stream(chunks, callback_handler, messages, **kwargs) + yield from process_stream(chunks, messages) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 57394e2c..51089d47 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,11 +7,15 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -20,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield event.model_dump() usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9bbcca7d..a5ffb539 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,13 +6,17 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, cast +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -29,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -112,8 +118,17 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + region_for_boto = region_name or os.getenv("AWS_REGION") + if region_for_boto is None: + region_for_boto = "us-west-2" + logger.warning("defaulted to us-west-2 because no region was specified") + logger.warning( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + session = boto_session or boto3.Session( - region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", + region_name=region_for_boto, ) # Add strands-agents to the request user agent @@ -477,3 +492,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 62f16d31..66138186 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,17 +3,22 @@ - Docs: https://docs.litellm.ai/ """ +import json import logging -from typing import Any, Optional, TypedDict, cast +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema +from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from .openai import OpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" @@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] } return super().format_request_message_content(content) + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + """ + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 583db2f2..755e07ad 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. + """ + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7ed12216..b062fe14 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,9 +5,10 @@ import json import logging -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 764cb851..783ce379 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,15 +4,20 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel from typing_extensions import Unpack, override +from ..types.content import Messages from ..types.models import OpenAIModel as SAOpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -94,6 +99,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: tool_calls: dict[int, list[Any]] = {} for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue choice = event.choices[0] if choice.delta.content: @@ -122,3 +130,35 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: _ = event yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py new file mode 100644 index 00000000..1cef1425 --- /dev/null +++ b/src/strands/multiagent/__init__.py @@ -0,0 +1,13 @@ +"""Multiagent capabilities for Strands Agents. + +This module provides support for multiagent systems, including agent-to-agent (A2A) +communication protocols and coordination mechanisms. + +Submodules: + a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables + standardized communication between agents. +""" + +from . import a2a + +__all__ = ["a2a"] diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py new file mode 100644 index 00000000..c5425618 --- /dev/null +++ b/src/strands/multiagent/a2a/__init__.py @@ -0,0 +1,14 @@ +"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. + +This module provides classes and utilities for enabling Strands Agents to communicate +with other agents using the Agent-to-Agent (A2A) protocol. + +Docs: https://google-a2a.github.io/A2A/latest/ + +Classes: + A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. +""" + +from .agent import A2AAgent + +__all__ = ["A2AAgent"] diff --git a/src/strands/multiagent/a2a/agent.py b/src/strands/multiagent/a2a/agent.py new file mode 100644 index 00000000..4359100d --- /dev/null +++ b/src/strands/multiagent/a2a/agent.py @@ -0,0 +1,149 @@ +"""A2A-compatible wrapper for Strands Agent. + +This module provides the A2AAgent class, which adapts a Strands Agent to the A2A protocol, +allowing it to be used in A2A-compatible systems. +""" + +import logging +from typing import Any, Literal + +import uvicorn +from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from fastapi import FastAPI +from starlette.applications import Starlette + +from ...agent.agent import Agent as SAAgent +from .executor import StrandsA2AExecutor + +logger = logging.getLogger(__name__) + + +class A2AAgent: + """A2A-compatible wrapper for Strands Agent.""" + + def __init__( + self, + agent: SAAgent, + *, + # AgentCard + host: str = "0.0.0.0", + port: int = 9000, + version: str = "0.0.1", + ): + """Initialize an A2A-compatible agent from a Strands agent. + + Args: + agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. + host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". + port: The port to bind the A2A server to. Defaults to 9000. + version: The version of the agent. Defaults to "0.0.1". + """ + self.host = host + self.port = port + self.http_url = f"http://{self.host}:{self.port}/" + self.version = version + self.strands_agent = agent + self.name = self.strands_agent.name + self.description = self.strands_agent.description + # TODO: enable configurable capabilities and request handler + self.capabilities = AgentCapabilities() + self.request_handler = DefaultRequestHandler( + agent_executor=StrandsA2AExecutor(self.strands_agent), + task_store=InMemoryTaskStore(), + ) + logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + + @property + def public_agent_card(self) -> AgentCard: + """Get the public AgentCard for this agent. + + The AgentCard contains metadata about the agent, including its name, + description, URL, version, skills, and capabilities. This information + is used by other agents and systems to discover and interact with this agent. + + Returns: + AgentCard: The public agent card containing metadata about this agent. + + Raises: + ValueError: If name or description is None or empty. + """ + if not self.name: + raise ValueError("A2A agent name cannot be None or empty") + if not self.description: + raise ValueError("A2A agent description cannot be None or empty") + + return AgentCard( + name=self.name, + description=self.description, + url=self.http_url, + version=self.version, + skills=self.agent_skills, + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=self.capabilities, + ) + + @property + def agent_skills(self) -> list[AgentSkill]: + """Get the list of skills this agent provides. + + Skills represent specific capabilities that the agent can perform. + Strands agent tools are adapted to A2A skills. + + Returns: + list[AgentSkill]: A list of skills this agent provides. + """ + # TODO: translate Strands tools (native & MCP) to skills + return [] + + def to_starlette_app(self) -> Starlette: + """Create a Starlette application for serving this agent via HTTP. + + This method creates a Starlette application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + Starlette: A Starlette application configured to serve this agent. + """ + return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def to_fastapi_app(self) -> FastAPI: + """Create a FastAPI application for serving this agent via HTTP. + + This method creates a FastAPI application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + FastAPI: A FastAPI application configured to serve this agent. + """ + return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def serve(self, app_type: Literal["fastapi", "starlette"] = "starlette", **kwargs: Any) -> None: + """Start the A2A server with the specified application type. + + This method starts an HTTP server that exposes the agent via the A2A protocol. + The server can be implemented using either FastAPI or Starlette, depending on + the specified app_type. + + Args: + app_type: The type of application to serve, either "fastapi" or "starlette". + Defaults to "starlette". + **kwargs: Additional keyword arguments to pass to uvicorn.run. + """ + try: + logger.info("Starting Strands A2A server...") + if app_type == "fastapi": + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) + else: + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) + except KeyboardInterrupt: + logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") + except Exception: + logger.exception("Strands A2A server encountered exception.") + finally: + logger.info("Strands A2A server has shutdown.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py new file mode 100644 index 00000000..b7a7af09 --- /dev/null +++ b/src/strands/multiagent/a2a/executor.py @@ -0,0 +1,67 @@ +"""Strands Agent executor for the A2A protocol. + +This module provides the StrandsA2AExecutor class, which adapts a Strands Agent +to be used as an executor in the A2A protocol. It handles the execution of agent +requests and the conversion of Strands Agent responses to A2A events. +""" + +import logging + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message +from a2a.utils.errors import ServerError + +from ...agent.agent import Agent as SAAgent +from ...agent.agent_result import AgentResult as SAAgentResult + +log = logging.getLogger(__name__) + + +class StrandsA2AExecutor(AgentExecutor): + """Executor that adapts a Strands Agent to the A2A protocol.""" + + def __init__(self, agent: SAAgent): + """Initialize a StrandsA2AExecutor. + + Args: + agent: The Strands Agent to adapt to the A2A protocol. + """ + self.agent = agent + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute a request using the Strands Agent and send the response as A2A events. + + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. + + Args: + context: The A2A request context, containing the user's input and other metadata. + event_queue: The A2A event queue, used to send response events. + """ + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel an ongoing execution. + + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. + + Args: + context: The A2A request context. + event_queue: The A2A event queue. + + Raises: + ServerError: Always raised with an UnsupportedOperationError, as cancellation + is not currently supported. + """ + raise ServerError(error=UnsupportedOperationError()) diff --git a/src/strands/telemetry/__init__.py b/src/strands/telemetry/__init__.py index 15981216..21dd6ebf 100644 --- a/src/strands/telemetry/__init__.py +++ b/src/strands/telemetry/__init__.py @@ -3,7 +3,8 @@ This module provides metrics and tracing functionality. """ -from .metrics import EventLoopMetrics, Trace, metrics_to_string +from .config import get_otel_resource +from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string from .tracer import Tracer, get_tracer __all__ = [ @@ -12,4 +13,6 @@ "metrics_to_string", "Tracer", "get_tracer", + "MetricsClient", + "get_otel_resource", ] diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py new file mode 100644 index 00000000..9f5a05fd --- /dev/null +++ b/src/strands/telemetry/config.py @@ -0,0 +1,33 @@ +"""OpenTelemetry configuration and setup utilities for Strands agents. + +This module provides centralized configuration and initialization functionality +for OpenTelemetry components and other telemetry infrastructure shared across Strands applications. +""" + +from importlib.metadata import version + +from opentelemetry.sdk.resources import Resource + + +def get_otel_resource() -> Resource: + """Create a standard OpenTelemetry resource with service information. + + This function implements a singleton pattern - it will return the same + Resource object for the same service_name parameter. + + Args: + service_name: Name of the service for OpenTelemetry. + + Returns: + Resource object with standard service information. + """ + resource = Resource.create( + { + "service.name": __name__, + "service.version": version("strands-agents"), + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.language": "python", + } + ) + + return resource diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index cd70819b..332ab2ae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -6,6 +6,10 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +import opentelemetry.metrics as metrics_api +from opentelemetry.metrics import Counter, Histogram, Meter + +from ..telemetry import metrics_constants as constants from ..types.content import Message from ..types.streaming import Metrics, Usage from ..types.tools import ToolUse @@ -117,22 +121,34 @@ class ToolMetrics: error_count: int = 0 total_time: float = 0.0 - def add_call(self, tool: ToolUse, duration: float, success: bool) -> None: + def add_call( + self, + tool: ToolUse, + duration: float, + success: bool, + metrics_client: "MetricsClient", + attributes: Optional[Dict[str, Any]] = None, + ) -> None: """Record a new tool call with its outcome. Args: tool: The tool that was called. duration: How long the call took in seconds. success: Whether the call was successful. + metrics_client: The metrics client for recording the metrics. + attributes: attributes of the metrics. """ self.tool = tool # Update with latest tool state self.call_count += 1 self.total_time += duration - + metrics_client.tool_call_count.add(1, attributes=attributes) + metrics_client.tool_duration.record(duration, attributes=attributes) if success: self.success_count += 1 + metrics_client.tool_success_count.add(1, attributes=attributes) else: self.error_count += 1 + metrics_client.tool_error_count.add(1, attributes=attributes) @dataclass @@ -155,32 +171,53 @@ class EventLoopMetrics: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) - def start_cycle(self) -> Tuple[float, Trace]: + @property + def _metrics_client(self) -> "MetricsClient": + """Get the singleton MetricsClient instance.""" + return MetricsClient() + + def start_cycle( + self, + attributes: Optional[Dict[str, Any]] = None, + ) -> Tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. + Args: + attributes: attributes of the metrics. + Returns: A tuple containing the start time and the cycle trace object. """ + self._metrics_client.event_loop_cycle_count.add(1, attributes=attributes) + self._metrics_client.event_loop_start_cycle.add(1, attributes=attributes) self.cycle_count += 1 start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: """End the current event loop cycle and record its duration. Args: start_time: The timestamp when the cycle started. cycle_trace: The trace object for this cycle. + attributes: attributes of the metrics. """ + self._metrics_client.event_loop_end_cycle.add(1, attributes) end_time = time.time() duration = end_time - start_time + self._metrics_client.event_loop_cycle_duration.record(duration, attributes) self.cycle_durations.append(duration) cycle_trace.end(end_time) def add_tool_usage( - self, tool: ToolUse, duration: float, tool_trace: Trace, success: bool, message: Message + self, + tool: ToolUse, + duration: float, + tool_trace: Trace, + success: bool, + message: Message, ) -> None: """Record metrics for a tool invocation. @@ -203,8 +240,16 @@ def add_tool_usage( tool_trace.raw_name = f"{tool_name} - {tool_use_id}" tool_trace.add_message(message) - self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call(tool, duration, success) - + self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call( + tool, + duration, + success, + self._metrics_client, + attributes={ + "tool_name": tool_name, + "tool_use_id": tool_use_id, + }, + ) tool_trace.end() def update_usage(self, usage: Usage) -> None: @@ -213,6 +258,8 @@ def update_usage(self, usage: Usage) -> None: Args: usage: The usage data to add to the accumulated totals. """ + self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) + self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) self.accumulated_usage["inputTokens"] += usage["inputTokens"] self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] @@ -223,6 +270,7 @@ def update_metrics(self, metrics: Metrics) -> None: Args: metrics: The metrics data to add to the accumulated totals. """ + self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -355,3 +403,74 @@ def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optio A formatted string representation of the metrics. """ return "\n".join(_metrics_summary_to_lines(event_loop_metrics, allowed_names or set())) + + +class MetricsClient: + """Singleton client for managing OpenTelemetry metrics instruments. + + The actual metrics export destination (console, OTLP endpoint, etc.) is configured + through OpenTelemetry SDK configuration by users, not by this client. + """ + + _instance: Optional["MetricsClient"] = None + meter: Meter + event_loop_cycle_count: Counter + event_loop_start_cycle: Counter + event_loop_end_cycle: Counter + event_loop_cycle_duration: Histogram + event_loop_latency: Histogram + event_loop_input_tokens: Histogram + event_loop_output_tokens: Histogram + + tool_call_count: Counter + tool_success_count: Counter + tool_error_count: Counter + tool_duration: Histogram + + def __new__(cls) -> "MetricsClient": + """Create or return the singleton instance of MetricsClient. + + Returns: + The single MetricsClient instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the MetricsClient. + + This method only runs once due to the singleton pattern. + Sets up the OpenTelemetry meter and creates metric instruments. + """ + if hasattr(self, "meter"): + return + + logger.info("Creating Strands MetricsClient") + meter_provider: metrics_api.MeterProvider = metrics_api.get_meter_provider() + self.meter = meter_provider.get_meter(__name__) + self.create_instruments() + + def create_instruments(self) -> None: + """Create and initialize all OpenTelemetry metric instruments.""" + self.event_loop_cycle_count = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_CYCLE_COUNT, unit="Count" + ) + self.event_loop_start_cycle = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_START_CYCLE, unit="Count" + ) + self.event_loop_end_cycle = self.meter.create_counter(name=constants.STRANDS_EVENT_LOOP_END_CYCLE, unit="Count") + self.event_loop_cycle_duration = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CYCLE_DURATION, unit="s" + ) + self.event_loop_latency = self.meter.create_histogram(name=constants.STRANDS_EVENT_LOOP_LATENCY, unit="ms") + self.tool_call_count = self.meter.create_counter(name=constants.STRANDS_TOOL_CALL_COUNT, unit="Count") + self.tool_success_count = self.meter.create_counter(name=constants.STRANDS_TOOL_SUCCESS_COUNT, unit="Count") + self.tool_error_count = self.meter.create_counter(name=constants.STRANDS_TOOL_ERROR_COUNT, unit="Count") + self.tool_duration = self.meter.create_histogram(name=constants.STRANDS_TOOL_DURATION, unit="s") + self.event_loop_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKENS, unit="token" + ) + self.event_loop_output_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py new file mode 100644 index 00000000..b622eebf --- /dev/null +++ b/src/strands/telemetry/metrics_constants.py @@ -0,0 +1,15 @@ +"""Metrics that are emitted in Strands-Agents.""" + +STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count" +STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle" +STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle" +STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count" +STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count" +STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count" + +# Histograms +STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency" +STRANDS_TOOL_DURATION = "strands.tool.duration" +STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" +STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" +STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed..813c90e1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -8,17 +8,19 @@ import logging import os from datetime import date, datetime, timezone -from importlib.metadata import version from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult +from ..telemetry import get_otel_resource from ..types.content import Message, Messages from ..types.streaming import Usage from ..types.tools import ToolResult, ToolUse @@ -26,6 +28,19 @@ logger = logging.getLogger(__name__) +HAS_OTEL_EXPORTER_MODULE = False +OTEL_EXPORTER_MODULE_ERROR = ( + "opentelemetry-exporter-otlp-proto-http not detected;" + "please install strands-agents with the optional 'otel' target" + "otel http exporting is currently DISABLED" +) +try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + HAS_OTEL_EXPORTER_MODULE = True +except ImportError: + pass + class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" @@ -133,28 +148,33 @@ def __init__( self.service_name = service_name self.otlp_headers = otlp_headers or {} - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") - # Create resource with service information - resource = Resource.create( - { - "service.name": self.service_name, - "service.version": version("strands-agents"), - "telemetry.sdk.name": "opentelemetry", - "telemetry.sdk.language": "python", - } - ) + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + + resource = get_otel_resource() # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -163,7 +183,7 @@ def _initialize_tracer(self) -> None: self.tracer_provider.add_span_processor(console_processor) # Add OTLP exporter if endpoint is provided - if self.otlp_endpoint and self.tracer_provider: + if HAS_OTEL_EXPORTER_MODULE and self.otlp_endpoint and self.tracer_provider: try: # Ensure endpoint has the right format endpoint = self.otlp_endpoint @@ -186,19 +206,26 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) + except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) + elif self.otlp_endpoint and self.tracer_provider: + raise ModuleNotFoundError(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -212,7 +239,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -224,7 +251,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -239,7 +266,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -272,13 +299,13 @@ def _end_span( finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -294,12 +321,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -328,7 +355,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -347,9 +374,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -374,7 +399,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -402,10 +427,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -436,7 +461,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -466,7 +491,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -506,7 +531,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -557,13 +582,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..12979015 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c6890945..a2298813 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -41,6 +41,12 @@ "image/webp": "webp", } +CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( + "the client session is not running. Ensure the agent is used within " + "the MCP client context manager. For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" +) + class MCPClient: """Represents a connection to a Model Context Protocol (MCP) server. @@ -145,7 +151,7 @@ def list_tools_sync(self) -> List[MCPAgentTool]: """ self._log_debug_with_thread("listing MCP tools synchronously") if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: return await self._background_thread_session.list_tools() @@ -180,7 +186,7 @@ def call_tool_sync( """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2efdd600..e56ee999 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + if self.registry.get(tool.tool_name) is None: + normalized_name = tool.tool_name.replace("-", "_") + + matching_tools = [ + tool_name + for (tool_name, tool) in self.registry.items() + if tool_name.replace("-", "_") == normalized_name + ] + + if matching_tools: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'." + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + # Register in main registry self.registry[tool.tool_name] = tool diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 00000000..5421cdc6 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,415 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + +from ..types.tools import ToolSpec + + +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + required_props: list[str] = [] + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: + flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") + + return flattened + + +def _process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def _process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = _process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_tool_spec( + model: Type[BaseModel], + description: Optional[str] = None, +) -> ToolSpec: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + + Returns: + ToolSpec: Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + _process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + _expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = _flatten_schema(input_schema) + + final_schema = flattened_schema + + # Construct the tool specification + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) + + +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + _process_properties(ref_def, field_type) + + +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index b595c3d6..a449c74e 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -63,50 +63,54 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) +def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: + """Normalize a single property definition. + + Args: + prop_name: The name of the property. + prop_def: The property definition to normalize. + + Returns: + The normalized property definition. + """ + if not isinstance(prop_def, dict): + return {"type": "string", "description": f"Property {prop_name}"} + + if prop_def.get("type") == "object" and "properties" in prop_def: + return normalize_schema(prop_def) # Recursive call + + # Copy existing property, ensuring defaults + normalized_prop = prop_def.copy() + normalized_prop.setdefault("type", "string") + normalized_prop.setdefault("description", f"Property {prop_name}") + return normalized_prop + + def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Normalize a JSON schema to match expectations. + This function recursively processes nested objects to preserve the complete schema structure. + Uses a copy-then-normalize approach to preserve all original schema properties. + Args: schema: The schema to normalize. Returns: The normalized schema. """ - normalized = {"type": schema.get("type", "object"), "properties": {}} - - # Handle properties - if "properties" in schema: - for prop_name, prop_def in schema["properties"].items(): - if isinstance(prop_def, dict): - normalized_prop = { - "type": prop_def.get("type", "string"), - "description": prop_def.get("description", f"Property {prop_name}"), - } - - # Handle enum values correctly - if "enum" in prop_def: - normalized_prop["enum"] = prop_def["enum"] - - # Handle numeric constraints - if prop_def.get("type") in ["number", "integer"]: - if "minimum" in prop_def: - normalized_prop["minimum"] = prop_def["minimum"] - if "maximum" in prop_def: - normalized_prop["maximum"] = prop_def["maximum"] - - normalized["properties"][prop_name] = normalized_prop - else: - # Handle non-dict property definitions (like simple strings) - normalized["properties"][prop_name] = { - "type": "string", - "description": f"Property {prop_name}", - } - - # Required fields - if "required" in schema: - normalized["required"] = schema["required"] - else: - normalized["required"] = [] + # Start with a complete copy to preserve all existing properties + normalized = schema.copy() + + # Ensure essential structure exists + normalized.setdefault("type", "object") + normalized.setdefault("properties", {}) + normalized.setdefault("required", []) + + # Process properties recursively + if "properties" in normalized: + properties = normalized["properties"] + for prop_name, prop_def in properties.items(): + normalized["properties"][prop_name] = _normalize_property(prop_name, prop_def) return normalized diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 23e74602..071c8a51 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,9 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel from ..content import Messages from ..streaming import StreamEvent @@ -10,6 +12,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -38,6 +42,26 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + pass + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 96f758d5..8ff37d35 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,8 +11,9 @@ import json import logging import mimetypes -from typing import Any, Optional, cast +from typing import Any, Callable, Optional, Type, TypeVar, cast +from pydantic import BaseModel from typing_extensions import override from ..content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OpenAIModel(Model, abc.ABC): """Base OpenAI model provider implementation. @@ -31,6 +34,32 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. @@ -57,7 +86,8 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { "image_url": { "detail": "auto", @@ -262,3 +292,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py index 9ffd1bdf..bf0be706 100644 --- a/tests-integ/test_bedrock_guardrails.py +++ b/tests-integ/test_bedrock_guardrails.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def boto_session(): - return boto3.Session(region_name="us-west-2") + return boto3.Session(region_name="us-east-1") @pytest.fixture(scope="module") @@ -142,7 +142,7 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_stream_processing_mode=processing_mode, guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, - region_name="us-west-2", + region_name="us-east-1", ) agent = Agent( diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index 59ae2a14..8b1dade3 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -1,8 +1,10 @@ import base64 +import os import threading import time from typing import List, Literal +import pytest from mcp import StdioServerParameters, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -101,6 +103,10 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", +) def test_streamable_http_mcp_client(): server_thread = threading.Thread( target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c9..50033f8f 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,7 +34,7 @@ def tool_weather() -> str: @pytest.fixture def system_prompt(): - return "You are an AI assistant that uses & instead of ." + return "You are an AI assistant." @pytest.fixture @@ -46,4 +47,17 @@ def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - assert all(string in text for string in ["12:00", "sunny", "&"]) + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index a6a29aa9..5378a9b2 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -118,3 +119,33 @@ def calculator(expression: str) -> float: agent("What is 123 + 456?") assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61f..01a3e121 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -7,7 +8,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 00000000..38b46821 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index c9046ad5..b0790ba0 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -44,3 +45,22 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 00000000..5dcf4944 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/conftest.py b/tests/conftest.py index 4f0b5b21..cd18b698 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,8 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) @pytest.fixture diff --git a/tests/multiagent/__init__.py b/tests/multiagent/__init__.py new file mode 100644 index 00000000..b43bae53 --- /dev/null +++ b/tests/multiagent/__init__.py @@ -0,0 +1 @@ +"""Tests for the multiagent module.""" diff --git a/tests/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py new file mode 100644 index 00000000..eb5487d9 --- /dev/null +++ b/tests/multiagent/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A module.""" diff --git a/tests/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py new file mode 100644 index 00000000..558a4594 --- /dev/null +++ b/tests/multiagent/a2a/conftest.py @@ -0,0 +1,41 @@ +"""Common fixtures for A2A module tests.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue + +from strands.agent.agent import Agent as SAAgent +from strands.agent.agent_result import AgentResult as SAAgentResult + + +@pytest.fixture +def mock_strands_agent(): + """Create a mock Strands Agent for testing.""" + agent = MagicMock(spec=SAAgent) + agent.name = "Test Agent" + agent.description = "A test agent for unit testing" + + # Setup default response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + agent.return_value = mock_result + + return agent + + +@pytest.fixture +def mock_request_context(): + """Create a mock RequestContext for testing.""" + context = MagicMock(spec=RequestContext) + context.get_user_input.return_value = "Test input" + return context + + +@pytest.fixture +def mock_event_queue(): + """Create a mock EventQueue for testing.""" + queue = MagicMock(spec=EventQueue) + queue.enqueue_event = AsyncMock() + return queue diff --git a/tests/multiagent/a2a/test_agent.py b/tests/multiagent/a2a/test_agent.py new file mode 100644 index 00000000..5558c2af --- /dev/null +++ b/tests/multiagent/a2a/test_agent.py @@ -0,0 +1,165 @@ +"""Tests for the A2AAgent class.""" + +from unittest.mock import patch + +import pytest +from a2a.types import AgentCapabilities, AgentCard +from fastapi import FastAPI +from starlette.applications import Starlette + +from strands.multiagent.a2a.agent import A2AAgent + + +def test_a2a_agent_initialization(mock_strands_agent): + """Test that A2AAgent initializes correctly with default values.""" + a2a_agent = A2AAgent(mock_strands_agent) + + assert a2a_agent.strands_agent == mock_strands_agent + assert a2a_agent.name == "Test Agent" + assert a2a_agent.description == "A test agent for unit testing" + assert a2a_agent.host == "0.0.0" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://0.0.0:9000/" + assert a2a_agent.version == "0.0.1" + assert isinstance(a2a_agent.capabilities, AgentCapabilities) + + +def test_a2a_agent_initialization_with_custom_values(mock_strands_agent): + """Test that A2AAgent initializes correctly with custom values.""" + a2a_agent = A2AAgent( + mock_strands_agent, + host="127.0.0.1", + port=8080, + version="1.0.0", + ) + + assert a2a_agent.host == "127.0.0.1" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.version == "1.0.0" + + +def test_public_agent_card(mock_strands_agent): + """Test that public_agent_card returns a valid AgentCard.""" + a2a_agent = A2AAgent(mock_strands_agent) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + assert card.url == "http://0.0.0:9000/" + assert card.version == "0.0.1" + assert card.defaultInputModes == ["text"] + assert card.defaultOutputModes == ["text"] + assert card.skills == [] + assert card.capabilities == a2a_agent.capabilities + + +def test_public_agent_card_with_missing_name(mock_strands_agent): + """Test that public_agent_card raises ValueError when name is missing.""" + mock_strands_agent.name = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent name cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_public_agent_card_with_missing_description(mock_strands_agent): + """Test that public_agent_card raises ValueError when description is missing.""" + mock_strands_agent.description = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent description cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_agent_skills(mock_strands_agent): + """Test that agent_skills returns an empty list (current implementation).""" + a2a_agent = A2AAgent(mock_strands_agent) + + skills = a2a_agent.agent_skills + + assert isinstance(skills, list) + assert len(skills) == 0 + + +def test_to_starlette_app(mock_strands_agent): + """Test that to_starlette_app returns a Starlette application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app(mock_strands_agent): + """Test that to_fastapi_app returns a FastAPI application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +@patch("uvicorn.run") +def test_serve_with_starlette(mock_run, mock_strands_agent): + """Test that serve starts a Starlette server by default.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], Starlette) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_fastapi(mock_run, mock_strands_agent): + """Test that serve starts a FastAPI server when specified.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(app_type="fastapi") + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], FastAPI) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_custom_kwargs(mock_run, mock_strands_agent): + """Test that serve passes additional kwargs to uvicorn.run.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(log_level="debug", reload=True) + + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["log_level"] == "debug" + assert kwargs["reload"] is True + + +@patch("uvicorn.run", side_effect=KeyboardInterrupt) +def test_serve_handles_keyboard_interrupt(mock_run, mock_strands_agent, caplog): + """Test that serve handles KeyboardInterrupt gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server shutdown requested (KeyboardInterrupt)" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text + + +@patch("uvicorn.run", side_effect=Exception("Test exception")) +def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): + """Test that serve handles general exceptions gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server encountered exception" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 00000000..2ac9bed9 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,118 @@ +"""Tests for the StrandsA2AExecutor class.""" + +from unittest.mock import MagicMock + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult as SAAgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor + + +def test_executor_initialization(mock_strands_agent): + """Test that StrandsA2AExecutor initializes correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + assert executor.agent == mock_strands_agent + + +@pytest.mark.asyncio +async def test_execute_with_text_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes text responses correctly.""" + # Setup mock agent response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify event was enqueued + mock_event_queue.enqueue_event.assert_called_once() + args, _ = mock_event_queue.enqueue_event.call_args + event = args[0] + assert event.parts[0].root.text == "Test response" + + +@pytest.mark.asyncio +async def test_execute_with_multiple_text_blocks(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes multiple text blocks correctly.""" + # Setup mock agent response with multiple text blocks + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "First response"}, {"text": "Second response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify events were enqueued + assert mock_event_queue.enqueue_event.call_count == 2 + + # Check first event + args1, _ = mock_event_queue.enqueue_event.call_args_list[0] + event1 = args1[0] + assert event1.parts[0].root.text == "First response" + + # Check second event + args2, _ = mock_event_queue.enqueue_event.call_args_list[1] + event2 = args2[0] + assert event2.parts[0].root.text == "Second response" + + +@pytest.mark.asyncio +async def test_execute_with_empty_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles empty responses correctly.""" + # Setup mock agent response with empty content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": []} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_with_no_message(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles responses with no message correctly.""" + # Setup mock agent response with no message + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = None + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4a63fa31..85d17544 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,6 +7,7 @@ from time import sleep import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -438,7 +439,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -484,10 +485,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -504,6 +538,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -538,7 +575,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 @@ -674,6 +711,46 @@ def function(system_prompt: str) -> str: ) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + tool_name = "system-prompter" + + @strands.tools.tool(name=tool_name) + def function(system_prompt: str) -> str: + return system_prompt + + tool = strands.tools.tools.FunctionTool(function) + agent.tool_registry.register_tool(tool) + + mock_randint.return_value = 1 + + agent.tool.system_prompter(system_prompt="tool prompt") + + # Verify the correct tool was invoked + assert agent.tool_handler.process.call_count == 1 + tool_call = agent.tool_handler.process.call_args.kwargs.get("tool") + + assert tool_call == { + # Note that the tool-use uses the "python safe" name + "toolUseId": "tooluse_system_prompter_1", + # But the name of the tool is the one in the registry + "name": tool_name, + "input": {"system_prompt": "tool prompt"}, + } + + +def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + mock_randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Tool 'system_prompter_1' not found" + + def test_agent_with_none_callback_handler_prints_nothing(): agent = Agent() @@ -686,6 +763,62 @@ def test_agent_with_callback_handler_none_uses_null_handler(): assert agent.callback_handler == null_callback_handler +def test_agent_callback_handler_not_provided_creates_new_instances(): + """Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created.""" + # Create two agents without providing callback_handler + agent1 = Agent() + agent2 = Agent() + + # Both should have PrintingCallbackHandler instances + assert isinstance(agent1.callback_handler, PrintingCallbackHandler) + assert isinstance(agent2.callback_handler, PrintingCallbackHandler) + + # But they should be different object instances + assert agent1.callback_handler is not agent2.callback_handler + + +def test_agent_callback_handler_explicit_none_uses_null_handler(): + """Test that when callback_handler is explicitly set to None, null_callback_handler is used.""" + agent = Agent(callback_handler=None) + + # Should use null_callback_handler + assert agent.callback_handler is null_callback_handler + + +def test_agent_callback_handler_custom_handler_used(): + """Test that when a custom callback_handler is provided, it is used.""" + custom_handler = unittest.mock.Mock() + agent = Agent(callback_handler=custom_handler) + + # Should use the provided custom handler + assert agent.callback_handler is custom_handler + + +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index bbec3cd1..7d43199e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -25,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -168,7 +169,7 @@ def test_apply_management(conversation_manager, messages, expected_messages): def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): - manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, @@ -182,6 +183,40 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con assert messages == original_messages +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + + assert messages == expected_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py new file mode 100644 index 00000000..9952203e --- /dev/null +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -0,0 +1,566 @@ +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.types.content import Messages +from strands.types.exceptions import ContextWindowOverflowException + +if TYPE_CHECKING: + from strands.agent.agent import Agent + + +class MockAgent: + """Mock agent for testing summarization.""" + + def __init__(self, summary_response="This is a summary of the conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.call_tracker = Mock() + + def __call__(self, prompt): + """Mock agent call that returns a summary.""" + self.call_tracker(prompt) + result = Mock() + result.message = {"role": "assistant", "content": [{"text": self.summary_response}]} + return result + + +def create_mock_agent(summary_response="This is a summary of the conversation.") -> "Agent": + """Factory function that returns a properly typed MockAgent.""" + return cast("Agent", MockAgent(summary_response)) + + +@pytest.fixture +def mock_agent(): + """Fixture for mock agent.""" + return create_mock_agent() + + +@pytest.fixture +def summarizing_manager(): + """Fixture for summarizing conversation manager with default settings.""" + return SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + + +def test_init_default_values(): + """Test initialization with default values.""" + manager = SummarizingConversationManager() + + assert manager.summarization_agent is None + assert manager.summary_ratio == 0.3 + assert manager.preserve_recent_messages == 10 + + +def test_init_clamps_summary_ratio(): + """Test that summary_ratio is clamped to valid range.""" + # Test lower bound + manager = SummarizingConversationManager(summary_ratio=0.05) + assert manager.summary_ratio == 0.1 + + # Test upper bound + manager = SummarizingConversationManager(summary_ratio=0.95) + assert manager.summary_ratio == 0.8 + + +def test_reduce_context_raises_when_no_agent(): + """Test that reduce_context raises exception when agent has no messages.""" + manager = SummarizingConversationManager() + + # Create a mock agent with no messages + mock_agent = Mock() + empty_messages: Messages = [] + mock_agent.messages = empty_messages + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_with_summarization(summarizing_manager, mock_agent): + """Test reduce_context with summarization enabled.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Message 3"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + mock_agent.messages = test_messages + + summarizing_manager.reduce_context(mock_agent) + + # Should have: 1 summary message + 2 preserved recent messages + remaining from summarization + assert len(mock_agent.messages) == 4 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + first_content = mock_agent.messages[0]["content"][0] + assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] + + # Recent messages should be preserved + assert "Message 3" in str(mock_agent.messages[-2]["content"]) + assert "Response 3" in str(mock_agent.messages[-1]["content"]) + + +def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, mock_agent): + """Test that reduce_context raises exception when there are too few messages to summarize effectively.""" + # Create a scenario where calculation results in 0 messages to summarize + manager = SummarizingConversationManager( + summary_ratio=0.1, # Very small ratio + preserve_recent_messages=5, # High preservation + ) + + insufficient_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_insufficient_messages_for_summarization(mock_agent): + """Test reduce_context when there aren't enough messages to summarize.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=3, + ) + + insufficient_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_messages + + # This should raise an exception since there aren't enough messages to summarize + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_raises_on_summarization_failure(): + """Test that reduce_context raises exception when summarization fails.""" + # Create an agent that will fail + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + failing_agent_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + ] + failing_agent.messages = failing_agent_messages + + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + with pytest.raises(Exception, match="Agent failed"): + manager.reduce_context(failing_agent) + + # Should log the error + mock_logger.error.assert_called_once() + + +def test_generate_summary(summarizing_manager, mock_agent): + """Test the _generate_summary method.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + summary = summarizing_manager._generate_summary(test_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): + """Test summary generation with tool use and results.""" + tool_messages: Messages = [ + {"role": "user", "content": [{"text": "Use a tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + summary = summarizing_manager._generate_summary(tool_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_raises_on_agent_failure(): + """Test that _generate_summary raises exception when agent fails.""" + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + empty_failing_messages: Messages = [] + failing_agent.messages = empty_failing_messages + + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should raise the exception from the agent + with pytest.raises(Exception, match="Agent failed"): + manager._generate_summary(messages, failing_agent) + + +def test_adjust_split_point_for_tool_pairs(summarizing_manager): + """Test that the split point is adjusted to avoid breaking ToolUse/ToolResult pairs.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Tool output"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Response after tool"}]}, + ] + + # If we try to split at message 2 (the ToolResult), it should move forward to message 3 + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert adjusted_split == 3 # Should move to after the ToolResult + + # If we try to split at message 3, it should be fine (no tool issues) + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + assert adjusted_split == 3 + + # If we try to split at message 1 (toolUse with following toolResult), it should be valid + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert adjusted_split == 1 # Should be valid because toolResult follows + + +def test_apply_management_no_op(summarizing_manager, mock_agent): + """Test apply_management does not modify messages (no-op behavior).""" + apply_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More messages"}]}, + {"role": "assistant", "content": [{"text": "Even more"}]}, + ] + mock_agent.messages = apply_test_messages + original_messages = mock_agent.messages.copy() + + summarizing_manager.apply_management(mock_agent) + + # Should never modify messages - summarization only happens on context overflow + assert mock_agent.messages == original_messages + + +def test_init_with_custom_parameters(): + """Test initialization with custom parameters.""" + mock_agent = create_mock_agent() + + manager = SummarizingConversationManager( + summary_ratio=0.4, + preserve_recent_messages=5, + summarization_agent=mock_agent, + ) + assert manager.summary_ratio == 0.4 + assert manager.preserve_recent_messages == 5 + assert manager.summarization_agent == mock_agent + assert manager.summarization_system_prompt is None + + +def test_init_with_both_agent_and_prompt_raises_error(): + """Test that providing both agent and system prompt raises ValueError.""" + mock_agent = create_mock_agent() + custom_prompt = "Custom summarization prompt" + + with pytest.raises(ValueError, match="Cannot provide both summarization_agent and summarization_system_prompt"): + SummarizingConversationManager( + summarization_agent=mock_agent, + summarization_system_prompt=custom_prompt, + ) + + +def test_uses_summarization_agent_when_provided(): + """Test that summarization_agent is used when provided.""" + summary_agent = create_mock_agent("Custom summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the dedicated summarization agent, not the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Custom summary from dedicated agent" + + # Assert that the summarization agent was called + summary_agent.call_tracker.assert_called_once() + + +def test_uses_parent_agent_when_no_summarization_agent(): + """Test that parent agent is used when no summarization_agent is provided.""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Parent agent summary" + + # Assert that the parent agent was called + parent_agent.call_tracker.assert_called_once() + + +def test_uses_custom_system_prompt(): + """Test that custom system prompt is used when provided.""" + custom_prompt = "Custom system prompt for summarization" + manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) + mock_agent = create_mock_agent() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Capture the agent's system prompt changes + original_prompt = mock_agent.system_prompt + manager._generate_summary(messages, mock_agent) + + # The agent's system prompt should be restored after summarization + assert mock_agent.system_prompt == original_prompt + + +def test_agent_state_restoration(): + """Test that agent state is properly restored after summarization.""" + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + + # Set initial state + original_system_prompt = "Original system prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] + mock_agent.system_prompt = original_system_prompt + mock_agent.messages = original_messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager._generate_summary(messages, mock_agent) + + # State should be restored + assert mock_agent.system_prompt == original_system_prompt + assert mock_agent.messages == original_messages + + +def test_agent_state_restoration_on_exception(): + """Test that agent state is restored even when summarization fails.""" + manager = SummarizingConversationManager() + + # Create an agent that fails during summarization + mock_agent = Mock() + mock_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + mock_agent.messages = agent_messages + mock_agent.side_effect = Exception("Summarization failed") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should restore state even on exception + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, mock_agent) + + # State should still be restored + assert mock_agent.system_prompt == "Original prompt" + + +def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): + """Test that tool pair adjustment works correctly with the forward-search logic.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + mock_agent = create_mock_agent() + # Create messages where the split point would be adjusted to 0 due to tool pairs + tool_pair_messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest message"}]}, + ] + mock_agent.messages = tool_pair_messages + + # With 3 messages, preserve_recent_messages=1, summary_ratio=0.5: + # messages_to_summarize_count = (3 - 1) * 0.5 = 1 + # But split point adjustment will move forward from the toolUse, potentially increasing count + manager.reduce_context(mock_agent) + # Should have summary + remaining messages + assert len(mock_agent.messages) == 2 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + summary_content = mock_agent.messages[0]["content"][0] + assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] + + # Last message should be the preserved recent message + assert mock_agent.messages[1]["role"] == "user" + assert mock_agent.messages[1]["content"][0]["text"] == "Latest message" + + +def test_adjust_split_point_exceeds_message_length(summarizing_manager): + """Test that split point exceeding message array length raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Try to split at point 5 when there are only 2 messages + with pytest.raises(ContextWindowOverflowException, match="Split point exceeds message array length"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 5) + + +def test_adjust_split_point_equals_message_length(summarizing_manager): + """Test that split point equal to message array length returns unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Split point equals message length (2) - should return unchanged + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 2 + + +def test_adjust_split_point_no_tool_result_at_split(summarizing_manager): + """Test split point that doesn't contain tool result, ensuring we reach return split_point.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + + # Split point message is not a tool result, so it should directly return split_point + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert result == 1 + + +def test_adjust_split_point_tool_result_without_tool_use(summarizing_manager): + """Test that having tool results without tool uses raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Has tool result but no tool use - invalid state + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + + +def test_adjust_split_point_tool_result_moves_to_end(summarizing_manager): + """Test tool result at split point moves forward to valid position at end.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "different_tool", "input": {}}}]}, + ] + + # Split at message 2 (toolResult) - will move forward to message 3 (toolUse at end is valid) + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 3 + + +def test_adjust_split_point_tool_result_no_forward_position(summarizing_manager): + """Test tool result at split point where forward search finds no valid position.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "Message between"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Split at message 3 (toolResult) - will try to move forward but no valid position exists + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + + +def test_reduce_context_adjustment_returns_zero(): + """Test that tool pair adjustment can return zero, triggering the check at line 122.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + # Mock the adjustment method to return 0 + def mock_adjust(messages, split_point): + return 0 # This should trigger the <= 0 check at line 122 + + manager._adjust_split_point_for_tool_pairs = mock_adjust + + mock_agent = Mock() + simple_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = simple_messages + + # The adjustment method will return 0, which should trigger line 122-123 + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef..fe1b3e9f 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f1..11f14503 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -104,27 +111,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -157,7 +143,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( +def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -168,26 +155,12 @@ def test_event_loop_cycle_text_response_input_too_long( tool_execution_handler, ): model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, {"contentBlockStop": {}}, ], ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( model=model, @@ -204,10 +177,12 @@ def test_event_loop_cycle_text_response_input_too_long( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") -def test_event_loop_cycle_text_response_throttling( +def test_event_loop_cycle_exponential_backoff( + mock_time, model, model_id, system_prompt, @@ -217,7 +192,11 @@ def test_event_loop_cycle_text_response_throttling( tool_handler, tool_execution_handler, ): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, @@ -235,11 +214,16 @@ def test_event_loop_cycle_text_response_throttling( tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error( @@ -460,19 +444,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -728,3 +699,176 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_event_loop_cycle_callback( + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + model.converse.return_value = [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + callback_handler.assert_has_calls( + [ + call(start=True), + call(start_event_loop=True), + call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + call( + delta={"toolUse": {"input": '{"value"}'}}, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + call( + reasoningText="value", + delta={"reasoningContent": {"text": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + call( + reasoning_signature="value", + delta={"reasoningContent": {"signature": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + call( + data="value", + delta={"text": "value"}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + ], + ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"] diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py index 395c71a1..fcf531df 100644 --- a/tests/strands/event_loop/test_message_processor.py +++ b/tests/strands/event_loop/test_message_processor.py @@ -45,84 +45,3 @@ def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): result = message_processor.clean_orphaned_empty_tool_uses(test_messages) assert result == expected assert test_messages == expected_messages - - -@pytest.mark.parametrize( - "messages,expected_idx", - [ - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - 1, - ), - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - None, - ), - ( - [], - None, - ), - ], -) -def test_find_last_message_with_tool_results(messages, expected_idx): - idx = message_processor.find_last_message_with_tool_results(messages) - assert idx == expected_idx - - -@pytest.mark.parametrize( - "messages,msg_idx,expected_changed,expected_content", - [ - ( - [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], - } - ], - 0, - True, - [ - { - "toolResult": { - "toolUseId": "1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - } - } - ], - ), - ( - [{"role": "user", "content": [{"text": "no tool result"}]}], - 0, - False, - [{"text": "no tool result"}], - ), - ( - [], - 0, - False, - [], - ), - ( - [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], - 2, - False, - [{"toolResult": {"toolUseId": "1"}}], - ), - ], -) -def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): - test_messages = copy.deepcopy(messages) - changed = message_processor.truncate_tool_results(test_messages, msg_idx) - assert changed == expected_changed - if 0 <= msg_idx < len(test_messages): - assert test_messages[msg_idx]["content"] == expected_content - else: - assert test_messages == messages diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48..e91f4986 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -3,6 +3,7 @@ import pytest import strands +import strands.event_loop from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -17,13 +18,6 @@ def moto_autouse(moto_env, moto_mock_aws): _ = moto_mock_aws -@pytest.fixture -def agent(): - mock = unittest.mock.Mock() - - return mock - - @pytest.mark.parametrize( ("messages", "exp_result"), [ @@ -81,7 +75,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "exp_handler_args"), + ("event", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( @@ -148,21 +142,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, exp_handler_args): - if exp_handler_args: - exp_handler_args.update({"delta": event["delta"], "extra_arg": 1}) +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): + exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} - tru_handler_args = {} - - def callback_handler(**kwargs): - tru_handler_args.update(kwargs) - - tru_updated_state = strands.event_loop.streaming.handle_content_block_delta( - event, state, callback_handler, extra_arg=1 - ) + tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) assert tru_updated_state == exp_updated_state - assert tru_handler_args == exp_handler_args + assert tru_callback_event == exp_callback_event @pytest.mark.parametrize( @@ -275,8 +261,9 @@ def test_extract_usage_metrics(): @pytest.mark.parametrize( - ("response", "exp_stop_reason", "exp_message", "exp_usage", "exp_metrics", "exp_request_state", "exp_messages"), + ("response", "exp_events"), [ + # Standard Message ( [ {"messageStart": {"role": "assistant"}}, @@ -297,28 +284,127 @@ def test_extract_usage_metrics(): } }, ], - "tool_use", - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", + }, + }, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + }, + }, + { + "callback": { + "current_tool_use": { + "input": { + "key": "value", + }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "tool_use", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "tool_use", + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], ), + # Empty Message ( [{}], - "end_turn", - { - "role": "assistant", - "content": [], - }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - {}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": {}, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [], + }, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ), + }, + ], ), + # Redacted Message ( [ {"messageStart": {"role": "assistant"}}, @@ -345,77 +431,161 @@ def test_extract_usage_metrics(): } }, ], - "guardrail_intervened", - { - "role": "assistant", - "content": [{"text": "REDACTED."}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "REDACTED"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", + }, + }, + }, + }, + }, + { + "callback": { + "data": "Hello!", + "delta": { + "text": "Hello!", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", + }, + }, + }, + }, + { + "callback": { + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "guardrail_intervened", + { + "role": "assistant", + "content": [{"text": "REDACTED."}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], ), ], ) -def test_process_stream( - response, exp_stop_reason, exp_message, exp_usage, exp_metrics, exp_request_state, exp_messages -): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 - - tru_messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.process_stream(response, callback_handler, tru_messages) - ) - - assert tru_stop_reason == exp_stop_reason - assert tru_message == exp_message - assert tru_usage == exp_usage - assert tru_metrics == exp_metrics - assert tru_request_state == exp_request_state - assert tru_messages == exp_messages +def test_process_stream(response, exp_events): + messages = [{"role": "user", "content": [{"text": "Some input!"}]}] + stream = strands.event_loop.streaming.process_stream(response, messages) + tru_events = list(stream) + assert tru_events == exp_events -def test_stream_messages(agent): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 +def test_stream_messages(): mock_model = unittest.mock.MagicMock() mock_model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, ] - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.stream_messages( - mock_model, - model_id="test_model", - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - callback_handler=callback_handler, - agent=agent, - ) + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], + tool_config=None, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test"}]} - exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - exp_metrics = {"latencyMs": 0} - exp_request_state = {"calls": 1} - - assert ( - tru_stop_reason == exp_stop_reason - and tru_message == exp_message - and tru_usage == exp_usage - and tru_metrics == exp_metrics - and tru_request_state == exp_request_state - ) + tru_events = list(stream) + exp_events = [ + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", + }, + }, + }, + }, + }, + { + "callback": { + "data": "test", + "delta": { + "text": "test", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "stop": ( + "end_turn", + {"role": "assistant", "content": [{"text": "test"}]}, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ) + }, + ] + assert tru_events == exp_events mock_model.converse.assert_called_with( [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9421650e..a0cfc4d4 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -615,10 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -631,7 +645,10 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b326eee7..137b57c8 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -91,6 +91,23 @@ def test__init__default_model_id(bedrock_client): assert tru_model_id == exp_model_id +def test__init__with_default_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + default_region = "us-west-2" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name=default_region) + # Assert that warning logs are emitted + mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") + mock_warning.assert_any_call( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + + def test__init__with_custom_region(bedrock_client): """Test that BedrockModel uses the provided region.""" _ = bedrock_client diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 89aa591f..4c1f8528 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -132,3 +132,44 @@ def test_stream_empty(openai_client, model): assert tru_events == exp_events openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_with_empty_choices(openai_client, model): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Event with no choices attribute + mock_event_1 = unittest.mock.Mock(spec=[]) + + # Event with empty choices list + mock_event_2 = unittest.mock.Mock(choices=[]) + + # Valid event with content + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + + # Event with finish reason + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage info + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 4e84f0fd..215e1efd 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -1,9 +1,13 @@ import dataclasses import unittest +from unittest import mock import pytest +from opentelemetry.metrics._internal import _ProxyMeter +from opentelemetry.sdk.metrics import MeterProvider import strands +from strands.telemetry import MetricsClient from strands.types.streaming import Metrics, Usage @@ -117,6 +121,37 @@ def test_trace_end(mock_time, end_time, trace): assert tru_end_time == exp_end_time +@pytest.fixture +def mock_get_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + MetricsClient._instance = None + meter_provider_mock = mock.MagicMock(spec=MeterProvider) + + mock_meter = mock.MagicMock() + mock_create_counter = mock.MagicMock() + mock_meter.create_counter.return_value = mock_create_counter + + mock_create_histogram = mock.MagicMock() + mock_meter.create_histogram.return_value = mock_create_histogram + meter_provider_mock.get_meter.return_value = mock_meter + + mock_get_meter_provider.return_value = meter_provider_mock + + yield mock_get_meter_provider + + +@pytest.fixture +def mock_sdk_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_sdk.MeterProvider") as mock_meter_provider: + yield mock_meter_provider + + +@pytest.fixture +def mock_resource(): + with mock.patch("opentelemetry.sdk.resources.Resource") as mock_resource: + yield mock_resource + + def test_trace_add_child(child_trace, trace): trace.add_child(child_trace) @@ -162,11 +197,14 @@ def test_trace_to_dict(trace): @pytest.mark.parametrize("success", [True, False]) -def test_tool_metrics_add_call(success, tool, tool_metrics): +def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provider): tool = dict(tool, **{"name": "updated"}) duration = 1 + metrics_client = MetricsClient() - tool_metrics.add_call(tool, duration, success) + attributes = {"foo": "bar"} + + tool_metrics.add_call(tool, duration, success, metrics_client, attributes=attributes) tru_attrs = dataclasses.asdict(tool_metrics) exp_attrs = { @@ -177,12 +215,17 @@ def test_tool_metrics_add_call(success, tool, tool_metrics): "total_time": duration, } + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client.tool_call_count.add.assert_called_with(1, attributes=attributes) + metrics_client.tool_duration.record.assert_called_with(duration, attributes=attributes) + if success: + metrics_client.tool_success_count.add.assert_called_with(1, attributes=attributes) assert tru_attrs == exp_attrs @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") -def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics): +def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 mock_uuid4.return_value = "i1" @@ -192,6 +235,8 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} exp_attrs = {"cycle_count": 1, "traces": [tru_cycle_trace]} + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_cycle_count.add.assert_called() assert ( tru_start_time == exp_start_time and tru_cycle_trace.to_dict() == exp_cycle_trace.to_dict() @@ -200,10 +245,11 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): +def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace) + attributes = {"foo": "bar"} + event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace, attributes=attributes) tru_cycle_durations = event_loop_metrics.cycle_durations exp_cycle_durations = [1] @@ -215,17 +261,23 @@ def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): assert tru_trace_end_time == exp_trace_end_time + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_end_cycle.add.assert_called_with(1, attributes) + metrics_client.event_loop_cycle_duration.record.assert_called() + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics): +def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} event_loop_metrics.add_tool_usage(tool, duration, trace, success, message) + mock_get_meter_provider.return_value.get_meter.assert_called() + tru_event_loop_metrics_attrs = {"tool_metrics": event_loop_metrics.tool_metrics} exp_event_loop_metrics_attrs = { "tool_metrics": { @@ -258,7 +310,7 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me assert tru_trace_attrs == exp_trace_attrs -def test_event_loop_metrics_update_usage(usage, event_loop_metrics): +def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_usage(usage) @@ -270,9 +322,13 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics): ) assert tru_usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_input_tokens.record.assert_called() + metrics_client.event_loop_output_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): +def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_metrics(metrics) @@ -282,9 +338,11 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): ) assert tru_metrics == exp_metrics + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) -def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics): +def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} @@ -379,3 +437,31 @@ def test_metrics_to_string(trace, child_trace, tool_metrics, exp_str, event_loop tru_str = strands.telemetry.metrics.metrics_to_string(event_loop_metrics) assert tru_str == exp_str + + +def test_setup_meter_if_meter_provider_is_set( + mock_get_meter_provider, + mock_resource, +): + """Test global meter_provider and meter are used""" + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + metrics_client = MetricsClient() + + mock_get_meter_provider.assert_called() + mock_get_meter_provider.return_value.get_meter.assert_called() + + assert metrics_client is not None + + +def test_use_ProxyMeter_if_no_global_meter_provider(): + """Return _ProxyMeter""" + # Reset the singleton instance + strands.telemetry.metrics.MetricsClient._instance = None + + # Create a new instance which should use the real _ProxyMeter + metrics_client = MetricsClient() + + # Verify it's using a _ProxyMeter + assert isinstance(metrics_client.meter, _ProxyMeter) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 32a4ac0a..6ae3e1ad 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -4,7 +4,9 @@ from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -18,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -38,13 +52,16 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @pytest.fixture def mock_otlp_exporter(): - with mock.patch("strands.telemetry.tracer.OTLPSpanExporter") as mock_exporter: + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") as mock_exporter, + ): yield mock_exporter @@ -56,7 +73,9 @@ def mock_console_exporter(): @pytest.fixture def mock_resource(): - with mock.patch("strands.telemetry.tracer.Resource") as mock_resource: + with mock.patch("strands.telemetry.tracer.get_otel_resource") as mock_resource: + mock_resource_instance = mock.MagicMock() + mock_resource.return_value = mock_resource_instance yield mock_resource @@ -104,8 +123,17 @@ def env_with_both(): yield -def test_init_default(): +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -141,17 +169,20 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False # Initialize Tracer Tracer(enable_console_export=True) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify console exporter was added mock_console_exporter.assert_called_once() @@ -161,16 +192,21 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False # Initialize Tracer - Tracer(otlp_endpoint="http://test-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was added with correct endpoint mock_otlp_exporter.assert_called_once() @@ -191,7 +227,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -262,7 +298,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -300,7 +336,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -338,7 +374,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -373,7 +409,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -464,20 +500,24 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False + mock_otlp_exporter.side_effect = Exception("Connection error") # This should not raise an exception, but should log an error # Initialize Tracer - Tracer(otlp_endpoint="http://invalid-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was attempted mock_otlp_exporter.assert_called_once() @@ -486,6 +526,42 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_missing_module( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_resource +): + """Test initializing the tracer when the OTLP exporter module is missing.""" + mock_is_initialized.return_value = False + + # Initialize Tracer with OTLP endpoint but missing module + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + Tracer(otlp_endpoint="http://test-endpoint") + + # Verify the error message + assert "opentelemetry-exporter-otlp-proto-http not detected" in str(excinfo.value) + assert "otel http exporting is currently DISABLED" in str(excinfo.value) + + # Verify the tracer provider was created with correct resource + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) + + # Verify set_tracer_provider was not called since an exception was raised + mock_set_tracer_provider.assert_not_called() + + +def test_initialize_tracer_with_custom_tracer_provider(mock_is_initialized, mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -530,7 +606,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index ced2bd7f..4b238792 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,6 +1,7 @@ import concurrent import functools import unittest.mock +import uuid import pytest @@ -9,6 +10,11 @@ from strands.types.content import Message +@pytest.fixture(autouse=True) +def moto_autouse(moto_env): + _ = moto_env + + @pytest.fixture def tool_handler(request): def handler(tool_use): @@ -37,6 +43,12 @@ def tool_uses(request, tool_use): return request.param if hasattr(request, "param") else [tool_use] +@pytest.fixture +def mock_metrics_client(): + with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: + yield mock_metrics_client + + @pytest.fixture def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() @@ -52,10 +64,10 @@ def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] -@unittest.mock.patch.object(strands.telemetry.metrics, "uuid4", return_value="trace1") @pytest.fixture def cycle_trace(): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") + with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): + return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") @pytest.fixture @@ -297,6 +309,7 @@ def test_run_tools_creates_and_ends_span_on_success( mock_get_tracer, tool_handler, tool_uses, + mock_metrics_client, event_loop_metrics, request_state, invalid_tool_use_ids, diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3dca5371..1b274f46 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,8 +2,11 @@ Tests for the SDK tool registry module. """ +from unittest.mock import MagicMock + import pytest +from strands.tools import PythonAgentTool from strands.tools.registry import ToolRegistry @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path(): with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): tool_registry.process_tools([invalid_path]) + + +def test_register_tool_with_similar_name_raises(): + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + + with pytest.raises(ValueError) as err: + tool_registry.register_tool(tool_2) + + assert ( + str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " + "Cannot add a duplicate tool which differs by a '-' or '_'" + ) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 00000000..2e354b83 --- /dev/null +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,228 @@ +from typing import Literal, Optional + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_tool_spec +from strands.types.tools import ToolSpec + + +# Basic test model +class User(BaseModel): + """User model with name and age.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): + tool_spec = convert_pydantic_to_tool_spec(User) + + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + assert tool_spec == expected_spec + + +def test_convert_pydantic_to_tool_spec_complex(): + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): + name: str = Field(description="The name of the user") + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index f24cc22d..1b65156b 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -50,11 +50,10 @@ def test_validate_tool_use(): def test_normalize_schema_basic(): schema = {"type": "object"} normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert normalized["properties"] == {} - assert "required" in normalized - assert normalized["required"] == [] + + expected = {"type": "object", "properties": {}, "required": []} + + assert normalized == expected def test_normalize_schema_with_properties(): @@ -66,14 +65,17 @@ def test_normalize_schema_with_properties(): }, } normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "User name" - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["description"] == "User age" + + expected = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_removed(): @@ -82,27 +84,40 @@ def test_normalize_schema_with_property_removed(): "properties": {"name": "invalid"}, } normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_defaults(): schema = {"properties": {"name": {}}} normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_enum(): schema = {"properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}} normalized = normalize_schema(schema) - assert "color" in normalized["properties"] - assert normalized["properties"]["color"]["type"] == "string" - assert normalized["properties"]["color"]["description"] == "color" - assert "enum" in normalized["properties"]["color"] - assert normalized["properties"]["color"]["enum"] == ["red", "green", "blue"] + + expected = { + "type": "object", + "properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_numeric_constraints(): @@ -113,21 +128,170 @@ def test_normalize_schema_with_property_numeric_constraints(): } } normalized = normalize_schema(schema) - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["minimum"] == 0 - assert normalized["properties"]["age"]["maximum"] == 120 - assert "score" in normalized["properties"] - assert normalized["properties"]["score"]["type"] == "number" - assert normalized["properties"]["score"]["minimum"] == 0.0 - assert normalized["properties"]["score"]["maximum"] == 100.0 + + expected = { + "type": "object", + "properties": { + "age": {"type": "integer", "description": "age", "minimum": 0, "maximum": 120}, + "score": {"type": "number", "description": "score", "minimum": 0.0, "maximum": 100.0}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_required(): schema = {"type": "object", "required": ["name", "email"]} normalized = normalize_schema(schema) - assert "required" in normalized - assert normalized["required"] == ["name", "email"] + + expected = {"type": "object", "properties": {}, "required": ["name", "email"]} + + assert normalized == expected + + +def test_normalize_schema_with_nested_object(): + """Test normalization of schemas with nested objects.""" + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + assert normalized == expected + + +def test_normalize_schema_with_deeply_nested_objects(): + """Test normalization of deeply nested object structures.""" + schema = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": {"type": "object", "properties": {"value": {"type": "string", "const": "fixed"}}} + }, + } + }, + } + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": { + "type": "object", + "properties": { + "value": {"type": "string", "description": "Property value", "const": "fixed"} + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_const_constraint(): + """Test that const constraints are preserved.""" + schema = { + "type": "object", + "properties": { + "status": {"type": "string", "const": "ACTIVE"}, + "config": {"type": "object", "properties": {"mode": {"type": "string", "const": "PRODUCTION"}}}, + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "status": {"type": "string", "description": "Property status", "const": "ACTIVE"}, + "config": { + "type": "object", + "properties": {"mode": {"type": "string", "description": "Property mode", "const": "PRODUCTION"}}, + "required": [], + }, + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_additional_properties(): + """Test that additionalProperties constraint is preserved.""" + schema = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": {"type": "object", "properties": {"id": {"type": "string"}}, "additionalProperties": False} + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": { + "type": "object", + "additionalProperties": False, + "properties": {"id": {"type": "string", "description": "Property id"}}, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected def test_normalize_tool_spec_with_json_schema(): @@ -137,14 +301,20 @@ def test_normalize_tool_spec_with_json_schema(): "inputSchema": {"json": {"type": "object", "properties": {"query": {}}, "required": ["query"]}}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["properties"]["query"]["type"] == "string" - assert normalized["inputSchema"]["json"]["properties"]["query"]["description"] == "Property query" + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_with_direct_schema(): @@ -154,22 +324,29 @@ def test_normalize_tool_spec_with_direct_schema(): "inputSchema": {"type": "object", "properties": {"query": {}}, "required": ["query"]}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["required"] == ["query"] + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_without_input_schema(): tool_spec = {"name": "test_tool", "description": "A test tool"} normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not present - assert "inputSchema" not in normalized + + expected = {"name": "test_tool", "description": "A test tool"} + + assert normalized == expected def test_normalize_tool_spec_empty_input_schema(): @@ -179,10 +356,10 @@ def test_normalize_tool_spec_empty_input_schema(): "inputSchema": "", } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not a dict - assert normalized["inputSchema"] == "" + + expected = {"name": "test_tool", "description": "A test tool", "inputSchema": ""} + + assert normalized == expected def test_validate_tool_use_with_valid_input(): diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index f2797fe5..03690733 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,8 +1,16 @@ +from typing import Type + import pytest +from pydantic import BaseModel from strands.types.models import Model as SAModel +class Person(BaseModel): + name: str + age: int + + class TestModel(SAModel): def update_config(self, **model_config): return model_config @@ -10,6 +18,9 @@ def update_config(self, **model_config): def get_config(self): return + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + def format_request(self, messages, tool_specs, system_prompt): return { "messages": messages, @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt): }, ] assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 9db08bc9..3a1a940b 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,3 +1,4 @@ +import base64 import unittest.mock import pytest @@ -90,7 +91,24 @@ def system_prompt(): "image_url": { "detail": "auto", "format": "image/jpeg", - "url": "data:image/jpeg;base64,image", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", }, "type": "image_url", }, @@ -344,3 +362,15 @@ def test_format_chunk_unknown_type(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result