diff --git a/pyproject.toml b/pyproject.toml index 56c1a40e..e0cc2578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -264,4 +264,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 56f5b92e..a5e26a07 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -43,6 +44,9 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -308,7 +312,6 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None - self.tool_caller = Agent.ToolCaller(self) @property @@ -387,6 +390,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages, self.callback_handler) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 02a56a1c..e336642c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -33,23 +33,6 @@ MAX_DELAY = 240 # 4 minutes -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs - - def event_loop_cycle( model: Model, system_prompt: Optional[str], @@ -107,7 +90,8 @@ def event_loop_cycle( event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) + if "request_state" not in kwargs: + kwargs["request_state"] = {} cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() kwargs["event_loop_cycle_trace"] = cycle_trace @@ -310,26 +294,6 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. - - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. - - Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. - - Returns: - Updated keyword arguments ready for the next cycle. - """ - # Store parent cycle ID - kwargs["event_loop_metrics"] = event_loop_metrics - kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - - return kwargs - - def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -403,7 +367,9 @@ def _handle_tool_execution( parallel_tool_executor=tool_execution_handler, ) - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + # Store parent cycle ID for the next cycle + kwargs["event_loop_metrics"] = event_loop_metrics + kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] tool_result_message: Message = { "role": "user", diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 57394e2c..ab427e53 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,11 +7,15 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -20,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield event.model_dump() usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + # process the stream and get the tool use input + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) + + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9bbcca7d..3de41198 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,13 +6,17 @@ import json import logging import os -from typing import Any, Iterable, List, Literal, Optional, cast +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -29,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -112,8 +118,17 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + region_for_boto = region_name or os.getenv("AWS_REGION") + if region_for_boto is None: + region_for_boto = "us-west-2" + logger.warning("defaulted to us-west-2 because no region was specified") + logger.warning( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + session = boto_session or boto3.Session( - region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", + region_name=region_for_boto, ) # Add strands-agents to the request user agent @@ -477,3 +492,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + # process the stream and get the tool use input + results = process_stream( + response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt + ) + + stop_reason, messages, _, _, _ = results + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 62f16d31..66138186 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,17 +3,22 @@ - Docs: https://docs.litellm.ai/ """ +import json import logging -from typing import Any, Optional, TypedDict, cast +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema +from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from .openai import OpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" @@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] } return super().format_request_message_content(content) + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + """ + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 583db2f2..755e07ad 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. + """ + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 7ed12216..b062fe14 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,9 +5,10 @@ import json import logging -from typing import Any, Iterable, Optional, cast +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6cbef664..783ce379 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,15 +4,20 @@ """ import logging -from typing import Any, Iterable, Optional, Protocol, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel from typing_extensions import Unpack, override +from ..types.content import Messages from ..types.models import OpenAIModel as SAOpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -125,3 +130,35 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: _ = event yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/telemetry/__init__.py b/src/strands/telemetry/__init__.py index 15981216..21dd6ebf 100644 --- a/src/strands/telemetry/__init__.py +++ b/src/strands/telemetry/__init__.py @@ -3,7 +3,8 @@ This module provides metrics and tracing functionality. """ -from .metrics import EventLoopMetrics, Trace, metrics_to_string +from .config import get_otel_resource +from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string from .tracer import Tracer, get_tracer __all__ = [ @@ -12,4 +13,6 @@ "metrics_to_string", "Tracer", "get_tracer", + "MetricsClient", + "get_otel_resource", ] diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py new file mode 100644 index 00000000..9f5a05fd --- /dev/null +++ b/src/strands/telemetry/config.py @@ -0,0 +1,33 @@ +"""OpenTelemetry configuration and setup utilities for Strands agents. + +This module provides centralized configuration and initialization functionality +for OpenTelemetry components and other telemetry infrastructure shared across Strands applications. +""" + +from importlib.metadata import version + +from opentelemetry.sdk.resources import Resource + + +def get_otel_resource() -> Resource: + """Create a standard OpenTelemetry resource with service information. + + This function implements a singleton pattern - it will return the same + Resource object for the same service_name parameter. + + Args: + service_name: Name of the service for OpenTelemetry. + + Returns: + Resource object with standard service information. + """ + resource = Resource.create( + { + "service.name": __name__, + "service.version": version("strands-agents"), + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.language": "python", + } + ) + + return resource diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index cd70819b..af940643 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -6,6 +6,10 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +import opentelemetry.metrics as metrics_api +from opentelemetry.metrics import Counter, Meter + +from ..telemetry import metrics_constants as constants from ..types.content import Message from ..types.streaming import Metrics, Usage from ..types.tools import ToolUse @@ -355,3 +359,45 @@ def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optio A formatted string representation of the metrics. """ return "\n".join(_metrics_summary_to_lines(event_loop_metrics, allowed_names or set())) + + +class MetricsClient: + """Singleton client for managing OpenTelemetry metrics instruments. + + The actual metrics export destination (console, OTLP endpoint, etc.) is configured + through OpenTelemetry SDK configuration by users, not by this client. + """ + + _instance: Optional["MetricsClient"] = None + meter: Meter + strands_agent_invocation_count: Counter + + def __new__(cls) -> "MetricsClient": + """Create or return the singleton instance of MetricsClient. + + Returns: + The single MetricsClient instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the MetricsClient. + + This method only runs once due to the singleton pattern. + Sets up the OpenTelemetry meter and creates metric instruments. + """ + if hasattr(self, "meter"): + return + + logger.info("Creating Strands MetricsClient") + meter_provider: metrics_api.MeterProvider = metrics_api.get_meter_provider() + self.meter = meter_provider.get_meter(__name__) + self.create_instruments() + + def create_instruments(self) -> None: + """Create and initialize all OpenTelemetry metric instruments.""" + self.strands_agent_invocation_count = self.meter.create_counter( + name=constants.STRANDS_AGENT_INVOCATION_COUNT, unit="Count" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py new file mode 100644 index 00000000..d3d3e81f --- /dev/null +++ b/src/strands/telemetry/metrics_constants.py @@ -0,0 +1,3 @@ +"""Metrics that are emitted in Strands-Agent.""" + +STRANDS_AGENT_INVOCATION_COUNT = "strands.agent.invocation_count" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index e9a37a4a..813c90e1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -8,20 +8,19 @@ import logging import os from datetime import date, datetime, timezone -from importlib.metadata import version from typing import Any, Dict, Mapping, Optional import opentelemetry.trace as trace_api from opentelemetry import propagate from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import Span, StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult +from ..telemetry import get_otel_resource from ..types.content import Message, Messages from ..types.streaming import Usage from ..types.tools import ToolResult, ToolUse @@ -151,7 +150,6 @@ def __init__( self.otlp_headers = otlp_headers or {} self.tracer_provider: Optional[trace_api.TracerProvider] = None self.tracer: Optional[trace_api.Tracer] = None - propagate.set_global_textmap( CompositePropagator( [ @@ -173,15 +171,7 @@ def _initialize_tracer(self) -> None: self.tracer = self.tracer_provider.get_tracer(self.service_name) return - # Create resource with service information - resource = Resource.create( - { - "service.name": self.service_name, - "service.version": version("strands-agents"), - "telemetry.sdk.name": "opentelemetry", - "telemetry.sdk.language": "python", - } - ) + resource = get_otel_resource() # Create tracer provider self.tracer_provider = SDKTracerProvider(resource=resource) @@ -216,6 +206,7 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) + except Exception as e: logger.exception("error=<%s> | Failed to configure OTLP exporter", e) elif self.otlp_endpoint and self.tracer_provider: diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..12979015 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 00000000..5421cdc6 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,415 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + +from ..types.tools import ToolSpec + + +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + required_props: list[str] = [] + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: + flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") + + return flattened + + +def _process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def _process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = _process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_tool_spec( + model: Type[BaseModel], + description: Optional[str] = None, +) -> ToolSpec: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + + Returns: + ToolSpec: Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + _process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + _expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = _flatten_schema(input_schema) + + final_schema = flattened_schema + + # Construct the tool specification + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) + + +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + _process_properties(ref_def, field_type) + + +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 23e74602..071c8a51 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,9 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar + +from pydantic import BaseModel from ..content import Messages from ..streaming import StreamEvent @@ -10,6 +12,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -38,6 +42,26 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + pass + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 96f758d5..a6bd93cc 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,8 +11,9 @@ import json import logging import mimetypes -from typing import Any, Optional, cast +from typing import Any, Callable, Optional, Type, TypeVar, cast +from pydantic import BaseModel from typing_extensions import override from ..content import ContentBlock, Messages @@ -22,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OpenAIModel(Model, abc.ABC): """Base OpenAI model provider implementation. @@ -57,7 +60,17 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") + image_bytes = content["image"]["source"]["bytes"] + try: + base64.b64decode(image_bytes, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images will not be accepted in a future version", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + image_bytes = base64.b64encode(image_bytes) + + image_data = image_bytes.decode("utf-8") return { "image_url": { "detail": "auto", @@ -262,3 +275,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c9..95bfceb5 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -47,3 +48,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny", "&"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index a6a29aa9..5378a9b2 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -118,3 +119,33 @@ def calculator(expression: str) -> float: agent("What is 123 + 456?") assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index 86f6b42f..01a3e121 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 00000000..38b46821 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index c9046ad5..b0790ba0 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -44,3 +45,22 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6f47be0..85d17544 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,6 +7,7 @@ from time import sleep import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -793,6 +794,31 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 734457aa..efdf7af8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -111,27 +111,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -465,19 +444,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -733,3 +699,76 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"] diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9421650e..a0cfc4d4 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -615,10 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -631,7 +645,10 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b326eee7..137b57c8 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -91,6 +91,23 @@ def test__init__default_model_id(bedrock_client): assert tru_model_id == exp_model_id +def test__init__with_default_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + default_region = "us-west-2" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name=default_region) + # Assert that warning logs are emitted + mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") + mock_warning.assert_any_call( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + + def test__init__with_custom_region(bedrock_client): """Test that BedrockModel uses the provided region.""" _ = bedrock_client diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 4e84f0fd..cafbd4bb 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -1,9 +1,13 @@ import dataclasses import unittest +from unittest import mock import pytest +from opentelemetry.metrics._internal import _ProxyMeter +from opentelemetry.sdk.metrics import MeterProvider import strands +from strands.telemetry import MetricsClient from strands.types.streaming import Metrics, Usage @@ -117,6 +121,30 @@ def test_trace_end(mock_time, end_time, trace): assert tru_end_time == exp_end_time +@pytest.fixture +def mock_get_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + meter_provider_mock = mock.MagicMock(spec=MeterProvider) + mock_get_meter_provider.return_value = meter_provider_mock + + mock_meter = mock.MagicMock() + meter_provider_mock.get_meter.return_value = mock_meter + + yield mock_get_meter_provider + + +@pytest.fixture +def mock_sdk_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_sdk.MeterProvider") as mock_meter_provider: + yield mock_meter_provider + + +@pytest.fixture +def mock_resource(): + with mock.patch("opentelemetry.sdk.resources.Resource") as mock_resource: + yield mock_resource + + def test_trace_add_child(child_trace, trace): trace.add_child(child_trace) @@ -379,3 +407,31 @@ def test_metrics_to_string(trace, child_trace, tool_metrics, exp_str, event_loop tru_str = strands.telemetry.metrics.metrics_to_string(event_loop_metrics) assert tru_str == exp_str + + +def test_setup_meter_if_meter_provider_is_set( + mock_get_meter_provider, + mock_resource, +): + """Test global meter_provider and meter are used""" + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + metrics_client = MetricsClient() + + mock_get_meter_provider.assert_called() + mock_get_meter_provider.return_value.get_meter.assert_called() + + assert metrics_client is not None + + +def test_use_ProxyMeter_if_no_global_meter_provider(): + """Return _ProxyMeter""" + # Reset the singleton instance + strands.telemetry.metrics.MetricsClient._instance = None + + # Create a new instance which should use the real _ProxyMeter + metrics_client = MetricsClient() + + # Verify it's using a _ProxyMeter + assert isinstance(metrics_client.meter, _ProxyMeter) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 030dcd37..6ae3e1ad 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -73,7 +73,9 @@ def mock_console_exporter(): @pytest.fixture def mock_resource(): - with mock.patch("strands.telemetry.tracer.Resource") as mock_resource: + with mock.patch("strands.telemetry.tracer.get_otel_resource") as mock_resource: + mock_resource_instance = mock.MagicMock() + mock_resource.return_value = mock_resource_instance yield mock_resource @@ -175,14 +177,12 @@ def test_initialize_tracer_with_console( ): """Test initializing the tracer with console exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance # Initialize Tracer Tracer(enable_console_export=True) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify console exporter was added mock_console_exporter.assert_called_once() @@ -198,9 +198,6 @@ def test_initialize_tracer_with_otlp( """Test initializing the tracer with OTLP exporter.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), @@ -209,7 +206,7 @@ def test_initialize_tracer_with_otlp( Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was added with correct endpoint mock_otlp_exporter.assert_called_once() @@ -508,8 +505,6 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( """Test initializing the tracer with an invalid OTLP endpoint.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance mock_otlp_exporter.side_effect = Exception("Connection error") # This should not raise an exception, but should log an error @@ -522,7 +517,7 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was attempted mock_otlp_exporter.assert_called_once() @@ -537,9 +532,6 @@ def test_initialize_tracer_with_missing_module( """Test initializing the tracer when the OTLP exporter module is missing.""" mock_is_initialized.return_value = False - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance - # Initialize Tracer with OTLP endpoint but missing module with ( mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), @@ -552,13 +544,13 @@ def test_initialize_tracer_with_missing_module( assert "otel http exporting is currently DISABLED" in str(excinfo.value) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify set_tracer_provider was not called since an exception was raised mock_set_tracer_provider.assert_not_called() -def test_initialize_tracer_with_custom_tracer_provider(mock_get_tracer_provider, mock_resource): +def test_initialize_tracer_with_custom_tracer_provider(mock_is_initialized, mock_get_tracer_provider, mock_resource): """Test initializing the tracer with NoOpTracerProvider.""" mock_is_initialized.return_value = True tracer = Tracer(otlp_endpoint="http://invalid-endpoint") diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 00000000..2e354b83 --- /dev/null +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,228 @@ +from typing import Literal, Optional + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_tool_spec +from strands.types.tools import ToolSpec + + +# Basic test model +class User(BaseModel): + """User model with name and age.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): + tool_spec = convert_pydantic_to_tool_spec(User) + + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + assert tool_spec == expected_spec + + +def test_convert_pydantic_to_tool_spec_complex(): + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): + name: str = Field(description="The name of the user") + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index f2797fe5..03690733 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,8 +1,16 @@ +from typing import Type + import pytest +from pydantic import BaseModel from strands.types.models import Model as SAModel +class Person(BaseModel): + name: str + age: int + + class TestModel(SAModel): def update_config(self, **model_config): return model_config @@ -10,6 +18,9 @@ def update_config(self, **model_config): def get_config(self): return + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + def format_request(self, messages, tool_specs, system_prompt): return { "messages": messages, @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt): }, ] assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 9db08bc9..2827969d 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,3 +1,4 @@ +import base64 import unittest.mock import pytest @@ -90,7 +91,24 @@ def system_prompt(): "image_url": { "detail": "auto", "format": "image/jpeg", - "url": "data:image/jpeg;base64,image", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", }, "type": "image_url", },