diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a6efda65..fa894231 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,37 +1,38 @@ ## Description -[Provide a detailed description of the changes in this PR] + ## Related Issues -[Link to related issues using #issue-number format] + + ## Documentation PR -[Link to related associated PR in the agent-docs repo] + + ## Type of Change -- Bug fix -- New feature -- Breaking change -- Documentation update -- Other (please describe): -[Choose one of the above types of changes] + +Bug fix +New feature +Breaking change +Documentation update +Other (please describe): ## Testing -[How have you tested the change?] -* `hatch fmt --linter` -* `hatch fmt --formatter` -* `hatch test --all` -* Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +- [ ] I ran `hatch run prepare` ## Checklist - [ ] I have read the CONTRIBUTING document -- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] I have added any necessary tests that prove my fix is effective or my feature works - [ ] I have updated the documentation accordingly -- [ ] I have added an appropriate example to the documentation to outline the feature +- [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published +---- + By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 294a2f3e..39b53c49 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -2,41 +2,59 @@ name: Secure Integration test on: pull_request_target: - types: [opened, synchronize, labeled, unlabled, reopened] + branches: main jobs: + authorization-check: + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result }} + steps: + - name: Collaborator Check + uses: actions/github-script@v7 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) + return "manual-approval" + } check-access-and-checkout: runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} permissions: id-token: write pull-requests: read contents: read steps: - - name: Check PR labels and author - id: check - uses: actions/github-script@v7 - with: - script: | - const pr = context.payload.pull_request; - - const labels = pr.labels.map(label => label.name); - const hasLabel = labels.includes('approved-for-integ-test') - if (hasLabel) { - core.info('PR contains label approved-for-integ-test') - return - } - - core.setFailed('Pull Request must either have label approved-for-integ-test') - name: Configure Credentials uses: aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 mask-aws-account-id: true - - name: Checkout base branch + - name: Checkout head commit uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.ref }} # Pull the commit from the forked repo + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python uses: actions/setup-python@v5 diff --git a/pyproject.toml b/pyproject.toml index 835def0f..4bb69ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", @@ -58,7 +58,6 @@ dev = [ "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -66,7 +65,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -109,7 +108,8 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + # excluding due to A2A and OTEL http exporter dependency conflict + "mypy -p src --exclude src/strands/multiagent" ] lint-fix = [ "ruff check --fix" @@ -138,17 +138,29 @@ features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] dev-mode = true features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] +[tool.hatch.envs.a2a.scripts] +run = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}" +] +run-cov = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}" +] +lint-check = [ + "ruff check", + "mypy -p src/strands/multiagent/a2a" +] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] - [tool.hatch.envs.hatch-test.scripts] run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a" ] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a" ] cov-combine = [] @@ -177,7 +189,15 @@ test = [ test-integ = [ "hatch test tests-integ {args}" ] - +prepare = [ + "hatch fmt --linter", + "hatch fmt --formatter", + "hatch test --all" +] +test-a2a = [ + # required to run manually due to A2A and OTEL http exporter dependency conflict + "hatch -e a2a run run {args}" +] [tool.mypy] python_version = "3.10" @@ -261,4 +281,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 56f5b92e..4ecc42a9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -43,6 +44,9 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -216,6 +220,9 @@ def __init__( record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, ): """Initialize the Agent with the specified configuration. @@ -248,6 +255,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to True. trace_attributes: Custom trace attributes to apply to the agent's trace span. + name: name of the Agent + Defaults to None. + description: description of what the Agent does + Defaults to None. Raises: ValueError: If max_parallel_tools is less than 1. @@ -308,8 +319,9 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None - self.tool_caller = Agent.ToolCaller(self) + self.name = name + self.description = description @property def tool(self) -> ToolCaller: @@ -387,6 +399,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages, self.callback_handler) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 71165926..9580ea35 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -33,23 +33,6 @@ MAX_DELAY = 240 # 4 minutes -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs - - def event_loop_cycle( model: Model, system_prompt: Optional[str], @@ -105,10 +88,11 @@ def event_loop_cycle( kwargs["event_loop_cycle_id"] = uuid.uuid4() event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) - # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) - cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() + if "request_state" not in kwargs: + kwargs["request_state"] = {} + attributes = {"event_loop_cycle_id": str(kwargs.get("event_loop_cycle_id"))} + cycle_start_time, cycle_trace = event_loop_metrics.start_cycle(attributes=attributes) kwargs["event_loop_cycle_trace"] = cycle_trace callback_handler(start=True) @@ -136,6 +120,7 @@ def event_loop_cycle( metrics: Metrics # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( @@ -145,14 +130,19 @@ def event_loop_cycle( ) try: - stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( - model, - system_prompt, - messages, - tool_config, - callback_handler, - **kwargs, - ) + # TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the + # call stack. At this point, we converted all events that were previously passed to the handler in + # `stream_messages` into yielded events that now have the "callback" key. To maintain backwards + # compatability, we need to combine the event with kwargs before passing to the handler. This we will + # revisit when migrating to strongly typed events. + for event in stream_messages(model, system_prompt, messages, tool_config): + if "callback" in event: + inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} + callback_handler(**inputs) + else: + stop_reason, message, usage, metrics = event["stop"] + kwargs.setdefault("request_state", {}) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage) break # Success! Break out of retry loop @@ -168,7 +158,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -226,7 +216,7 @@ def event_loop_cycle( ) # End the cycle and return results - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) if cycle_span: tracer.end_event_loop_cycle_span( span=cycle_span, @@ -309,26 +299,6 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. - - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. - - Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. - - Returns: - Updated keyword arguments ready for the next cycle. - """ - # Store parent cycle ID - kwargs["event_loop_metrics"] = event_loop_metrics - kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - - return kwargs - - def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -369,7 +339,7 @@ def _handle_tool_execution( kwargs (Dict[str, Any]): Additional keyword arguments, including request state. Returns: - Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: - The stop reason, - The updated message, - The updated event loop metrics, @@ -379,7 +349,6 @@ def _handle_tool_execution( if not tool_uses: return stop_reason, message, event_loop_metrics, kwargs["request_state"] - tool_handler_process = partial( tool_handler.process, messages=messages, @@ -402,7 +371,9 @@ def _handle_tool_execution( parallel_tool_executor=tool_execution_handler, ) - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + # Store parent cycle ID for the next cycle + kwargs["event_loop_metrics"] = event_loop_metrics + kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] tool_result_message: Message = { "role": "user", diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806f..0e9d472b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: return message -def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: +def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: """Handles the start of a content block by extracting tool usage information if any. Args: @@ -102,31 +102,31 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any -) -> Dict[str, Any]: + event: ContentBlockDeltaEvent, state: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. """ delta_content = event["delta"] + callback_event = {} + if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) + callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) + callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,29 +134,27 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_handler( - reasoningText=delta_content["reasoningContent"]["text"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoningText": delta_content["reasoningContent"]["text"], + "delta": delta_content, + "reasoning": True, + } elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_handler( - reasoning_signature=delta_content["reasoningContent"]["signature"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoning_signature": delta_content["reasoningContent"]["signature"], + "delta": delta_content, + "reasoning": True, + } - return state + return state, callback_event -def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: +def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: """Handles the end of a content block by finalizing tool usage, text content, or reasoning content. Args: @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: Returns: Updated state with finalized content block. """ - content: List[ContentBlock] = state["content"] + content: list[ContentBlock] = state["content"] current_tool_use = state["current_tool_use"] text = state["text"] @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: def process_stream( chunks: Iterable[StreamEvent], - callback_handler: Any, messages: Messages, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - callback_handler: Callback for processing events as they happen. messages: The agents messages. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the constructed message, the usage metrics, and the updated request state. + The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" - state: Dict[str, Any] = { + state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, "text": "", "current_tool_use": {}, @@ -285,18 +278,16 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - kwargs.setdefault("request_state", {}) - for chunk in chunks: - # Callback handler call here allows each event to be visible to the caller - callback_handler(event=chunk) + yield {"callback": {"event": chunk}} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) + state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +297,7 @@ def process_stream( elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], messages, state) - return stop_reason, state["message"], usage, metrics, kwargs["request_state"] + yield {"stop": (stop_reason, state["message"], usage, metrics)} def stream_messages( @@ -314,9 +305,7 @@ def stream_messages( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Streams messages to the model and processes the response. Args: @@ -324,12 +313,9 @@ def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_config: Configuration for the tools to use. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the final message, the usage metrics, and updated request state. + The reason for stopping, the final message, and the usage metrics """ logger.debug("model=<%s> | streaming messages", model) @@ -337,4 +323,4 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - return process_stream(chunks, callback_handler, messages, **kwargs) + yield from process_stream(chunks, messages) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 57394e2c..51089d47 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,11 +7,15 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -20,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield event.model_dump() usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9bbcca7d..a5ffb539 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,13 +6,17 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, cast +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -29,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -112,8 +118,17 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + region_for_boto = region_name or os.getenv("AWS_REGION") + if region_for_boto is None: + region_for_boto = "us-west-2" + logger.warning("defaulted to us-west-2 because no region was specified") + logger.warning( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + session = boto_session or boto3.Session( - region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", + region_name=region_for_boto, ) # Add strands-agents to the request user agent @@ -477,3 +492,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 62f16d31..66138186 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,17 +3,22 @@ - Docs: https://docs.litellm.ai/ """ +import json import logging -from typing import Any, Optional, TypedDict, cast +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema +from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from .openai import OpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" @@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] } return super().format_request_message_content(content) + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + """ + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 583db2f2..755e07ad 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. + """ + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7ed12216..b062fe14 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,9 +5,10 @@ import json import logging -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6cbef664..783ce379 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,15 +4,20 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel from typing_extensions import Unpack, override +from ..types.content import Messages from ..types.models import OpenAIModel as SAOpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -125,3 +130,35 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: _ = event yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py new file mode 100644 index 00000000..1cef1425 --- /dev/null +++ b/src/strands/multiagent/__init__.py @@ -0,0 +1,13 @@ +"""Multiagent capabilities for Strands Agents. + +This module provides support for multiagent systems, including agent-to-agent (A2A) +communication protocols and coordination mechanisms. + +Submodules: + a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables + standardized communication between agents. +""" + +from . import a2a + +__all__ = ["a2a"] diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py new file mode 100644 index 00000000..c5425618 --- /dev/null +++ b/src/strands/multiagent/a2a/__init__.py @@ -0,0 +1,14 @@ +"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. + +This module provides classes and utilities for enabling Strands Agents to communicate +with other agents using the Agent-to-Agent (A2A) protocol. + +Docs: https://google-a2a.github.io/A2A/latest/ + +Classes: + A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. +""" + +from .agent import A2AAgent + +__all__ = ["A2AAgent"] diff --git a/src/strands/multiagent/a2a/agent.py b/src/strands/multiagent/a2a/agent.py new file mode 100644 index 00000000..4359100d --- /dev/null +++ b/src/strands/multiagent/a2a/agent.py @@ -0,0 +1,149 @@ +"""A2A-compatible wrapper for Strands Agent. + +This module provides the A2AAgent class, which adapts a Strands Agent to the A2A protocol, +allowing it to be used in A2A-compatible systems. +""" + +import logging +from typing import Any, Literal + +import uvicorn +from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from fastapi import FastAPI +from starlette.applications import Starlette + +from ...agent.agent import Agent as SAAgent +from .executor import StrandsA2AExecutor + +logger = logging.getLogger(__name__) + + +class A2AAgent: + """A2A-compatible wrapper for Strands Agent.""" + + def __init__( + self, + agent: SAAgent, + *, + # AgentCard + host: str = "0.0.0.0", + port: int = 9000, + version: str = "0.0.1", + ): + """Initialize an A2A-compatible agent from a Strands agent. + + Args: + agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. + host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". + port: The port to bind the A2A server to. Defaults to 9000. + version: The version of the agent. Defaults to "0.0.1". + """ + self.host = host + self.port = port + self.http_url = f"http://{self.host}:{self.port}/" + self.version = version + self.strands_agent = agent + self.name = self.strands_agent.name + self.description = self.strands_agent.description + # TODO: enable configurable capabilities and request handler + self.capabilities = AgentCapabilities() + self.request_handler = DefaultRequestHandler( + agent_executor=StrandsA2AExecutor(self.strands_agent), + task_store=InMemoryTaskStore(), + ) + logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + + @property + def public_agent_card(self) -> AgentCard: + """Get the public AgentCard for this agent. + + The AgentCard contains metadata about the agent, including its name, + description, URL, version, skills, and capabilities. This information + is used by other agents and systems to discover and interact with this agent. + + Returns: + AgentCard: The public agent card containing metadata about this agent. + + Raises: + ValueError: If name or description is None or empty. + """ + if not self.name: + raise ValueError("A2A agent name cannot be None or empty") + if not self.description: + raise ValueError("A2A agent description cannot be None or empty") + + return AgentCard( + name=self.name, + description=self.description, + url=self.http_url, + version=self.version, + skills=self.agent_skills, + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=self.capabilities, + ) + + @property + def agent_skills(self) -> list[AgentSkill]: + """Get the list of skills this agent provides. + + Skills represent specific capabilities that the agent can perform. + Strands agent tools are adapted to A2A skills. + + Returns: + list[AgentSkill]: A list of skills this agent provides. + """ + # TODO: translate Strands tools (native & MCP) to skills + return [] + + def to_starlette_app(self) -> Starlette: + """Create a Starlette application for serving this agent via HTTP. + + This method creates a Starlette application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + Starlette: A Starlette application configured to serve this agent. + """ + return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def to_fastapi_app(self) -> FastAPI: + """Create a FastAPI application for serving this agent via HTTP. + + This method creates a FastAPI application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + FastAPI: A FastAPI application configured to serve this agent. + """ + return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def serve(self, app_type: Literal["fastapi", "starlette"] = "starlette", **kwargs: Any) -> None: + """Start the A2A server with the specified application type. + + This method starts an HTTP server that exposes the agent via the A2A protocol. + The server can be implemented using either FastAPI or Starlette, depending on + the specified app_type. + + Args: + app_type: The type of application to serve, either "fastapi" or "starlette". + Defaults to "starlette". + **kwargs: Additional keyword arguments to pass to uvicorn.run. + """ + try: + logger.info("Starting Strands A2A server...") + if app_type == "fastapi": + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) + else: + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) + except KeyboardInterrupt: + logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") + except Exception: + logger.exception("Strands A2A server encountered exception.") + finally: + logger.info("Strands A2A server has shutdown.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py new file mode 100644 index 00000000..b7a7af09 --- /dev/null +++ b/src/strands/multiagent/a2a/executor.py @@ -0,0 +1,67 @@ +"""Strands Agent executor for the A2A protocol. + +This module provides the StrandsA2AExecutor class, which adapts a Strands Agent +to be used as an executor in the A2A protocol. It handles the execution of agent +requests and the conversion of Strands Agent responses to A2A events. +""" + +import logging + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message +from a2a.utils.errors import ServerError + +from ...agent.agent import Agent as SAAgent +from ...agent.agent_result import AgentResult as SAAgentResult + +log = logging.getLogger(__name__) + + +class StrandsA2AExecutor(AgentExecutor): + """Executor that adapts a Strands Agent to the A2A protocol.""" + + def __init__(self, agent: SAAgent): + """Initialize a StrandsA2AExecutor. + + Args: + agent: The Strands Agent to adapt to the A2A protocol. + """ + self.agent = agent + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute a request using the Strands Agent and send the response as A2A events. + + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. + + Args: + context: The A2A request context, containing the user's input and other metadata. + event_queue: The A2A event queue, used to send response events. + """ + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel an ongoing execution. + + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. + + Args: + context: The A2A request context. + event_queue: The A2A event queue. + + Raises: + ServerError: Always raised with an UnsupportedOperationError, as cancellation + is not currently supported. + """ + raise ServerError(error=UnsupportedOperationError()) diff --git a/src/strands/telemetry/__init__.py b/src/strands/telemetry/__init__.py index 15981216..21dd6ebf 100644 --- a/src/strands/telemetry/__init__.py +++ b/src/strands/telemetry/__init__.py @@ -3,7 +3,8 @@ This module provides metrics and tracing functionality. """ -from .metrics import EventLoopMetrics, Trace, metrics_to_string +from .config import get_otel_resource +from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string from .tracer import Tracer, get_tracer __all__ = [ @@ -12,4 +13,6 @@ "metrics_to_string", "Tracer", "get_tracer", + "MetricsClient", + "get_otel_resource", ] diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py new file mode 100644 index 00000000..9f5a05fd --- /dev/null +++ b/src/strands/telemetry/config.py @@ -0,0 +1,33 @@ +"""OpenTelemetry configuration and setup utilities for Strands agents. + +This module provides centralized configuration and initialization functionality +for OpenTelemetry components and other telemetry infrastructure shared across Strands applications. +""" + +from importlib.metadata import version + +from opentelemetry.sdk.resources import Resource + + +def get_otel_resource() -> Resource: + """Create a standard OpenTelemetry resource with service information. + + This function implements a singleton pattern - it will return the same + Resource object for the same service_name parameter. + + Args: + service_name: Name of the service for OpenTelemetry. + + Returns: + Resource object with standard service information. + """ + resource = Resource.create( + { + "service.name": __name__, + "service.version": version("strands-agents"), + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.language": "python", + } + ) + + return resource diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index cd70819b..332ab2ae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -6,6 +6,10 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +import opentelemetry.metrics as metrics_api +from opentelemetry.metrics import Counter, Histogram, Meter + +from ..telemetry import metrics_constants as constants from ..types.content import Message from ..types.streaming import Metrics, Usage from ..types.tools import ToolUse @@ -117,22 +121,34 @@ class ToolMetrics: error_count: int = 0 total_time: float = 0.0 - def add_call(self, tool: ToolUse, duration: float, success: bool) -> None: + def add_call( + self, + tool: ToolUse, + duration: float, + success: bool, + metrics_client: "MetricsClient", + attributes: Optional[Dict[str, Any]] = None, + ) -> None: """Record a new tool call with its outcome. Args: tool: The tool that was called. duration: How long the call took in seconds. success: Whether the call was successful. + metrics_client: The metrics client for recording the metrics. + attributes: attributes of the metrics. """ self.tool = tool # Update with latest tool state self.call_count += 1 self.total_time += duration - + metrics_client.tool_call_count.add(1, attributes=attributes) + metrics_client.tool_duration.record(duration, attributes=attributes) if success: self.success_count += 1 + metrics_client.tool_success_count.add(1, attributes=attributes) else: self.error_count += 1 + metrics_client.tool_error_count.add(1, attributes=attributes) @dataclass @@ -155,32 +171,53 @@ class EventLoopMetrics: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) - def start_cycle(self) -> Tuple[float, Trace]: + @property + def _metrics_client(self) -> "MetricsClient": + """Get the singleton MetricsClient instance.""" + return MetricsClient() + + def start_cycle( + self, + attributes: Optional[Dict[str, Any]] = None, + ) -> Tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. + Args: + attributes: attributes of the metrics. + Returns: A tuple containing the start time and the cycle trace object. """ + self._metrics_client.event_loop_cycle_count.add(1, attributes=attributes) + self._metrics_client.event_loop_start_cycle.add(1, attributes=attributes) self.cycle_count += 1 start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: """End the current event loop cycle and record its duration. Args: start_time: The timestamp when the cycle started. cycle_trace: The trace object for this cycle. + attributes: attributes of the metrics. """ + self._metrics_client.event_loop_end_cycle.add(1, attributes) end_time = time.time() duration = end_time - start_time + self._metrics_client.event_loop_cycle_duration.record(duration, attributes) self.cycle_durations.append(duration) cycle_trace.end(end_time) def add_tool_usage( - self, tool: ToolUse, duration: float, tool_trace: Trace, success: bool, message: Message + self, + tool: ToolUse, + duration: float, + tool_trace: Trace, + success: bool, + message: Message, ) -> None: """Record metrics for a tool invocation. @@ -203,8 +240,16 @@ def add_tool_usage( tool_trace.raw_name = f"{tool_name} - {tool_use_id}" tool_trace.add_message(message) - self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call(tool, duration, success) - + self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call( + tool, + duration, + success, + self._metrics_client, + attributes={ + "tool_name": tool_name, + "tool_use_id": tool_use_id, + }, + ) tool_trace.end() def update_usage(self, usage: Usage) -> None: @@ -213,6 +258,8 @@ def update_usage(self, usage: Usage) -> None: Args: usage: The usage data to add to the accumulated totals. """ + self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) + self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) self.accumulated_usage["inputTokens"] += usage["inputTokens"] self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] @@ -223,6 +270,7 @@ def update_metrics(self, metrics: Metrics) -> None: Args: metrics: The metrics data to add to the accumulated totals. """ + self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -355,3 +403,74 @@ def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optio A formatted string representation of the metrics. """ return "\n".join(_metrics_summary_to_lines(event_loop_metrics, allowed_names or set())) + + +class MetricsClient: + """Singleton client for managing OpenTelemetry metrics instruments. + + The actual metrics export destination (console, OTLP endpoint, etc.) is configured + through OpenTelemetry SDK configuration by users, not by this client. + """ + + _instance: Optional["MetricsClient"] = None + meter: Meter + event_loop_cycle_count: Counter + event_loop_start_cycle: Counter + event_loop_end_cycle: Counter + event_loop_cycle_duration: Histogram + event_loop_latency: Histogram + event_loop_input_tokens: Histogram + event_loop_output_tokens: Histogram + + tool_call_count: Counter + tool_success_count: Counter + tool_error_count: Counter + tool_duration: Histogram + + def __new__(cls) -> "MetricsClient": + """Create or return the singleton instance of MetricsClient. + + Returns: + The single MetricsClient instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the MetricsClient. + + This method only runs once due to the singleton pattern. + Sets up the OpenTelemetry meter and creates metric instruments. + """ + if hasattr(self, "meter"): + return + + logger.info("Creating Strands MetricsClient") + meter_provider: metrics_api.MeterProvider = metrics_api.get_meter_provider() + self.meter = meter_provider.get_meter(__name__) + self.create_instruments() + + def create_instruments(self) -> None: + """Create and initialize all OpenTelemetry metric instruments.""" + self.event_loop_cycle_count = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_CYCLE_COUNT, unit="Count" + ) + self.event_loop_start_cycle = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_START_CYCLE, unit="Count" + ) + self.event_loop_end_cycle = self.meter.create_counter(name=constants.STRANDS_EVENT_LOOP_END_CYCLE, unit="Count") + self.event_loop_cycle_duration = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CYCLE_DURATION, unit="s" + ) + self.event_loop_latency = self.meter.create_histogram(name=constants.STRANDS_EVENT_LOOP_LATENCY, unit="ms") + self.tool_call_count = self.meter.create_counter(name=constants.STRANDS_TOOL_CALL_COUNT, unit="Count") + self.tool_success_count = self.meter.create_counter(name=constants.STRANDS_TOOL_SUCCESS_COUNT, unit="Count") + self.tool_error_count = self.meter.create_counter(name=constants.STRANDS_TOOL_ERROR_COUNT, unit="Count") + self.tool_duration = self.meter.create_histogram(name=constants.STRANDS_TOOL_DURATION, unit="s") + self.event_loop_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKENS, unit="token" + ) + self.event_loop_output_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py new file mode 100644 index 00000000..b622eebf --- /dev/null +++ b/src/strands/telemetry/metrics_constants.py @@ -0,0 +1,15 @@ +"""Metrics that are emitted in Strands-Agents.""" + +STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count" +STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle" +STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle" +STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count" +STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count" +STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count" + +# Histograms +STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency" +STRANDS_TOOL_DURATION = "strands.tool.duration" +STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" +STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" +STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index e9a37a4a..813c90e1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -8,20 +8,19 @@ import logging import os from datetime import date, datetime, timezone -from importlib.metadata import version from typing import Any, Dict, Mapping, Optional import opentelemetry.trace as trace_api from opentelemetry import propagate from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import Span, StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult +from ..telemetry import get_otel_resource from ..types.content import Message, Messages from ..types.streaming import Usage from ..types.tools import ToolResult, ToolUse @@ -151,7 +150,6 @@ def __init__( self.otlp_headers = otlp_headers or {} self.tracer_provider: Optional[trace_api.TracerProvider] = None self.tracer: Optional[trace_api.Tracer] = None - propagate.set_global_textmap( CompositePropagator( [ @@ -173,15 +171,7 @@ def _initialize_tracer(self) -> None: self.tracer = self.tracer_provider.get_tracer(self.service_name) return - # Create resource with service information - resource = Resource.create( - { - "service.name": self.service_name, - "service.version": version("strands-agents"), - "telemetry.sdk.name": "opentelemetry", - "telemetry.sdk.language": "python", - } - ) + resource = get_otel_resource() # Create tracer provider self.tracer_provider = SDKTracerProvider(resource=resource) @@ -216,6 +206,7 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) + except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) elif self.otlp_endpoint and self.tracer_provider: diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..12979015 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 00000000..5421cdc6 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,415 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + +from ..types.tools import ToolSpec + + +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + required_props: list[str] = [] + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: + flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") + + return flattened + + +def _process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def _process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = _process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_tool_spec( + model: Type[BaseModel], + description: Optional[str] = None, +) -> ToolSpec: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + + Returns: + ToolSpec: Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + _process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + _expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = _flatten_schema(input_schema) + + final_schema = flattened_schema + + # Construct the tool specification + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) + + +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + _process_properties(ref_def, field_type) + + +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 23e74602..071c8a51 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,9 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel from ..content import Messages from ..streaming import StreamEvent @@ -10,6 +12,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -38,6 +42,26 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + pass + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 96f758d5..8ff37d35 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,8 +11,9 @@ import json import logging import mimetypes -from typing import Any, Optional, cast +from typing import Any, Callable, Optional, Type, TypeVar, cast +from pydantic import BaseModel from typing_extensions import override from ..content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OpenAIModel(Model, abc.ABC): """Base OpenAI model provider implementation. @@ -31,6 +34,32 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. @@ -57,7 +86,8 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { "image_url": { "detail": "auto", @@ -262,3 +292,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c9..50033f8f 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,7 +34,7 @@ def tool_weather() -> str: @pytest.fixture def system_prompt(): - return "You are an AI assistant that uses & instead of ." + return "You are an AI assistant." @pytest.fixture @@ -46,4 +47,17 @@ def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - assert all(string in text for string in ["12:00", "sunny", "&"]) + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index a6a29aa9..5378a9b2 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -118,3 +119,33 @@ def calculator(expression: str) -> float: agent("What is 123 + 456?") assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index 86f6b42f..01a3e121 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 00000000..38b46821 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index c9046ad5..b0790ba0 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -44,3 +45,22 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests/multiagent/__init__.py b/tests/multiagent/__init__.py new file mode 100644 index 00000000..b43bae53 --- /dev/null +++ b/tests/multiagent/__init__.py @@ -0,0 +1 @@ +"""Tests for the multiagent module.""" diff --git a/tests/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py new file mode 100644 index 00000000..eb5487d9 --- /dev/null +++ b/tests/multiagent/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A module.""" diff --git a/tests/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py new file mode 100644 index 00000000..558a4594 --- /dev/null +++ b/tests/multiagent/a2a/conftest.py @@ -0,0 +1,41 @@ +"""Common fixtures for A2A module tests.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue + +from strands.agent.agent import Agent as SAAgent +from strands.agent.agent_result import AgentResult as SAAgentResult + + +@pytest.fixture +def mock_strands_agent(): + """Create a mock Strands Agent for testing.""" + agent = MagicMock(spec=SAAgent) + agent.name = "Test Agent" + agent.description = "A test agent for unit testing" + + # Setup default response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + agent.return_value = mock_result + + return agent + + +@pytest.fixture +def mock_request_context(): + """Create a mock RequestContext for testing.""" + context = MagicMock(spec=RequestContext) + context.get_user_input.return_value = "Test input" + return context + + +@pytest.fixture +def mock_event_queue(): + """Create a mock EventQueue for testing.""" + queue = MagicMock(spec=EventQueue) + queue.enqueue_event = AsyncMock() + return queue diff --git a/tests/multiagent/a2a/test_agent.py b/tests/multiagent/a2a/test_agent.py new file mode 100644 index 00000000..5558c2af --- /dev/null +++ b/tests/multiagent/a2a/test_agent.py @@ -0,0 +1,165 @@ +"""Tests for the A2AAgent class.""" + +from unittest.mock import patch + +import pytest +from a2a.types import AgentCapabilities, AgentCard +from fastapi import FastAPI +from starlette.applications import Starlette + +from strands.multiagent.a2a.agent import A2AAgent + + +def test_a2a_agent_initialization(mock_strands_agent): + """Test that A2AAgent initializes correctly with default values.""" + a2a_agent = A2AAgent(mock_strands_agent) + + assert a2a_agent.strands_agent == mock_strands_agent + assert a2a_agent.name == "Test Agent" + assert a2a_agent.description == "A test agent for unit testing" + assert a2a_agent.host == "0.0.0" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://0.0.0:9000/" + assert a2a_agent.version == "0.0.1" + assert isinstance(a2a_agent.capabilities, AgentCapabilities) + + +def test_a2a_agent_initialization_with_custom_values(mock_strands_agent): + """Test that A2AAgent initializes correctly with custom values.""" + a2a_agent = A2AAgent( + mock_strands_agent, + host="127.0.0.1", + port=8080, + version="1.0.0", + ) + + assert a2a_agent.host == "127.0.0.1" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.version == "1.0.0" + + +def test_public_agent_card(mock_strands_agent): + """Test that public_agent_card returns a valid AgentCard.""" + a2a_agent = A2AAgent(mock_strands_agent) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + assert card.url == "http://0.0.0:9000/" + assert card.version == "0.0.1" + assert card.defaultInputModes == ["text"] + assert card.defaultOutputModes == ["text"] + assert card.skills == [] + assert card.capabilities == a2a_agent.capabilities + + +def test_public_agent_card_with_missing_name(mock_strands_agent): + """Test that public_agent_card raises ValueError when name is missing.""" + mock_strands_agent.name = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent name cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_public_agent_card_with_missing_description(mock_strands_agent): + """Test that public_agent_card raises ValueError when description is missing.""" + mock_strands_agent.description = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent description cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_agent_skills(mock_strands_agent): + """Test that agent_skills returns an empty list (current implementation).""" + a2a_agent = A2AAgent(mock_strands_agent) + + skills = a2a_agent.agent_skills + + assert isinstance(skills, list) + assert len(skills) == 0 + + +def test_to_starlette_app(mock_strands_agent): + """Test that to_starlette_app returns a Starlette application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app(mock_strands_agent): + """Test that to_fastapi_app returns a FastAPI application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +@patch("uvicorn.run") +def test_serve_with_starlette(mock_run, mock_strands_agent): + """Test that serve starts a Starlette server by default.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], Starlette) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_fastapi(mock_run, mock_strands_agent): + """Test that serve starts a FastAPI server when specified.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(app_type="fastapi") + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], FastAPI) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_custom_kwargs(mock_run, mock_strands_agent): + """Test that serve passes additional kwargs to uvicorn.run.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(log_level="debug", reload=True) + + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["log_level"] == "debug" + assert kwargs["reload"] is True + + +@patch("uvicorn.run", side_effect=KeyboardInterrupt) +def test_serve_handles_keyboard_interrupt(mock_run, mock_strands_agent, caplog): + """Test that serve handles KeyboardInterrupt gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server shutdown requested (KeyboardInterrupt)" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text + + +@patch("uvicorn.run", side_effect=Exception("Test exception")) +def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): + """Test that serve handles general exceptions gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server encountered exception" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 00000000..2ac9bed9 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,118 @@ +"""Tests for the StrandsA2AExecutor class.""" + +from unittest.mock import MagicMock + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult as SAAgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor + + +def test_executor_initialization(mock_strands_agent): + """Test that StrandsA2AExecutor initializes correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + assert executor.agent == mock_strands_agent + + +@pytest.mark.asyncio +async def test_execute_with_text_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes text responses correctly.""" + # Setup mock agent response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify event was enqueued + mock_event_queue.enqueue_event.assert_called_once() + args, _ = mock_event_queue.enqueue_event.call_args + event = args[0] + assert event.parts[0].root.text == "Test response" + + +@pytest.mark.asyncio +async def test_execute_with_multiple_text_blocks(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes multiple text blocks correctly.""" + # Setup mock agent response with multiple text blocks + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "First response"}, {"text": "Second response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify events were enqueued + assert mock_event_queue.enqueue_event.call_count == 2 + + # Check first event + args1, _ = mock_event_queue.enqueue_event.call_args_list[0] + event1 = args1[0] + assert event1.parts[0].root.text == "First response" + + # Check second event + args2, _ = mock_event_queue.enqueue_event.call_args_list[1] + event2 = args2[0] + assert event2.parts[0].root.text == "Second response" + + +@pytest.mark.asyncio +async def test_execute_with_empty_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles empty responses correctly.""" + # Setup mock agent response with empty content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": []} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_with_no_message(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles responses with no message correctly.""" + # Setup mock agent response with no message + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = None + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6f47be0..85d17544 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,6 +7,7 @@ from time import sleep import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -793,6 +794,31 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 8c46e009..11f14503 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -104,27 +111,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -157,8 +143,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -191,6 +177,53 @@ def test_event_loop_cycle_text_response_throttling( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() + + +def test_event_loop_cycle_exponential_backoff( + mock_time, + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding + model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ], + ] + + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error( @@ -411,19 +444,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -679,3 +699,176 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_event_loop_cycle_callback( + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + model.converse.return_value = [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + callback_handler.assert_has_calls( + [ + call(start=True), + call(start_event_loop=True), + call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + call( + delta={"toolUse": {"input": '{"value"}'}}, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + call( + reasoningText="value", + delta={"reasoningContent": {"text": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + call( + reasoning_signature="value", + delta={"reasoningContent": {"signature": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + call( + data="value", + delta={"text": "value"}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + ], + ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"] diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48..e91f4986 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -3,6 +3,7 @@ import pytest import strands +import strands.event_loop from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -17,13 +18,6 @@ def moto_autouse(moto_env, moto_mock_aws): _ = moto_mock_aws -@pytest.fixture -def agent(): - mock = unittest.mock.Mock() - - return mock - - @pytest.mark.parametrize( ("messages", "exp_result"), [ @@ -81,7 +75,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "exp_handler_args"), + ("event", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( @@ -148,21 +142,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, exp_handler_args): - if exp_handler_args: - exp_handler_args.update({"delta": event["delta"], "extra_arg": 1}) +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): + exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} - tru_handler_args = {} - - def callback_handler(**kwargs): - tru_handler_args.update(kwargs) - - tru_updated_state = strands.event_loop.streaming.handle_content_block_delta( - event, state, callback_handler, extra_arg=1 - ) + tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) assert tru_updated_state == exp_updated_state - assert tru_handler_args == exp_handler_args + assert tru_callback_event == exp_callback_event @pytest.mark.parametrize( @@ -275,8 +261,9 @@ def test_extract_usage_metrics(): @pytest.mark.parametrize( - ("response", "exp_stop_reason", "exp_message", "exp_usage", "exp_metrics", "exp_request_state", "exp_messages"), + ("response", "exp_events"), [ + # Standard Message ( [ {"messageStart": {"role": "assistant"}}, @@ -297,28 +284,127 @@ def test_extract_usage_metrics(): } }, ], - "tool_use", - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", + }, + }, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + }, + }, + { + "callback": { + "current_tool_use": { + "input": { + "key": "value", + }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "tool_use", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "tool_use", + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], ), + # Empty Message ( [{}], - "end_turn", - { - "role": "assistant", - "content": [], - }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - {}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": {}, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [], + }, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ), + }, + ], ), + # Redacted Message ( [ {"messageStart": {"role": "assistant"}}, @@ -345,77 +431,161 @@ def test_extract_usage_metrics(): } }, ], - "guardrail_intervened", - { - "role": "assistant", - "content": [{"text": "REDACTED."}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "REDACTED"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", + }, + }, + }, + }, + }, + { + "callback": { + "data": "Hello!", + "delta": { + "text": "Hello!", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", + }, + }, + }, + }, + { + "callback": { + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "guardrail_intervened", + { + "role": "assistant", + "content": [{"text": "REDACTED."}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], ), ], ) -def test_process_stream( - response, exp_stop_reason, exp_message, exp_usage, exp_metrics, exp_request_state, exp_messages -): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 - - tru_messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.process_stream(response, callback_handler, tru_messages) - ) - - assert tru_stop_reason == exp_stop_reason - assert tru_message == exp_message - assert tru_usage == exp_usage - assert tru_metrics == exp_metrics - assert tru_request_state == exp_request_state - assert tru_messages == exp_messages +def test_process_stream(response, exp_events): + messages = [{"role": "user", "content": [{"text": "Some input!"}]}] + stream = strands.event_loop.streaming.process_stream(response, messages) + tru_events = list(stream) + assert tru_events == exp_events -def test_stream_messages(agent): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 +def test_stream_messages(): mock_model = unittest.mock.MagicMock() mock_model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, ] - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.stream_messages( - mock_model, - model_id="test_model", - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - callback_handler=callback_handler, - agent=agent, - ) + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], + tool_config=None, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test"}]} - exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - exp_metrics = {"latencyMs": 0} - exp_request_state = {"calls": 1} - - assert ( - tru_stop_reason == exp_stop_reason - and tru_message == exp_message - and tru_usage == exp_usage - and tru_metrics == exp_metrics - and tru_request_state == exp_request_state - ) + tru_events = list(stream) + exp_events = [ + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", + }, + }, + }, + }, + }, + { + "callback": { + "data": "test", + "delta": { + "text": "test", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "stop": ( + "end_turn", + {"role": "assistant", "content": [{"text": "test"}]}, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ) + }, + ] + assert tru_events == exp_events mock_model.converse.assert_called_with( [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9421650e..a0cfc4d4 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -615,10 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -631,7 +645,10 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b326eee7..137b57c8 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -91,6 +91,23 @@ def test__init__default_model_id(bedrock_client): assert tru_model_id == exp_model_id +def test__init__with_default_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + default_region = "us-west-2" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name=default_region) + # Assert that warning logs are emitted + mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") + mock_warning.assert_any_call( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + + def test__init__with_custom_region(bedrock_client): """Test that BedrockModel uses the provided region.""" _ = bedrock_client diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 4e84f0fd..215e1efd 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -1,9 +1,13 @@ import dataclasses import unittest +from unittest import mock import pytest +from opentelemetry.metrics._internal import _ProxyMeter +from opentelemetry.sdk.metrics import MeterProvider import strands +from strands.telemetry import MetricsClient from strands.types.streaming import Metrics, Usage @@ -117,6 +121,37 @@ def test_trace_end(mock_time, end_time, trace): assert tru_end_time == exp_end_time +@pytest.fixture +def mock_get_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + MetricsClient._instance = None + meter_provider_mock = mock.MagicMock(spec=MeterProvider) + + mock_meter = mock.MagicMock() + mock_create_counter = mock.MagicMock() + mock_meter.create_counter.return_value = mock_create_counter + + mock_create_histogram = mock.MagicMock() + mock_meter.create_histogram.return_value = mock_create_histogram + meter_provider_mock.get_meter.return_value = mock_meter + + mock_get_meter_provider.return_value = meter_provider_mock + + yield mock_get_meter_provider + + +@pytest.fixture +def mock_sdk_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_sdk.MeterProvider") as mock_meter_provider: + yield mock_meter_provider + + +@pytest.fixture +def mock_resource(): + with mock.patch("opentelemetry.sdk.resources.Resource") as mock_resource: + yield mock_resource + + def test_trace_add_child(child_trace, trace): trace.add_child(child_trace) @@ -162,11 +197,14 @@ def test_trace_to_dict(trace): @pytest.mark.parametrize("success", [True, False]) -def test_tool_metrics_add_call(success, tool, tool_metrics): +def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provider): tool = dict(tool, **{"name": "updated"}) duration = 1 + metrics_client = MetricsClient() - tool_metrics.add_call(tool, duration, success) + attributes = {"foo": "bar"} + + tool_metrics.add_call(tool, duration, success, metrics_client, attributes=attributes) tru_attrs = dataclasses.asdict(tool_metrics) exp_attrs = { @@ -177,12 +215,17 @@ def test_tool_metrics_add_call(success, tool, tool_metrics): "total_time": duration, } + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client.tool_call_count.add.assert_called_with(1, attributes=attributes) + metrics_client.tool_duration.record.assert_called_with(duration, attributes=attributes) + if success: + metrics_client.tool_success_count.add.assert_called_with(1, attributes=attributes) assert tru_attrs == exp_attrs @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") -def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics): +def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 mock_uuid4.return_value = "i1" @@ -192,6 +235,8 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} exp_attrs = {"cycle_count": 1, "traces": [tru_cycle_trace]} + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_cycle_count.add.assert_called() assert ( tru_start_time == exp_start_time and tru_cycle_trace.to_dict() == exp_cycle_trace.to_dict() @@ -200,10 +245,11 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): +def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace) + attributes = {"foo": "bar"} + event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace, attributes=attributes) tru_cycle_durations = event_loop_metrics.cycle_durations exp_cycle_durations = [1] @@ -215,17 +261,23 @@ def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): assert tru_trace_end_time == exp_trace_end_time + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_end_cycle.add.assert_called_with(1, attributes) + metrics_client.event_loop_cycle_duration.record.assert_called() + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics): +def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} event_loop_metrics.add_tool_usage(tool, duration, trace, success, message) + mock_get_meter_provider.return_value.get_meter.assert_called() + tru_event_loop_metrics_attrs = {"tool_metrics": event_loop_metrics.tool_metrics} exp_event_loop_metrics_attrs = { "tool_metrics": { @@ -258,7 +310,7 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me assert tru_trace_attrs == exp_trace_attrs -def test_event_loop_metrics_update_usage(usage, event_loop_metrics): +def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_usage(usage) @@ -270,9 +322,13 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics): ) assert tru_usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_input_tokens.record.assert_called() + metrics_client.event_loop_output_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): +def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_metrics(metrics) @@ -282,9 +338,11 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): ) assert tru_metrics == exp_metrics + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) -def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics): +def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} @@ -379,3 +437,31 @@ def test_metrics_to_string(trace, child_trace, tool_metrics, exp_str, event_loop tru_str = strands.telemetry.metrics.metrics_to_string(event_loop_metrics) assert tru_str == exp_str + + +def test_setup_meter_if_meter_provider_is_set( + mock_get_meter_provider, + mock_resource, +): + """Test global meter_provider and meter are used""" + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + metrics_client = MetricsClient() + + mock_get_meter_provider.assert_called() + mock_get_meter_provider.return_value.get_meter.assert_called() + + assert metrics_client is not None + + +def test_use_ProxyMeter_if_no_global_meter_provider(): + """Return _ProxyMeter""" + # Reset the singleton instance + strands.telemetry.metrics.MetricsClient._instance = None + + # Create a new instance which should use the real _ProxyMeter + metrics_client = MetricsClient() + + # Verify it's using a _ProxyMeter + assert isinstance(metrics_client.meter, _ProxyMeter) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 030dcd37..6ae3e1ad 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -73,7 +73,9 @@ def mock_console_exporter(): @pytest.fixture def mock_resource(): - with mock.patch("strands.telemetry.tracer.Resource") as mock_resource: + with mock.patch("strands.telemetry.tracer.get_otel_resource") as mock_resource: + mock_resource_instance = mock.MagicMock() + mock_resource.return_value = mock_resource_instance yield mock_resource @@ -175,14 +177,12 @@ def test_initialize_tracer_with_console( ): """Test initializing the tracer with console exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance # Initialize Tracer Tracer(enable_console_export=True) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify console exporter was added mock_console_exporter.assert_called_once() @@ -198,9 +198,6 @@ def test_initialize_tracer_with_otlp( """Test initializing the tracer with OTLP exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), @@ -209,7 +206,7 @@ def test_initialize_tracer_with_otlp( Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was added with correct endpoint mock_otlp_exporter.assert_called_once() @@ -508,8 +505,6 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( """Test initializing the tracer with an invalid OTLP endpoint.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") # This should not raise an exception, but should log an error @@ -522,7 +517,7 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was attempted mock_otlp_exporter.assert_called_once() @@ -537,9 +532,6 @@ def test_initialize_tracer_with_missing_module( """Test initializing the tracer when the OTLP exporter module is missing.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with OTLP endpoint but missing module with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), @@ -552,13 +544,13 @@ def test_initialize_tracer_with_missing_module( assert "otel http exporting is currently DISABLED" in str(excinfo.value) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify set_tracer_provider was not called since an exception was raised mock_set_tracer_provider.assert_not_called() -def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): +def test_initialize_tracer_with_custom_tracer_provider(mock_is_initialized, mock_get_tracer_provider, mock_resource): """Test initializing the tracer with NoOpTracerProvider.""" mock_is_initialized.return_value = True tracer = Tracer(otlp_endpoint="http://invalid-endpoint") diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index a6ea45c3..4b238792 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -43,6 +43,12 @@ def tool_uses(request, tool_use): return request.param if hasattr(request, "param") else [tool_use] +@pytest.fixture +def mock_metrics_client(): + with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: + yield mock_metrics_client + + @pytest.fixture def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() @@ -303,6 +309,7 @@ def test_run_tools_creates_and_ends_span_on_success( mock_get_tracer, tool_handler, tool_uses, + mock_metrics_client, event_loop_metrics, request_state, invalid_tool_use_ids, diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 00000000..2e354b83 --- /dev/null +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,228 @@ +from typing import Literal, Optional + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_tool_spec +from strands.types.tools import ToolSpec + + +# Basic test model +class User(BaseModel): + """User model with name and age.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): + tool_spec = convert_pydantic_to_tool_spec(User) + + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + assert tool_spec == expected_spec + + +def test_convert_pydantic_to_tool_spec_complex(): + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): + name: str = Field(description="The name of the user") + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index f2797fe5..03690733 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,8 +1,16 @@ +from typing import Type + import pytest +from pydantic import BaseModel from strands.types.models import Model as SAModel +class Person(BaseModel): + name: str + age: int + + class TestModel(SAModel): def update_config(self, **model_config): return model_config @@ -10,6 +18,9 @@ def update_config(self, **model_config): def get_config(self): return + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + def format_request(self, messages, tool_specs, system_prompt): return { "messages": messages, @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt): }, ] assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 9db08bc9..3a1a940b 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,3 +1,4 @@ +import base64 import unittest.mock import pytest @@ -90,7 +91,24 @@ def system_prompt(): "image_url": { "detail": "auto", "format": "image/jpeg", - "url": "data:image/jpeg;base64,image", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", }, "type": "image_url", }, @@ -344,3 +362,15 @@ def test_format_chunk_unknown_type(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result