8000 feat: add structured output support using Pydantic models by theagenticguy · Pull Request #60 · strands-agents/sdk-python · GitHub
[go: up one dir, main page]

Skip to content

feat: add structured output support using Pydantic models #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e183907
feat: add structured output support using Pydantic models
theagenticguy May 20, 2025
03942ae
fix: import cleanups and unused vars
theagenticguy May 20, 2025
19a580d
Merge branch 'main' into feature/structured-output
theagenticguy Jun 5, 2025
510def6
feat: wip adding `structured_output` methods
theagenticguy Jun 5, 2025
c3ffbce
feat: wip added structured output to bedrock and anthropic
theagenticguy Jun 5, 2025
0f03889
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 5, 2025
dce0a81
feat: litellm structured output and some integ tests
theagenticguy Jun 7, 2025
5262dfc
feat: all structured outputs working, tbd llama api
theagenticguy Jun 8, 2025
2a1f5ed
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 8, 2025
23df2c6
feat: updated docstring
theagenticguy Jun 8, 2025
cc78b6f
fix: otel ci dep issue
theagenticguy Jun 8, 2025
e8ef600
fix: remove unnecessary changes and comments
theagenticguy Jun 9, 2025
6eeeaa8
feat: basic test WIP
theagenticguy Jun 9, 2025
51f1f1d
feat: better test coverage
theagenticguy Jun 9, 2025
d5bef96
fix: remove unused fixture
theagenticguy Jun 9, 2025
c66fa32
fix: resolve some comments
theagenticguy Jun 13, 2025
422bc25
fix: inline basemodel classes
theagenticguy Jun 13, 2025
eabf075
feat: update litellm, add checks
theagenticguy Jun 17, 2025
7194d6c
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
885d3ac
fix: autoformatting issue
theagenticguy Jun 17, 2025
7308491
feat: resolves comments
theagenticguy Jun 17, 2025
a88c93b
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
0216bcc
fix: ollama skip tests, pyproject whitespace diffs
theagenticguy Jun 18, 2025
49ccfb5
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 68 additions & 57 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ dynamic = ["version"]
description = "A model-driven approach to building AI agents in just a few lines of code"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
authors = [
{name = "AWS", email = "opensource@amazon.com"},
]
license = { text = "Apache-2.0" }
authors = [{ name = "AWS", email = "opensource@amazon.com" }]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
Expand Down Expand Up @@ -46,9 +44,7 @@ Documentation = "https://strandsagents.com"
packages = ["src/strands"]

[project.optional-dependencies]
anthropic = [
"anthropic>=0.21.0,<1.0.0",
]
anthropic = ["anthropic>=0.21.0,<1.0.0"]
dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
Expand Down Expand Up @@ -94,25 +90,16 @@ source = "vcs"
[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
"strands-agents @ {root:uri}"
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
"strands-agents @ {root:uri}",
]

[tool.hatch.envs.hatch-static-analysis.scripts]
format-check = [
"ruff format --check"
]
format-fix = [
"ruff format"
]
lint-check = [
"ruff check",
"mypy -p src"
]
lint-fix = [
"ruff check --fix"
]
format-check = ["ruff format --check"]
format-fix = ["ruff format"]
lint-check = ["ruff check", "mypy -p src"]
lint-fix = ["ruff check --fix"]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
Expand All @@ -123,11 +110,7 @@ extra-dependencies = [
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
]
extra-args = [
"-n",
"auto",
"-vv",
]
extra-args = ["-n", "auto", "-vv"]

[tool.hatch.envs.dev]
dev-mode = true
Expand All @@ -137,17 +120,14 @@ features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]


[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10"]


[tool.hatch.envs.hatch-test.scripts]
run = [
"pytest{env:HATCH_TEST_ARGS:} {args}"
]
run = ["pytest{env:HATCH_TEST_ARGS:} {args}"]
run-cov = [
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}",
]

cov-combine = []
Expand Down Expand Up @@ -204,17 +184,22 @@ ignore_missing_imports = true

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]
include = [
"examples/**/*.py",
"src/**/*.py",
"tests/**/*.py",
"tests-integ/**/*.py",
]

[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"D", # pydocstyle
"E", # pycodestyle
"F", # pyflakes
"G", # logging format
"I", # isort
"LOG", # logging
"B", # flake8-bugbear
"D", # pydocstyle
"E", # pycodestyle
"F", # pyflakes
"G", # logging format
"I", # isort
"LOG", # logging
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -224,9 +209,7 @@ select = [
convention = "google"

[tool.pytest.ini_options]
testpaths = [
"tests"
]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
Expand All @@ -249,19 +232,47 @@ output = "build/coverage/coverage.xml"
name = "cz_conventional_commits"
tag_format = "v$version"
bump_message = "chore(release): bump version $current_version -> $new_version"
version_files = [
"pyproject.toml:version",
]
version_files = ["pyproject.toml:version"]
update_changelog_on_bump = true
style = [
["qmark", "fg:#ff9d00 bold"],
["question", "bold"],
["answer", "fg:#ff9d00 bold"],
["pointer", "fg:#ff9d00 bold"],
["highlighted", "fg:#ff9d00 bold"],
["selected", "fg:#cc5454"],
["separator", "fg:#cc5454"],
["instruction", ""],
["text", ""],
["disabled", "fg:#858585 italic"]
[
"qmark",
"fg:#ff9d00 bold",
],
[
"question",
"bold",
],
[
"answer",
"fg:#ff9d00 bold",
],
[
"pointer",
"fg:#ff9d00 bold",
],
[
"highlighted",
"fg:#ff9d00 bold",
],
[
"selected",
"fg:#cc5454",
],
[
"separator",
"fg:#cc5454",
],
[
"instruction",
"",
],
[
"text",
"",
],
[
"disabled",
"fg:#858585 italic",
],
]
32 changes: 31 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from uuid import uuid4

from opentelemetry import trace
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
Expand All @@ -43,6 +44,9 @@

logger = logging.getLogger(__name__)

# TypeVar for generic structured output
T = TypeVar("T", bound=BaseModel)


# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
Expand Down Expand Up @@ -387,6 +391,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
# Re-raise the exception to preserve original behavior
raise

def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
"""Th F438 is method allows you to get structured output from the agent.

If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
If you don't pass in a prompt, it will use only the conversation history to respond.
If no conversation history exists and no prompt is provided, an error will be raised.

For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
instruct the model to output the structured data.

Args:
output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt(Optional[str]): The prompt to use for the agent.
"""
messages = self.messages
if not messages and not prompt:
raise ValueError("No conversation history or prompt provided")

# add the prompt as the last message
if prompt:
messages.append({"role": "user", "content": [{"text": prompt}]})

# get the structured output from the model
return self.model.structured_output(output_model, messages, self.callback_handler)

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.

Expand Down
51 changes: 48 additions & 3 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict, cast
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast

import anthropic
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override

from ..event_loop.streaming import process_stream
from ..handlers.callback_handler import PrintingCallbackHandler
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
Expand All @@ -20,6 +24,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class AnthropicModel(Model):
"""Anthropic model provider implementation."""
Expand Down Expand Up @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
with self.client.messages.stream(**request) as stream:
for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.dict()
yield event.model_dump()

usage = event.message.usage # type: ignore
yield {"type": "metadata", "usage": usage.dict()}
yield {"type": "metadata", "usage": usage.model_dump()}

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
Expand All @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
raise ContextWindowOverflowException(str(error)) from error

raise error

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
# process the stream and get the tool use input
results = process_stream(
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
)

stop_reason, messages, _, _, _ = results

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

content = messages["content"]
output_response: dict[str, Any] | None = None
for block in content:
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
# if the tool use name never matches, raise an error.
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
output_response = block["toolUse"]["input"]
else:
continue

if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

return output_model(**output_response)
Loading
Loading
0