diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 7a4f8317..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,5 +0,0 @@ -# These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# @strands-agents/contributors will be requested for -# review when someone opens a pull request. -* @strands-agents/maintainers \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a6efda65..fa894231 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,37 +1,38 @@ ## Description -[Provide a detailed description of the changes in this PR] + ## Related Issues -[Link to related issues using #issue-number format] + + ## Documentation PR -[Link to related associated PR in the agent-docs repo] + + ## Type of Change -- Bug fix -- New feature -- Breaking change -- Documentation update -- Other (please describe): -[Choose one of the above types of changes] + +Bug fix +New feature +Breaking change +Documentation update +Other (please describe): ## Testing -[How have you tested the change?] -* `hatch fmt --linter` -* `hatch fmt --formatter` -* `hatch test --all` -* Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli +- [ ] I ran `hatch run prepare` ## Checklist - [ ] I have read the CONTRIBUTING document -- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] I have added any necessary tests that prove my fix is effective or my feature works - [ ] I have updated the documentation accordingly -- [ ] I have added an appropriate example to the documentation to outline the feature +- [ ] I have added an appropriate example to the documentation to outline the feature, or no new docs are needed - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published +---- + By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 00000000..39b53c49 --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,74 @@ +name: Secure Integration test + +on: + pull_request_target: + branches: main + +jobs: + authorization-check: + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result }} + steps: + - name: Collaborator Check + uses: actions/github-script@v7 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.pull_request.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) + return "manual-approval" + } + check-access-and-checkout: + runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + - name: Checkout head commit + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run integration tests + env: + AWS_REGION: us-east-1 + AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + id: tests + run: | + hatch test tests-integ + + diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml new file mode 100644 index 00000000..2b2d026f --- /dev/null +++ b/.github/workflows/pr-and-push.yml @@ -0,0 +1,19 @@ +name: Pull Request and Push Action + +on: + pull_request: # Safer than pull_request_target for untrusted code + branches: [ main ] + types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] + push: + branches: [ main ] # Also run on direct pushes to main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml new file mode 100644 index 00000000..8967c552 --- /dev/null +++ b/.github/workflows/pypi-publish-on-release.yml @@ -0,0 +1,82 @@ +name: Publish Python Package + +on: + release: + types: + - published + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + permissions: + contents: read + with: + ref: ${{ github.event.release.target_commitish }} + + build: + name: Build distribution πŸ“¦ + permissions: + contents: read + needs: + - call-test-lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch twine + + - name: Validate version + run: | + version=$(hatch version) + if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Valid version format" + exit 0 + else + echo "Invalid version format" + exit 1 + fi + + - name: Build + run: | + hatch build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + deploy: + name: Upload release to PyPI + needs: + - build + runs-on: ubuntu-latest + + # environment is used by PyPI Trusted Publisher and is strongly encouraged + # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ + environment: + name: pypi + url: https://pypi.org/p/strands-agents + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution πŸ“¦ to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml deleted file mode 100644 index 15fbebcb..00000000 --- a/.github/workflows/test-lint-pr.yml +++ /dev/null @@ -1,139 +0,0 @@ -name: Test and Lint - -on: - pull_request: # Safer than pull_request_target for untrusted code - branches: [ main ] - types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] - push: - branches: [ main ] # Also run on direct pushes to main - -jobs: - check-approval: - name: Check if PR has contributor approval - runs-on: ubuntu-latest - permissions: - pull-requests: read - # Skip this check for direct pushes to main - if: github.event_name == 'pull_request' - outputs: - approved: ${{ steps.check-approval.outputs.approved }} - steps: - - name: Check if PR has been approved by a contributor - id: check-approval - uses: actions/github-script@v7 - with: - script: | - const APPROVED_ASSOCIATION = ['COLLABORATOR', 'CONTRIBUTOR', 'MEMBER', 'OWNER'] - const PR_AUTHOR_ASSOCIATION = context.payload.pull_request.author_association; - const { data: reviews } = await github.rest.pulls.listReviews({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.issue.number, - }); - - const isApprovedContributor = APPROVED_ASSOCIATION.includes(PR_AUTHOR_ASSOCIATION); - - // Check if any contributor has approved - const isApproved = reviews.some(review => - review.state === 'APPROVED' && APPROVED_ASSOCIATION.includes(review.author_association) - ) || isApprovedContributor; - - core.setOutput('approved', isApproved); - - if (!isApproved) { - core.notice('This PR does not have approval from a Contributor. Workflow will not run test jobs.'); - return false; - } - - return true; - - unit-test: - name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} - needs: check-approval - permissions: - contents: read - # Only run if PR is approved or this is a direct push to main - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' - strategy: - matrix: - include: - # Linux - - os: ubuntu-latest - os-name: linux - python-version: "3.10" - - os: ubuntu-latest - os-name: linux - python-version: "3.11" - - os: ubuntu-latest - os-name: linux - python-version: "3.12" - - os: ubuntu-latest - os-name: linux - python-version: "3.13" - # Windows - - os: windows-latest - os-name: windows - python-version: "3.10" - - os: windows-latest - os-name: windows - python-version: "3.11" - - os: windows-latest - os-name: windows - python-version: "3.12" - - os: windows-latest - os-name: windows - python-version: "3.13" - # MacOS - latest only; not enough runners for MacOS - - os: macos-latest - os-name: macos - python-version: "3.13" - fail-fast: false - runs-on: ${{ matrix.os }} - env: - LOG_LEVEL: DEBUG - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to checkout - persist-credentials: false # Don't persist credentials for subsequent actions - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install --no-cache-dir hatch - - name: Run Unit tests - id: tests - run: hatch test tests --cover - continue-on-error: false - - lint: - name: Lint - runs-on: ubuntu-latest - needs: check-approval - permissions: - contents: read - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} - persist-credentials: false - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - cache: 'pip' - - - name: Install dependencies - run: | - pip install --no-cache-dir hatch - - - name: Run lint - id: lint - run: hatch run test-lint - continue-on-error: false diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml new file mode 100644 index 00000000..35e0f584 --- /dev/null +++ b/.github/workflows/test-lint.yml @@ -0,0 +1,94 @@ +name: Test and Lint + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + unit-test: + name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} + permissions: + contents: read + strategy: + matrix: + include: + # Linux + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.10" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.11" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.12" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.13" + # Windows + - os: windows-latest + os-name: 'windows' + python-version: "3.10" + - os: windows-latest + os-name: 'windows' + python-version: "3.11" + - os: windows-latest + os-name: 'windows' + python-version: "3.12" + - os: windows-latest + os-name: 'windows' + python-version: "3.13" + # MacOS - latest only; not enough runners for macOS + - os: macos-latest + os-name: 'macOS' + python-version: "3.13" + fail-fast: true + runs-on: ${{ matrix.os }} + env: + LOG_LEVEL: DEBUG + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref }} # Explicitly define which commit to check out + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run Unit tests + id: tests + run: hatch test tests --cover + continue-on-error: false + lint: + name: Lint + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + + - name: Run lint + id: lint + run: hatch run test-lint + continue-on-error: false diff --git a/.gitignore b/.gitignore index a80f4bd1..5cdc43db 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ __pycache__* .pytest_cache .ruff_cache *.bak +.vscode +dist \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e1ddbb89..fa724cdd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,16 +31,22 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ### Setting Up Your Development Environment -1. Install development dependencies: +1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. ```bash - pip install -e ".[dev]" && pip install -e ".[litellm] + hatch shell dev ``` + Alternatively, install development dependencies in a manually created virtual environment: + ```bash + pip install -e ".[dev]" && pip install -e ".[litellm]" + ``` + + 2. Set up pre-commit hooks: ```bash pre-commit install -t pre-commit -t commit-msg ``` - This will automatically run formatters and convention commit checks on your code before each commit. + This will automatically run formatters and conventional commit checks on your code before each commit. 3. Run code formatters manually: ```bash @@ -94,6 +100,8 @@ hatch fmt --linter If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. +For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). + ## Contributing via Pull Requests Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: @@ -115,7 +123,7 @@ To send us a pull request, please: ## Finding contributions to work on -Looking at the existing issues is a great way to find something to contribute on. +Looking at the existing issues is a great way to find something to contribute to. You can check: - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing diff --git a/README.md b/README.md index 08d6bff0..ed98d001 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ -# Strands Agents -
+
+ + Strands Agents + +
+ +

+ Strands Agents +

+

A model-driven approach to building AI agents in just a few lines of code.

@@ -8,16 +16,19 @@
GitHub commit activity GitHub open issues + GitHub open pull requests License PyPI version Python versions

- Docs - β—† Samples + Documentation + β—† Samples + β—† Python SDK β—† Tools β—† Agent Builder + β—† MCP Server

@@ -26,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Llama, Ollama, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools @@ -112,16 +123,17 @@ from strands.models.llamaapi import LlamaAPIModel bedrock_model = BedrockModel( model_id="us.amazon.nova-pro-v1:0", temperature=0.3, + streaming=True, # Enable/disable streaming ) agent = Agent(model=bedrock_model) agent("Tell me about Agentic AI") # Ollama -ollama_modal = OllamaModel( +ollama_model = OllamaModel( host="http://localhost:11434", model_id="llama3" ) -agent = Agent(model=ollama_modal) +agent = Agent(model=ollama_model) agent("Tell me about Agentic AI") # Llama API @@ -138,6 +150,7 @@ Built-in providers: - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) @@ -165,9 +178,9 @@ For detailed guidance & examples, explore our documentation: - [API Reference](https://strandsagents.com/latest/api-reference/agent/) - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) -## Contributing +## Contributing ❀️ -We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on: +We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on: - Reporting bugs & features - Development setup - Contributing via Pull Requests @@ -178,6 +191,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. +## Security + +See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. + ## ⚠️ Preview Status Strands Agents is currently in public preview. During this period: diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 00000000..51dc0a73 --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,59 @@ +# Style Guide + +## Overview + +The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable. + +Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden. + +## Log Formatting + +The format for Strands Agents logs is as follows: + +```python +logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...) +``` + +### Guidelines + +1. **Context**: + - Add context as `=` pairs at the beginning of the log + - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching + - Use `,`'s to separate pairs + - Enclose values in `<>` for readability + - This is particularly helpful in displaying empty values (`field=` vs `field=<>`) + - Use `%s` for string interpolation as recommended by Python logging + - This is an optimization to skip string interpolation when the log level is not enabled + +1. **Messages**: + - Add human-readable messages at the end of the log + - Use lowercase for consistency + - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter + - Keep messages concise and focused on a single statement + - If multiple statements are needed, separate them with the pipe character (`|`) + - Example: `"processing request | starting validation"` + +### Examples + +#### Good + +```python +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) +``` + +#### Poor + +```python +# Avoid: No structured fields, direct variable interpolation in message +logger.debug(f"User {user_id} performed action {action}") + +# Avoid: Inconsistent formatting, punctuation +logger.info("Request completed in %d ms.", duration) + +# Avoid: No separation between fields and message +logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts) +``` + +By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. diff --git a/pyproject.toml b/pyproject.toml index e3e3f372..4bb69ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.1" +dynamic = ["version"] description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" @@ -28,14 +28,13 @@ classifiers = [ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "docstring_parser>=0.15,<0.16.0", + "docstring_parser>=0.15,<1.0", "mcp>=1.8.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", - "opentelemetry-api>=1.33.0,<2.0.0", - "opentelemetry-sdk>=1.33.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.33.0,<2.0.0", + "opentelemetry-api>=1.30.0,<2.0.0", + "opentelemetry-sdk>=1.30.0,<2.0.0", ] [project.urls] @@ -54,12 +53,11 @@ dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", - "mypy>=0.981,<1.0.0", + "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -67,17 +65,34 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.69.0,<2.0.0", + "litellm>=1.72.6,<2.0.0", +] +llamaapi = [ + "llama-api-client>=0.1.0,<1.0.0", ] ollama = [ "ollama>=0.4.8,<1.0.0", ] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", +openai = [ + "openai>=1.68.0,<2.0.0", +] +otel = [ + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] +a2a = [ + "a2a-sdk>=0.2.6", + "uvicorn>=0.34.2", + "httpx>=0.28.1", + "fastapi>=0.115.12", + "starlette>=0.46.2", +] + +[tool.hatch.version] +# Tells Hatch to use your version control system (git) to determine the version. +source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -93,14 +108,15 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + # excluding due to A2A and OTEL http exporter dependency conflict + "mypy -p src --exclude src/strands/multiagent" ] lint-fix = [ "ruff check --fix" ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -114,17 +130,37 @@ extra-args = [ "-vv", ] +[tool.hatch.envs.dev] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"] + +[tool.hatch.envs.a2a] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] + +[tool.hatch.envs.a2a.scripts] +run = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}" +] +run-cov = [ + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}" +] +lint-check = [ + "ruff check", + "mypy -p src/strands/multiagent/a2a" +] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] - [tool.hatch.envs.hatch-test.scripts] run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a" ] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" + # excluding due to A2A and OTEL http exporter dependency conflict + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a" ] cov-combine = [] @@ -153,7 +189,15 @@ test = [ test-integ = [ "hatch test tests-integ {args}" ] - +prepare = [ + "hatch fmt --linter", + "hatch fmt --formatter", + "hatch test --all" +] +test-a2a = [ + # required to run manually due to A2A and OTEL http exporter dependency conflict + "hatch -e a2a run run {args}" +] [tool.mypy] python_version = "3.10" @@ -185,7 +229,9 @@ select = [ "D", # pydocstyle "E", # pycodestyle "F", # pyflakes + "G", # logging format "I", # isort + "LOG", # logging ] [tool.ruff.lint.per-file-ignores] @@ -235,4 +281,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 4d2fa1fe..6618d332 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,7 +8,12 @@ from .agent import Agent from .agent_result import AgentResult -from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager +from .conversation_manager import ( + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +) __all__ = [ "Agent", @@ -16,4 +21,5 @@ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", + "SummarizingConversationManager", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 89653036..4ecc42a9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -6,7 +6,7 @@ The Agent interface supports two complementary interaction patterns: 1. Natural language for conversation: `agent("Analyze this data")` -2. Method-style for direct tool access: `agent.tool_name(param1="value")` +2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ import asyncio @@ -16,10 +16,11 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from uuid import uuid4 from opentelemetry import trace +from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler @@ -43,6 +44,19 @@ logger = logging.getLogger(__name__) +# TypeVar for generic structured output +T = TypeVar("T", bound=BaseModel) + + +# Sentinel class and object to distinguish between explicit None and default parameter value +class _DefaultCallbackHandlerSentinel: + """Sentinel class to distinguish between explicit None and default parameter value.""" + + pass + + +_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() + class Agent: """Core Agent interface. @@ -70,10 +84,11 @@ def __init__(self, agent: "Agent") -> None: # agent tools and thus break their execution. self._agent = agent - def __getattr__(self, name: str) -> Callable: + def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: name: The name of the attribute (tool) being accessed. @@ -82,9 +97,30 @@ def __getattr__(self, name: str) -> Callable: A function that when called will execute the named tool. Raises: - AttributeError: If no tool with the given name exists. + AttributeError: If no tool with the given name exists or if multiple tools match the given name. """ + def find_normalized_tool_name() -> Optional[str]: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def caller(**kwargs: Any) -> Any: """Call a tool directly by name. @@ -105,14 +141,13 @@ def caller(**kwargs: Any) -> Any: Raises: AttributeError: If the tool doesn't exist. """ - if name not in self._agent.tool_registry.registry: - raise AttributeError(f"Tool '{name}' not found") + normalized_name = find_normalized_tool_name() # Create unique tool ID and set up the tool request tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" tool_use = { "toolUseId": tool_id, - "name": name, + "name": normalized_name, "input": kwargs.copy(), } @@ -165,7 +200,7 @@ def caller(**kwargs: Any) -> Any: self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages) # Apply window management - self._agent.conversation_manager.apply_management(self._agent.messages) + self._agent.conversation_manager.apply_management(self._agent) return tool_result @@ -177,12 +212,17 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, - callback_handler: Optional[Callable] = PrintingCallbackHandler(), + callback_handler: Optional[ + Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] + ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, ): """Initialize the Agent with the specified configuration. @@ -204,7 +244,8 @@ def __init__( system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. callback_handler: Callback for processing events as they happen during agent execution. - Defaults to strands.handlers.PrintingCallbackHandler if None. + If not provided (using the default), a new PrintingCallbackHandler instance is created. + If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. @@ -214,6 +255,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to True. trace_attributes: Custom trace attributes to apply to the agent's trace span. + name: name of the Agent + Defaults to None. + description: description of what the Agent does + Defaults to None. Raises: ValueError: If max_parallel_tools is less than 1. @@ -222,7 +267,17 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.callback_handler = callback_handler or null_callback_handler + + # If not provided, create a new PrintingCallbackHandler instance + # If explicitly set to None, use null_callback_handler + # Otherwise use the passed callback_handler + self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): + self.callback_handler = PrintingCallbackHandler() + elif callback_handler is None: + self.callback_handler = null_callback_handler + else: + self.callback_handler = callback_handler self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -264,8 +319,9 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() self.trace_span: Optional[trace.Span] = None - self.tool_caller = Agent.ToolCaller(self) + self.name = name + self.description = description @property def tool(self) -> ToolCaller: @@ -328,31 +384,47 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( - prompt=prompt, - model_id=model_id, - tools=self.tool_names, - system_prompt=self.system_prompt, - custom_trace_attributes=self.trace_attributes, - ) + self._start_agent_trace_span(prompt) try: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) return result except Exception as e: - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) # Re-raise the exception to preserve original behavior raise + def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + If no conversation history exists and no prompt is provided, an error will be raised. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt(Optional[str]): The prompt to use for the agent. + """ + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") + + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) + + # get the structured output from the model + return self.model.structured_output(output_model, messages, self.callback_handler) + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -383,6 +455,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ + self._start_agent_trace_span(prompt) + _stop_event = uuid4() queue = asyncio.Queue[Any]() @@ -400,8 +474,10 @@ def target_callback() -> None: nonlocal kwargs try: - self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - except BaseException as e: + result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) + self._end_agent_trace_span(response=result) + except Exception as e: + self._end_agent_trace_span(error=e) enqueue(e) finally: enqueue(_stop_event) @@ -414,14 +490,14 @@ def target_callback() -> None: item = await queue.get() if item == _stop_event: break - if isinstance(item, BaseException): + if isinstance(item, Exception): raise item yield item finally: thread.join() def _run_loop( - self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None + self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None ) -> AgentResult: """Execute the agent's event loop with the given prompt and parameters.""" try: @@ -445,9 +521,9 @@ def _run_loop( return self._execute_event_loop_cycle(invocation_callback_handler, kwargs) finally: - self.conversation_manager.apply_management(self.messages) + self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements @@ -457,27 +533,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str Returns: The result of the event loop cycle. """ - kwargs.pop("agent", None) - kwargs.pop("model", None) - kwargs.pop("system_prompt", None) - kwargs.pop("tool_execution_handler", None) - kwargs.pop("event_loop_metrics", None) - kwargs.pop("callback_handler", None) - kwargs.pop("tool_handler", None) - kwargs.pop("messages", None) - kwargs.pop("tool_config", None) + # Extract parameters with fallbacks to instance values + system_prompt = kwargs.pop("system_prompt", self.system_prompt) + model = kwargs.pop("model", self.model) + tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper) + event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics) + callback_handler_override = kwargs.pop("callback_handler", callback_handler) + tool_handler = kwargs.pop("tool_handler", self.tool_handler) + messages = kwargs.pop("messages", self.messages) + tool_config = kwargs.pop("tool_config", self.tool_config) + kwargs.pop("agent", None) # Remove agent to avoid conflicts try: # Execute the main event loop cycle stop_reason, message, metrics, state = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - callback_handler=callback_handler, - tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, - event_loop_metrics=self.event_loop_metrics, + model=model, + system_prompt=system_prompt, + messages=messages, # will be modified by event_loop_cycle + tool_config=tool_config, + callback_handler=callback_handler_override, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + event_loop_metrics=event_loop_metrics, agent=self, event_loop_parent_span=self.trace_span, **kwargs, @@ -488,8 +565,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self.messages, e=e) - return self._execute_event_loop_cycle(callback_handler, kwargs) + self.conversation_manager.reduce_context(self, e=e) + return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, @@ -515,7 +592,7 @@ def _record_tool_execution( """ # Create user message describing the tool call user_msg_content = [ - {"text": (f"agent.{tool['name']} direct tool call\nInput parameters: {json.dumps(tool['input'])}\n")} + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")} ] # Add override message if provided @@ -545,3 +622,43 @@ def _record_tool_execution( messages.append(tool_use_msg) messages.append(tool_result_msg) messages.append(assistant_msg) + + def _start_agent_trace_span(self, prompt: str) -> None: + """Starts a trace span for the agent. + + Args: + prompt: The natural language prompt from the user. + """ + model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + + self.trace_span = self.tracer.start_agent_span( + prompt=prompt, + model_id=model_id, + tools=self.tool_names, + system_prompt=self.system_prompt, + custom_trace_attributes=self.trace_attributes, + ) + + def _end_agent_trace_span( + self, + response: Optional[AgentResult] = None, + error: Optional[Exception] = None, + ) -> None: + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ + if self.trace_span: + trace_attributes: Dict[str, Any] = { + "span": self.trace_span, + } + + if response: + trace_attributes["response"] = response + if error: + trace_attributes["error"] = error + + self.tracer.end_agent_span(**trace_attributes) diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 68541877..c5962321 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -6,6 +6,8 @@ - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence +- SummarizingConversationManager: An implementation that summarizes older context instead + of simply trimming it Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -14,5 +16,11 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager +from .summarizing_conversation_manager import SummarizingConversationManager -__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"] +__all__ = [ + "ConversationManager", + "NullConversationManager", + "SlidingWindowConversationManager", + "SummarizingConversationManager", +] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index d18ae69a..dbccf941 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,9 +1,10 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ...types.content import Messages +if TYPE_CHECKING: + from ...agent.agent import Agent class ConversationManager(ABC): @@ -19,22 +20,22 @@ class ConversationManager(ABC): @abstractmethod # pragma: no cover - def apply_management(self, messages: Messages) -> None: - """Applies management strategy to the provided list of messages. + def apply_management(self, agent: "Agent") -> None: + """Applies management strategy to the provided agent. Processes the conversation history to maintain appropriate size by modifying the messages list in-place. Implementations should handle message pruning, summarization, or other size management techniques to keep the conversation context within desired bounds. Args: - messages: The conversation history to manage. + agent: The agent whose conversation history will be manage. This list is modified in-place. """ pass @abstractmethod # pragma: no cover - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. @@ -48,7 +49,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - Maintaining critical conversation markers Args: - messages: The conversation history to reduce. + agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. """ diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 2066c08b..4af4eb78 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,8 +1,10 @@ """Null implementation of conversation management.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent -from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -17,19 +19,19 @@ class NullConversationManager(ConversationManager): - Situations where the full conversation history should be preserved """ - def apply_management(self, messages: Messages) -> None: + def apply_management(self, _agent: "Agent") -> None: """Does nothing to the conversation history. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. """ pass - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None: """Does not reduce context and raises an exception. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. e: The exception that triggered the context reduction, if any. Raises: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 4b11e81c..53ac374f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,12 +1,13 @@ """Sliding window conversation history management.""" -import json import logging -from typing import List, Optional, cast +from typing import TYPE_CHECKING, Optional -from ...types.content import ContentBlock, Message, Messages +if TYPE_CHECKING: + from ...agent.agent import Agent + +from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException -from ...types.tools import ToolResult from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -43,17 +44,19 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: - window_size: Maximum number of messages to keep in history. + window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results - def apply_management(self, messages: Messages) -> None: - """Apply the sliding window to the messages array to maintain a manageable history size. + def apply_management(self, agent: "Agent") -> None: + """Apply the sliding window to the agent's messages array to maintain a manageable history size. This method is called after every event loop cycle, as the messages array may have been modified with tool results and assistant responses. It first removes any dangling messages that might create an invalid @@ -64,9 +67,10 @@ def apply_management(self, messages: Messages) -> None: blocks to maintain conversation coherence. Args: - messages: The messages to manage. + agent: The agent whose messages will be managed. This list is modified in-place. """ + messages = agent.messages self._remove_dangling_messages(messages) if len(messages) <= self.window_size: @@ -74,7 +78,7 @@ def apply_management(self, messages: Messages) -> None: "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size ) return - self.reduce_context(messages) + self.reduce_context(agent) def _remove_dangling_messages(self, messages: Messages) -> None: """Remove dangling messages that would create an invalid conversation state. @@ -107,14 +111,15 @@ def _remove_dangling_messages(self, messages: Messages) -> None: if not any("toolResult" in content for content in messages[-1]["content"]): messages.pop() - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. - The method handles special cases where tool results need to be converted to regular content blocks to maintain - conversation coherence after trimming. + The method handles special cases where trimming the messages leads to: + - toolResult with no corresponding toolUse + - toolUse with no corresponding toolResult Args: - messages: The messages to reduce. + agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. @@ -123,55 +128,107 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N Such as when the conversation is already minimal or when tool result messages cannot be properly converted. """ + messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Throw if we cannot trim any messages from the conversation - if trim_index >= len(messages): - raise ContextWindowOverflowException("Unable to trim conversation context!") from e - - # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the - # limitation of needing ToolUse and ToolResults to be paired. - if any("toolResult" in content for content in messages[trim_index]["content"]): - if len(messages[trim_index]["content"]) == 1: - messages[trim_index]["content"] = self._map_tool_result_content( - cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) + # Find the next valid trim_index + while trim_index < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[trim_index]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) ) - - # If there is more content than just one ToolResultContent, then we cannot cut at this index. + ): + trim_index += 1 else: - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + break + else: + # If we didn't find a valid trim_index, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") from e # Overwrite message history messages[:] = messages[trim_index:] - def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: - """Convert a ToolResult to a list of standard ContentBlocks. + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. - This method transforms tool result content into standard content blocks that can be preserved when trimming the - conversation history. + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. Args: - tool_result: The ToolResult to convert. + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. Returns: - A list of content blocks representing the tool result. + True if any changes were made to the message, False otherwise. """ - contents = [] - text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" - - for tool_result_content in tool_result["content"]: - if "text" in tool_result_content: - text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" - elif "json" in tool_result_content: - text_content = ( - "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", ) - elif "image" in tool_result_content: - contents.append(ContentBlock(image=tool_result_content["image"])) - elif "document" in tool_result_content: - contents.append(ContentBlock(document=tool_result_content["document"])) - else: - logger.warning("unsupported content type") - contents.append(ContentBlock(text=text_content)) - return contents + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py new file mode 100644 index 00000000..a6b112dd --- /dev/null +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -0,0 +1,222 @@ +"""Summarizing conversation history management with configurable options.""" + +import logging +from typing import TYPE_CHECKING, List, Optional + +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + + +logger = logging.getLogger(__name__) + + +DEFAULT_SUMMARIZATION_PROMPT = """You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information +* +## Tools Executed +* Tool X: Result Y""" + + +class SummarizingConversationManager(ConversationManager): + """Implements a summarizing window manager. + + This manager provides a configurable option to summarize older context instead of + simply trimming it, helping preserve important information while staying within + context limits. + """ + + def __init__( + self, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_agent: Optional["Agent"] = None, + summarization_system_prompt: Optional[str] = None, + ): + """Initialize the summarizing conversation manager. + + Args: + summary_ratio: Ratio of messages to summarize vs keep when context overflow occurs. + Value between 0.1 and 0.8. Defaults to 0.3 (summarize 30% of oldest messages). + preserve_recent_messages: Minimum number of recent messages to always keep. + Defaults to 10 messages. + summarization_agent: Optional agent to use for summarization instead of the parent agent. + If provided, this agent can use tools as part of the summarization process. + summarization_system_prompt: Optional system prompt override for summarization. + If None, uses the default summarization prompt. + """ + if summarization_agent is not None and summarization_system_prompt is not None: + raise ValueError( + "Cannot provide both summarization_agent and summarization_system_prompt. " + "Agents come with their own system prompt." + ) + + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) + self.preserve_recent_messages = preserve_recent_messages + self.summarization_agent = summarization_agent + self.summarization_system_prompt = summarization_system_prompt + + def apply_management(self, agent: "Agent") -> None: + """Apply management strategy to conversation history. + + For the summarizing conversation manager, no proactive management is performed. + Summarization only occurs when there's a context overflow that triggers reduce_context. + + Args: + agent: The agent whose conversation history will be managed. + The agent's messages list is modified in-place. + """ + # No proactive management - summarization only happens on context overflow + pass + + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: + """Reduce context using summarization. + + Args: + agent: The agent whose conversation history will be reduced. + The agent's messages list is modified in-place. + e: The exception that triggered the context reduction, if any. + + Raises: + ContextWindowOverflowException: If the context cannot be summarized. + """ + try: + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) + + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] + + # Generate summary + summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [summary_message] + remaining_messages + + except Exception as summarization_error: + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + """Generate a summary of the provided messages. + + Args: + messages: The messages to summarize. + agent: The agent instance to use for summarization. + + Returns: + A message containing the conversation summary. + + Raises: + Exception: If summary generation fails. + """ + # Choose which agent to use for summarization + summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + + # Save original system prompt and messages to restore later + original_system_prompt = summarization_agent.system_prompt + original_messages = summarization_agent.messages.copy() + + try: + # Only override system prompt if no agent was provided during initialization + if self.summarization_agent is None: + # Use custom system prompt if provided, otherwise use default + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + # Temporarily set the system prompt for summarization + summarization_agent.system_prompt = system_prompt + summarization_agent.messages = messages + + # Use the agent to generate summary with rich content (can use tools if needed) + result = summarization_agent("Please summarize this conversation.") + + return result.message + + finally: + # Restore original agent state + summarization_agent.system_prompt = original_system_prompt + summarization_agent.messages = original_messages + + def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. + + Uses the same logic as SlidingWindowConversationManager for consistency. + + Args: + messages: The full list of messages. + split_point: The initially calculated split point. + + Returns: + The adjusted split point that doesn't break ToolUse/ToolResult pairs. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + if split_point > len(messages): + raise ContextWindowOverflowException("Split point exceeds message array length") + + if split_point == len(messages): + return split_point + + # Find the next valid split_point + while split_point < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[split_point]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[split_point]["content"]) + and split_point + 1 < len(messages) + and not any("toolResult" in content for content in messages[split_point + 1]["content"]) + ) + ): + split_point += 1 + else: + break + else: + # If we didn't find a valid split_point, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split_point diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index 5a78bb3f..6dc0d9ee 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If trunction is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index db5a1b97..9580ea35 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,28 +22,15 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages logger = logging.getLogger(__name__) - -def initialize_state(**kwargs: Any) -> Any: - """Initialize the request state if not present. - - Creates an empty request_state dictionary if one doesn't already exist in the - provided keyword arguments. - - Args: - **kwargs: Keyword arguments that may contain a request_state. - - Returns: - The updated kwargs dictionary with request_state initialized if needed. - """ - if "request_state" not in kwargs: - kwargs["request_state"] = {} - return kwargs +MAX_ATTEMPTS = 6 +INITIAL_DELAY = 4 +MAX_DELAY = 240 # 4 minutes def event_loop_cycle( @@ -51,7 +38,7 @@ def event_loop_cycle( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, + callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface] = None, **kwargs: Any, @@ -101,10 +88,11 @@ def event_loop_cycle( kwargs["event_loop_cycle_id"] = uuid.uuid4() event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics()) - # Initialize state and get cycle trace - kwargs = initialize_state(**kwargs) - cycle_start_time, cycle_trace = event_loop_metrics.start_cycle() + if "request_state" not in kwargs: + kwargs["request_state"] = {} + attributes = {"event_loop_cycle_id": str(kwargs.get("event_loop_cycle_id"))} + cycle_start_time, cycle_trace = event_loop_metrics.start_cycle(attributes=attributes) kwargs["event_loop_cycle_trace"] = cycle_trace callback_handler(start=True) @@ -130,13 +118,10 @@ def event_loop_cycle( stop_reason: StopReason usage: Any metrics: Metrics - max_attempts = 6 - initial_delay = 4 - max_delay = 240 # 4 minutes - current_delay = initial_delay # Retry loop for handling throttling exceptions - for attempt in range(max_attempts): + current_delay = INITIAL_DELAY + for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( parent_span=cycle_span, @@ -145,14 +130,19 @@ def event_loop_cycle( ) try: - stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( - model, - system_prompt, - messages, - tool_config, - callback_handler, - **kwargs, - ) + # TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the + # call stack. At this point, we converted all events that were previously passed to the handler in + # `stream_messages` into yielded events that now have the "callback" key. To maintain backwards + # compatability, we need to combine the event with kwargs before passing to the handler. This we will + # revisit when migrating to strongly typed events. + for event in stream_messages(model, system_prompt, messages, tool_config): + if "callback" in event: + inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} + callback_handler(**inputs) + else: + stop_reason, message, usage, metrics = event["stop"] + kwargs.setdefault("request_state", {}) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage) break # Success! Break out of retry loop @@ -160,16 +150,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -177,7 +158,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -204,83 +185,38 @@ def event_loop_cycle( # If the model is requesting to use tools if stop_reason == "tool_use": - tool_uses: List[ToolUse] = [] - tool_results: List[ToolResult] = [] - invalid_tool_use_ids: List[str] = [] - - # Extract and validate tools - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - # Check if tools are available for execution - if tool_uses: - if tool_handler is None: - raise ValueError("toolUse present but tool handler not set") - if tool_config is None: - raise ValueError("toolUse present but tool config not set") - - # Create the tool handler process callable - tool_handler_process: Callable[[ToolUse], ToolResult] = partial( - tool_handler.process, - messages=messages, - model=model, - system_prompt=system_prompt, - tool_config=tool_config, - callback_handler=callback_handler, - **kwargs, + if not tool_handler: + raise EventLoopException( + Exception("Model requested tool use but no tool handler provided"), + kwargs["request_state"], ) - # Execute tools (parallel or sequential) - run_tools( - handler=tool_handler_process, - tool_uses=tool_uses, - event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, - parallel_tool_executor=tool_execution_handler, + if tool_config is None: + raise EventLoopException( + Exception("Model requested tool use but no tool config provided"), + kwargs["request_state"], ) - # Update state for the next cycle - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) - - # Create the tool result message - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } - messages.append(tool_result_message) - callback_handler(message=tool_result_message) - - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, message=message, tool_result_message=tool_result_message - ) - - # Check if we should stop the event loop - if kwargs["request_state"].get("stop_event_loop"): - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - return ( - stop_reason, - message, - event_loop_metrics, - kwargs["request_state"], - ) - - # Recursive call to continue the conversation - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) + # Handle tool execution + return _handle_tool_execution( + stop_reason, + message, + model, + system_prompt, + messages, + tool_config, + tool_handler, + callback_handler, + tool_execution_handler, + event_loop_metrics, + cycle_trace, + cycle_span, + cycle_start_time, + kwargs, + ) # End the cycle and return results - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) if cycle_span: tracer.end_event_loop_cycle_span( span=cycle_span, @@ -293,6 +229,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) @@ -359,21 +299,104 @@ def recurse_event_loop( ) -def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]: - """Prepare state for the next event loop cycle. +def _handle_tool_execution( + stop_reason: StopReason, + message: Message, + model: Model, + system_prompt: Optional[str], + messages: Messages, + tool_config: ToolConfig, + tool_handler: ToolHandler, + callback_handler: Callable[..., Any], + tool_execution_handler: Optional[ParallelToolExecutorInterface], + event_loop_metrics: EventLoopMetrics, + cycle_trace: Trace, + cycle_span: Any, + cycle_start_time: float, + kwargs: Dict[str, Any], +) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + tool_uses: List[ToolUse] = [] + tool_results: List[ToolResult] = [] + invalid_tool_use_ids: List[str] = [] - Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent - cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics. + """ + Handles the execution of tools requested by the model during an event loop cycle. Args: - kwargs: Current keyword arguments containing event loop state. - event_loop_metrics: The metrics object tracking event loop execution. + stop_reason (StopReason): The reason the model stopped generating. + message (Message): The message from the model that may contain tool use requests. + model (Model): The model provider instance. + system_prompt (Optional[str]): The system prompt instructions for the model. + messages (Messages): The conversation history messages. + tool_config (ToolConfig): Configuration for available tools. + tool_handler (ToolHandler): Handler for tool execution. + callback_handler (Callable[..., Any]): Callback for processing events as they happen. + tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. + event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. + cycle_trace (Trace): Trace object for the current event loop cycle. + cycle_span (Any): Span object for tracing the cycle (type may vary). + cycle_start_time (float): Start time of the current cycle. + kwargs (Dict[str, Any]): Additional keyword arguments, including request state. Returns: - Updated keyword arguments ready for the next cycle. + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + - The stop reason, + - The updated message, + - The updated event loop metrics, + - The updated request state. """ - # Store parent cycle ID + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + if not tool_uses: + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + tool_handler_process = partial( + tool_handler.process, + messages=messages, + model=model, + system_prompt=system_prompt, + tool_config=tool_config, + callback_handler=callback_handler, + **kwargs, + ) + + run_tools( + handler=tool_handler_process, + tool_uses=tool_uses, + event_loop_metrics=event_loop_metrics, + request_state=cast(Any, kwargs["request_state"]), + invalid_tool_use_ids=invalid_tool_use_ids, + tool_results=tool_results, + cycle_trace=cycle_trace, + parent_span=cycle_span, + parallel_tool_executor=tool_execution_handler, + ) + + # Store parent cycle ID for the next cycle kwargs["event_loop_metrics"] = event_loop_metrics kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] - return kwargs + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + messages.append(tool_result_message) + callback_handler(message=tool_result_message) + + if cycle_span: + tracer = get_tracer() + tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + + if kwargs["request_state"].get("stop_event_loop", False): + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + return recurse_event_loop( + model=model, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + **kwargs, + ) diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d7..4e1a39dc 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806f..0e9d472b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: return message -def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: +def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: """Handles the start of a content block by extracting tool usage information if any. Args: @@ -102,31 +102,31 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any -) -> Dict[str, Any]: + event: ContentBlockDeltaEvent, state: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. """ delta_content = event["delta"] + callback_event = {} + if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) + callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) + callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,29 +134,27 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_handler( - reasoningText=delta_content["reasoningContent"]["text"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoningText": delta_content["reasoningContent"]["text"], + "delta": delta_content, + "reasoning": True, + } elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_handler( - reasoning_signature=delta_content["reasoningContent"]["signature"], - delta=delta_content, - reasoning=True, - **kwargs, - ) + callback_event["callback"] = { + "reasoning_signature": delta_content["reasoningContent"]["signature"], + "delta": delta_content, + "reasoning": True, + } - return state + return state, callback_event -def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: +def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: """Handles the end of a content block by finalizing tool usage, text content, or reasoning content. Args: @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: Returns: Updated state with finalized content block. """ - content: List[ContentBlock] = state["content"] + content: list[ContentBlock] = state["content"] current_tool_use = state["current_tool_use"] text = state["text"] @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: return event["stopReason"] -def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: +def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: """Handles redacting content from the input or output. Args: @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: def process_stream( chunks: Iterable[StreamEvent], - callback_handler: Any, messages: Messages, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. - callback_handler: Callback for processing events as they happen. messages: The agents messages. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the constructed message, the usage metrics, and the updated request state. + The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" - state: Dict[str, Any] = { + state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, "text": "", "current_tool_use": {}, @@ -285,18 +278,16 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - kwargs.setdefault("request_state", {}) - for chunk in chunks: - # Callback handler call here allows each event to be visible to the caller - callback_handler(event=chunk) + yield {"callback": {"event": chunk}} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) + state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +297,7 @@ def process_stream( elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], messages, state) - return stop_reason, state["message"], usage, metrics, kwargs["request_state"] + yield {"stop": (stop_reason, state["message"], usage, metrics)} def stream_messages( @@ -314,9 +305,7 @@ def stream_messages( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, - **kwargs: Any, -) -> Tuple[StopReason, Message, Usage, Metrics, Any]: +) -> Generator[dict[str, Any], None, None]: """Streams messages to the model and processes the response. Args: @@ -324,12 +313,9 @@ def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_config: Configuration for the tools to use. - callback_handler: Callback for processing events as they happen. - **kwargs: Additional keyword arguments that will be passed to the callback handler. - And also returned in the request_state. Returns: - The reason for stopping, the final message, the usage metrics, and updated request state. + The reason for stopping, the final message, and the usage metrics """ logger.debug("model=<%s> | streaming messages", model) @@ -337,4 +323,4 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - return process_stream(chunks, callback_handler, messages, **kwargs) + yield from process_stream(chunks, messages) diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d6d104d8..4b794b4f 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None: Args: **kwargs: Callback event data including: - + - reasoningText (Optional[str]): Reasoning text to print if provided. - data (str): Text content to stream. - complete (bool): Whether this is the final chunk of a response. - current_tool_use (dict): Information about the current tool being used. """ + reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) current_tool_use = kwargs.get("current_tool_use", {}) + if reasoningText: + print(reasoningText, end="") + if data: print(data, end="" if not complete else "\n") diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index 0803eca5..bc4ec1ce 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -46,6 +46,7 @@ def preprocess( def process( self, tool: Any, + *, model: Model, system_prompt: Optional[str], messages: List[Any], diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 704114eb..51089d47 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,11 +7,15 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast import anthropic +from pydantic import BaseModel from typing_extensions import Required, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -20,6 +24,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class AnthropicModel(Model): """Anthropic model provider implementation.""" @@ -95,15 +101,21 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: Anthropic formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to an Anthropic-compatible format. """ if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") return { "source": { - "data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"), - "media_type": mimetypes.types_map.get( - f".{content['document']['format']}", "application/octet-stream" + "data": ( + content["document"]["source"]["bytes"].decode("utf-8") + if mime_type == "text/plain" + else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") ), - "type": "base64", + "media_type": mime_type, + "type": "text" if mime_type == "text/plain" else "base64", }, "title": content["document"]["name"], "type": "document", @@ -140,7 +152,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "toolResult" in content: return { "content": [ - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + self._format_request_message_content( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ) for tool_result_content in content["toolResult"]["content"] ], "is_error": content["toolResult"]["status"] == "error", @@ -148,7 +164,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "tool_result", } - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format an Anthropic messages array. @@ -189,6 +205,10 @@ def format_request( Returns: An Anthropic streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Anthropic-compatible + format. """ return { "max_tokens": self.config["max_tokens"], @@ -342,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: with self.client.messages.stream(**request) as stream: for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield event.dict() + yield event.model_dump() usage = event.message.usage # type: ignore - yield {"type": "metadata", "usage": usage.dict()} + yield {"type": "metadata", "usage": usage.model_dump()} except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -355,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise ContextWindowOverflowException(str(error)) from error raise error + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + + return output_model(**output_response) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b1b18b3e..a5ffb539 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -3,14 +3,20 @@ - Docs: https://aws.amazon.com/bedrock/ """ +import json import logging -from typing import Any, Iterable, Literal, Optional, cast +import os +from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig -from botocore.exceptions import ClientError, EventStreamError +from botocore.exceptions import ClientError +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from ..event_loop.streaming import process_stream +from ..handlers.callback_handler import PrintingCallbackHandler +from ..tools import convert_pydantic_to_tool_spec from ..types.content import Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.models import Model @@ -27,6 +33,8 @@ "too many total text bytes", ] +T = TypeVar("T", bound=BaseModel) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -60,6 +68,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0") stop_sequences: List of sequences that will stop generation when encountered + streaming: Flag to enable/disable streaming. Defaults to True. temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) """ @@ -80,6 +89,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Optional[int] model_id: str stop_sequences: Optional[list[str]] + streaming: Optional[bool] temperature: Optional[float] top_p: Optional[float] @@ -96,7 +106,8 @@ def __init__( Args: boto_session: Boto Session to use when calling the Bedrock Model. boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client. - region_name: AWS region to use for the Bedrock service. Defaults to "us-west-2". + region_name: AWS region to use for the Bedrock service. + Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set. **model_config: Configuration options for the Bedrock model. """ if region_name and boto_session: @@ -107,10 +118,33 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + region_for_boto = region_name or os.getenv("AWS_REGION") + if region_for_boto is None: + region_for_boto = "us-west-2" + logger.warning("defaulted to us-west-2 because no region was specified") + logger.warning( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + session = boto_session or boto3.Session( - region_name=region_name or "us-west-2", + region_name=region_for_boto, ) - client_config = boto_client_config or BotocoreConfig(user_agent_extra="strands-agents") + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + self.client = session.client( service_name="bedrock-runtime", config=client_config, @@ -130,7 +164,7 @@ def get_config(self) -> BedrockConfig: """Get the current Bedrock Model configuration. Returns: - The Bedrok model configuration. + The Bedrock model configuration. """ return self.config @@ -230,11 +264,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """ return cast(StreamEvent, event) + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: + """Check if guardrail data contains any blocked policies. + + Args: + guardrail_data: Guardrail data from trace information. + + Returns: + True if any blocked guardrail is detected, False otherwise. + """ + input_assessment = guardrail_data.get("inputAssessment", {}) + output_assessments = guardrail_data.get("outputAssessments", {}) + + # Check input assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()): + return True + + # Check output assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()): + return True + + return False + + def _generate_redaction_events(self) -> list[StreamEvent]: + """Generate redaction events based on configuration. + + Returns: + List of redaction events to yield. + """ + events: List[StreamEvent] = [] + + if self.config.get("guardrail_redact_input", True): + logger.debug("Redacting user input due to guardrail.") + events.append( + { + "redactContent": { + "redactUserContentMessage": self.config.get( + "guardrail_redact_input_message", "[User input redacted.]" + ) + } + } + ) + + if self.config.get("guardrail_redact_output", False): + logger.debug("Redacting assistant output due to guardrail.") + events.append( + { + "redactContent": { + "redactAssistantContentMessage": self.config.get( + "guardrail_redact_output_message", "[Assistant output redacted.]" + ) + } + } + ) + + return events + @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the Bedrock model and get the streaming response. + def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: + """Send the request to the Bedrock model and get the response. - This method calls the Bedrock converse_stream API and returns the stream of response events. + This method calls either the Bedrock converse_stream API or the converse API + based on the streaming parameter in the configuration. Args: request: The formatted request to send to the Bedrock model @@ -244,63 +335,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: Raises: ContextWindowOverflowException: If the input exceeds the model's context window. - EventStreamError: For all other Bedrock API errors. + ModelThrottledException: If the model service is throttling requests. """ + streaming = self.config.get("streaming", True) + try: - response = self.client.converse_stream(**request) - for chunk in response["stream"]: - if self.config.get("guardrail_redact_input", True) or self.config.get("guardrail_redact_output", False): + if streaming: + # Streaming implementation + response = self.client.converse_stream(**request) + for chunk in response["stream"]: if ( "metadata" in chunk and "trace" in chunk["metadata"] and "guardrail" in chunk["metadata"]["trace"] ): - inputAssessment = chunk["metadata"]["trace"]["guardrail"].get("inputAssessment", {}) - outputAssessments = chunk["metadata"]["trace"]["guardrail"].get("outputAssessments", {}) - - # Check if an input or output guardrail was triggered - if any( - self._find_detected_and_blocked_policy(assessment) - for assessment in inputAssessment.values() - ) or any( - self._find_detected_and_blocked_policy(assessment) - for assessment in outputAssessments.values() - ): - if self.config.get("guardrail_redact_input", True): - logger.debug("Found blocked input guardrail. Redacting input.") - yield { - "redactContent": { - "redactUserContentMessage": self.config.get( - "guardrail_redact_input_message", "[User input redacted.]" - ) - } - } - if self.config.get("guardrail_redact_output", False): - logger.debug("Found blocked output guardrail. Redacting output.") - yield { - "redactContent": { - "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", "[Assistant output redacted.]" - ) - } - } + guardrail_data = chunk["metadata"]["trace"]["guardrail"] + if self._has_blocked_guardrail(guardrail_data): + yield from self._generate_redaction_events() + yield chunk + else: + # Non-streaming implementation + response = self.client.converse(**request) + + # Convert and yield from the response + yield from self._convert_non_streaming_to_streaming(response) + + # Check for guardrail triggers after yielding any events (same as streaming path) + if ( + "trace" in response + and "guardrail" in response["trace"] + and self._has_blocked_guardrail(response["trace"]["guardrail"]) + ): + yield from self._generate_redaction_events() - yield chunk - except EventStreamError as e: - # Handle throttling that occurs mid-stream? - if "ThrottlingException" in str(e) and "ConverseStream" in str(e): - raise ModelThrottledException(str(e)) from e + except ClientError as e: + error_message = str(e) + + # Handle throttling error + if e.response["Error"]["Code"] == "ThrottlingException": + raise ModelThrottledException(error_message) from e - if any(overflow_message in str(e) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): + # Handle context window overflow + if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): logger.warning("bedrock threw context window overflow error") raise ContextWindowOverflowException(e) from e + + # Otherwise raise the error raise e - except ClientError as e: - # Handle throttling that occurs at the beginning of the call - if e.response["Error"]["Code"] == "ThrottlingException": - raise ModelThrottledException(str(e)) from e - raise + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: + """Convert a non-streaming response to the streaming format. + + Args: + response: The non-streaming response from the Bedrock model. + + Returns: + An iterable of response events in the streaming format. + """ + # Yield messageStart event + yield {"messageStart": {"role": response["output"]["message"]["role"]}} + + # Process content blocks + for content in response["output"]["message"]["content"]: + # Yield contentBlockStart event if needed + if "toolUse" in content: + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], + } + }, + } + } + + # For tool use, we need to yield the input as a delta + input_value = json.dumps(content["toolUse"]["input"]) + + yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + elif "text" in content: + # Then yield the text as a delta + yield { + "contentBlockDelta": { + "delta": {"text": content["text"]}, + } + } + elif "reasoningContent" in content: + # Then yield the reasoning content as a delta + yield { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}} + } + } + + if "signature" in content["reasoningContent"]["reasoningText"]: + yield { + "contentBlockDelta": { + "delta": { + "reasoningContent": { + "signature": content["reasoningContent"]["reasoningText"]["signature"] + } + } + } + } + + # Yield contentBlockStop event + yield {"contentBlockStop": {}} + + # Yield messageStop event + yield { + "messageStop": { + "stopReason": response["stopReason"], + "additionalModelResponseFields": response.get("additionalModelResponseFields"), + } + } + + # Yield metadata event + if "usage" in response or "metrics" in response or "trace" in response: + metadata: StreamEvent = {"metadata": {}} + if "usage" in response: + metadata["metadata"]["usage"] = response["usage"] + if "metrics" in response: + metadata["metadata"]["metrics"] = response["metrics"] + if "trace" in response: + metadata["metadata"]["trace"] = response["trace"] + yield metadata def _find_detected_and_blocked_policy(self, input: Any) -> bool: """Recursively checks if the assessment contains a detected and blocked guardrail. @@ -332,3 +492,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return self._find_detected_and_blocked_policy(item) # Otherwise return False return False + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + callback_handler = callback_handler or PrintingCallbackHandler() + tool_spec = convert_pydantic_to_tool_spec(output_model) + + response = self.converse(messages=prompt, tool_specs=[tool_spec]) + for event in process_stream(response, prompt): + if "callback" in event: + callback_handler(**event["callback"]) + else: + stop_reason, messages, _, _ = event["stop"] + + if stop_reason != "tool_use": + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + content = messages["content"] + output_response: dict[str, Any] | None = None + for block in content: + # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip. + # if the tool use name never matches, raise an error. + if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: + output_response = block["toolUse"]["input"] + else: + continue + + if output_response is None: + raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + + return output_model(**output_response) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index a7563133..66138186 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,21 +5,22 @@ import json import logging -import mimetypes -from typing import Any, Iterable, Optional, TypedDict +from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast import litellm +from litellm.utils import supports_response_schema +from pydantic import BaseModel from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages -from ..types.models import Model -from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from .openai import OpenAIModel logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) -class LiteLLMModel(Model): + +class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" class LiteLLMConfig(TypedDict, total=False): @@ -45,7 +46,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: https://github.com/BerriAI/litellm/blob/main/litellm/main.py. **model_config: Configuration options for the LiteLLM model. """ - self.config = LiteLLMModel.LiteLLMConfig(**model_config) + self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -68,9 +69,11 @@ def get_config(self) -> LiteLLMConfig: Returns: The LiteLLM model configuration. """ - return self.config + return cast(LiteLLMModel.LiteLLMConfig, self.config) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @override + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -78,19 +81,10 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LiteLLM formatted content block. - """ - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") - return { - "image_url": { - "detail": "auto", - "format": mime_type, - "url": f"data:{mime_type};base64,{image_data}", - }, - "type": "image_url", - } + Raises: + TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. + """ if "reasoningContent" in content: return { "signature": content["reasoningContent"]["reasoningText"]["signature"], @@ -98,9 +92,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "thinking", } - if "text" in content: - return {"text": content["text"], "type": "text"} - if "video" in content: return { "type": "video_url", @@ -110,230 +101,44 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An }, } - return {"text": json.dumps(content), "type": "text"} - - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: - """Format a LiteLLM tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - LiteLLM formatted tool call. - """ - return { - "function": { - "arguments": json.dumps(tool_use["input"]), - "name": tool_use["name"], - }, - "id": tool_use["toolUseId"], - "type": "function", - } - - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: - """Format a LiteLLM tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - LiteLLM formatted tool message. - """ - return { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), - } - - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format a LiteLLM messages array. - - Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM messages array. - """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] - - for message in messages: - contents = message["content"] - - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents, - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) - - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return super().format_request_message_content(content) @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format a LiteLLM chat streaming request. + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. - Returns: - A LiteLLM chat streaming request. """ - return { - "messages": self._format_request_messages(messages, system_prompt), - "model": self.config["model_id"], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": tool_spec["name"], - "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], - }, - } - for tool_spec in tool_specs or [] - ], - **(self.config.get("params") or {}), - } - - @override - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format the LiteLLM response events into standardized message chunks. - - Args: - event: A response event from the LiteLLM model. - - Returns: - The formatted chunk. - - Raises: - RuntimeError: If chunk_type is not recognized. - This error should never be encountered as we control chunk_type in the stream method. - """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - if event["data_type"] == "tool": - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - return {"contentBlockStart": {"start": {}}} - - case "content_delta": - if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} - - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - case "content_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, - "metrics": { - "latencyMs": 0, # TODO - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") - - @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the LiteLLM model and get the streaming response. - - Args: - request: The formatted request to send to the LiteLLM model. - - Returns: - An iterable of response events from the LiteLLM model. - """ - response = self.client.chat.completions.create(**request) - - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} - - tool_calls: dict[int, list[Any]] = {} - - for event in response: - choice = event.choices[0] - if choice.finish_reason: - break - - if choice.delta.content: - yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - yield {"chunk_type": "content_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start} - - for tool_delta in tool_deltas: - yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} - - yield {"chunk_type": "content_stop", "data_type": "tool"} - - yield {"chunk_type": "message_stop", "data": choice.finish_reason} - - event = next(response) - if hasattr(event, "usage"): - yield {"chunk_type": "metadata", "data": event.usage} + # The LiteLLM `Client` inits with Chat(). + # Chat() inits with self.completions + # completions() has a method `create()` which wraps the real completion API of Litellm + response = self.client.chat.completions.create( + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + if not supports_response_schema(self.get_config()["model_id"]): + raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + + # Find the first choice with tool_calls + for choice in response.choices: + if choice.finish_reason == "tool_calls": + try: + # Parse the tool call content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + # If no tool_calls found, raise an error + raise ValueError("No tool_calls found in response") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 307614db..755e07ad 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates """Llama API model provider. - Docs: https://llama.developer.meta.com/ @@ -7,10 +8,11 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -21,6 +23,8 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class LlamaAPIModel(Model): """Llama API model provider implementation.""" @@ -92,6 +96,9 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LllamaAPI formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format. """ if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") @@ -107,7 +114,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "text" in content: return {"text": content["text"], "type": "text"} - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: """Format a Llama API tool call. @@ -135,18 +142,30 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any Returns: Llama API formatted tool message. """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), + "content": [self._format_request_message_content(content) for content in contents], } def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a LlamaAPI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An LlamaAPI compatible messages array. + """ formatted_messages: list[dict[str, Any]] formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] @@ -196,6 +215,10 @@ def format_request( Returns: An Llama API chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to a LlamaAPI-compatible + format. """ request = { "messages": self._format_request_messages(messages, system_prompt), @@ -364,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: # we may have a metrics event here if metrics_event: yield {"chunk_type": "metadata", "data": metrics_event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Raises: + NotImplementedError: Structured output is not currently supported for LlamaAPI models. + """ + # response_format: ResponseFormat = { + # "type": "json_schema", + # "json_schema": { + # "name": output_model.__name__, + # "schema": output_model.model_json_schema(), + # }, + # } + # response = self.client.chat.completions.create( + # model=self.config["model_id"], + # messages=self.format_request(prompt)["messages"], + # response_format=response_format, + # ) + raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index f632da34..b062fe14 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,19 +5,21 @@ import json import logging -from typing import Any, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast from ollama import Client as OllamaClient +from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Message, Messages -from ..types.media import DocumentContent, ImageContent +from ..types.content import ContentBlock, Messages from ..types.models import Model from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolSpec logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class OllamaModel(Model): """Ollama model provider implementation. @@ -92,31 +94,31 @@ def get_config(self) -> OllamaConfig: """ return self.config - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format an Ollama chat streaming request. + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + """Format Ollama compatible message contents. + + Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks. Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. + role: E.g., user. + content: Content block to format. Returns: - An Ollama chat streaming request. - """ + Ollama formatted message contents. - def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: - if "text" in content: - return {"role": message["role"], "content": content["text"]} + Raises: + TypeError: If the content block type cannot be converted to an Ollama-compatible format. + """ + if "text" in content: + return [{"role": role, "content": content["text"]}] - if "image" in content: - return {"role": message["role"], "images": [content["image"]["source"]["bytes"]]} + if "image" in content: + return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] - if "toolUse" in content: - return { - "role": "assistant", + if "toolUse" in content: + return [ + { + "role": role, "tool_calls": [ { "function": { @@ -126,45 +128,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: } ], } + ] + + if "toolResult" in content: + return [ + formatted_tool_result_content + for tool_result_content in content["toolResult"]["content"] + for formatted_tool_result_content in self._format_request_message_contents( + "tool", + ( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ), + ) + ] - if "toolResult" in content: - result_content: Union[str, ImageContent, DocumentContent, Any] = None - result_images = [] - for tool_result_content in content["toolResult"]["content"]: - if "text" in tool_result_content: - result_content = tool_result_content["text"] - elif "json" in tool_result_content: - result_content = tool_result_content["json"] - elif "image" in tool_result_content: - result_content = "see images" - result_images.append(tool_result_content["image"]["source"]["bytes"]) - else: - result_content = content["toolResult"]["content"] + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - return { - "role": "tool", - "content": json.dumps( - { - "name": content["toolResult"]["toolUseId"], - "result": result_content, - "status": content["toolResult"]["status"], - } - ), - **({"images": result_images} if result_images else {}), - } + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an Ollama compatible messages array. - return {"role": message["role"], "content": json.dumps(content)} + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. - def format_messages() -> list[dict[str, Any]]: - return [format_message(message, content) for message in messages for content in message["content"]] + Returns: + An Ollama compatible messages array. + """ + system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - formatted_messages = format_messages() + return system_message + [ + formatted_message + for message in messages + for content in message["content"] + for formatted_message in self._format_request_message_contents(message["role"], content) + ] + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Ollama chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Ollama chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible + format. + """ return { - "messages": [ - *([{"role": "system", "content": system_prompt}] if system_prompt else []), - *formatted_messages, - ], + "messages": self._format_request_messages(messages, system_prompt), "model": self.config["model_id"], "options": { **(self.config.get("options") or {}), @@ -213,52 +233,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: RuntimeError: If chunk_type is not recognized. This error should never be encountered as we control chunk_type in the stream method. """ - if event["chunk_type"] == "message_start": - return {"messageStart": {"role": "assistant"}} - - if event["chunk_type"] == "content_start": - if event["data_type"] == "text": - return {"contentBlockStart": {"start": {}}} - - tool_name = event["data"].function.name - return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} - - if event["chunk_type"] == "content_delta": - if event["data_type"] == "text": - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - tool_arguments = event["data"].function.arguments - return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} - - if event["chunk_type"] == "content_stop": - return {"contentBlockStop": {}} - - if event["chunk_type"] == "message_stop": - reason: StopReason - if event["data"] == "tool_use": - reason = "tool_use" - elif event["data"] == "length": - reason = "max_tokens" - else: - reason = "end_turn" - - return {"messageStop": {"stopReason": reason}} - - if event["chunk_type"] == "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].eval_count, - "outputTokens": event["data"].prompt_eval_count, - "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, - }, - "metrics": { - "latencyMs": event["data"].total_duration / 1e6, + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + tool_name = event["data"].function.name + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + tool_arguments = event["data"].function.arguments + return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + reason: StopReason + if event["data"] == "tool_use": + reason = "tool_use" + elif event["data"] == "length": + reason = "max_tokens" + else: + reason = "end_turn" + + return {"messageStop": {"stopReason": reason}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].eval_count, + "outputTokens": event["data"].prompt_eval_count, + "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, + }, + "metrics": { + "latencyMs": event["data"].total_duration / 1e6, + }, }, - }, - } + } - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @@ -291,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "content_stop", "data_type": "text"} yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason} yield {"chunk_type": "metadata", "data": event} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + formatted_request = self.format_request(messages=prompt) + formatted_request["format"] = output_model.model_json_schema() + formatted_request["stream"] = False + response = self.client.chat(**formatted_request) + + try: + content = response.message.content.strip() + return output_model.model_validate_json(content) + except Exception as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py new file mode 100644 index 00000000..783ce379 --- /dev/null +++ b/src/strands/models/openai.py @@ -0,0 +1,164 @@ +"""OpenAI model provider. + +- Docs: https://platform.openai.com/docs/overview +""" + +import logging +from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast + +import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import Messages +from ..types.models import OpenAIModel as SAOpenAIModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class Client(Protocol): + """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" + + @property + # pragma: no cover + def chat(self) -> Any: + """Chat completions interface.""" + ... + + +class OpenAIModel(SAOpenAIModel): + """OpenAI model provider implementation.""" + + client: Client + + class OpenAIConfig(TypedDict, total=False): + """Configuration options for OpenAI models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_tokens). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/chat/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI model. + """ + self.config = dict(model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = openai.OpenAI(**client_args) + + @override + def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] + """Update the OpenAI model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIConfig: + """Get the OpenAI model configuration. + + Returns: + The OpenAI model configuration. + """ + return cast(OpenAIModel.OpenAIConfig, self.config) + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the OpenAI model and get the streaming response. + + Args: + request: The formatted request to send to the OpenAI model. + + Returns: + An iterable of response events from the OpenAI model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + return parsed + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py new file mode 100644 index 00000000..1cef1425 --- /dev/null +++ b/src/strands/multiagent/__init__.py @@ -0,0 +1,13 @@ +"""Multiagent capabilities for Strands Agents. + +This module provides support for multiagent systems, including agent-to-agent (A2A) +communication protocols and coordination mechanisms. + +Submodules: + a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables + standardized communication between agents. +""" + +from . import a2a + +__all__ = ["a2a"] diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py new file mode 100644 index 00000000..c5425618 --- /dev/null +++ b/src/strands/multiagent/a2a/__init__.py @@ -0,0 +1,14 @@ +"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. + +This module provides classes and utilities for enabling Strands Agents to communicate +with other agents using the Agent-to-Agent (A2A) protocol. + +Docs: https://google-a2a.github.io/A2A/latest/ + +Classes: + A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. +""" + +from .agent import A2AAgent + +__all__ = ["A2AAgent"] diff --git a/src/strands/multiagent/a2a/agent.py b/src/strands/multiagent/a2a/agent.py new file mode 100644 index 00000000..4359100d --- /dev/null +++ b/src/strands/multiagent/a2a/agent.py @@ -0,0 +1,149 @@ +"""A2A-compatible wrapper for Strands Agent. + +This module provides the A2AAgent class, which adapts a Strands Agent to the A2A protocol, +allowing it to be used in A2A-compatible systems. +""" + +import logging +from typing import Any, Literal + +import uvicorn +from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from fastapi import FastAPI +from starlette.applications import Starlette + +from ...agent.agent import Agent as SAAgent +from .executor import StrandsA2AExecutor + +logger = logging.getLogger(__name__) + + +class A2AAgent: + """A2A-compatible wrapper for Strands Agent.""" + + def __init__( + self, + agent: SAAgent, + *, + # AgentCard + host: str = "0.0.0.0", + port: int = 9000, + version: str = "0.0.1", + ): + """Initialize an A2A-compatible agent from a Strands agent. + + Args: + agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. + host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". + port: The port to bind the A2A server to. Defaults to 9000. + version: The version of the agent. Defaults to "0.0.1". + """ + self.host = host + self.port = port + self.http_url = f"http://{self.host}:{self.port}/" + self.version = version + self.strands_agent = agent + self.name = self.strands_agent.name + self.description = self.strands_agent.description + # TODO: enable configurable capabilities and request handler + self.capabilities = AgentCapabilities() + self.request_handler = DefaultRequestHandler( + agent_executor=StrandsA2AExecutor(self.strands_agent), + task_store=InMemoryTaskStore(), + ) + logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + + @property + def public_agent_card(self) -> AgentCard: + """Get the public AgentCard for this agent. + + The AgentCard contains metadata about the agent, including its name, + description, URL, version, skills, and capabilities. This information + is used by other agents and systems to discover and interact with this agent. + + Returns: + AgentCard: The public agent card containing metadata about this agent. + + Raises: + ValueError: If name or description is None or empty. + """ + if not self.name: + raise ValueError("A2A agent name cannot be None or empty") + if not self.description: + raise ValueError("A2A agent description cannot be None or empty") + + return AgentCard( + name=self.name, + description=self.description, + url=self.http_url, + version=self.version, + skills=self.agent_skills, + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=self.capabilities, + ) + + @property + def agent_skills(self) -> list[AgentSkill]: + """Get the list of skills this agent provides. + + Skills represent specific capabilities that the agent can perform. + Strands agent tools are adapted to A2A skills. + + Returns: + list[AgentSkill]: A list of skills this agent provides. + """ + # TODO: translate Strands tools (native & MCP) to skills + return [] + + def to_starlette_app(self) -> Starlette: + """Create a Starlette application for serving this agent via HTTP. + + This method creates a Starlette application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + Starlette: A Starlette application configured to serve this agent. + """ + return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def to_fastapi_app(self) -> FastAPI: + """Create a FastAPI application for serving this agent via HTTP. + + This method creates a FastAPI application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + FastAPI: A FastAPI application configured to serve this agent. + """ + return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + def serve(self, app_type: Literal["fastapi", "starlette"] = "starlette", **kwargs: Any) -> None: + """Start the A2A server with the specified application type. + + This method starts an HTTP server that exposes the agent via the A2A protocol. + The server can be implemented using either FastAPI or Starlette, depending on + the specified app_type. + + Args: + app_type: The type of application to serve, either "fastapi" or "starlette". + Defaults to "starlette". + **kwargs: Additional keyword arguments to pass to uvicorn.run. + """ + try: + logger.info("Starting Strands A2A server...") + if app_type == "fastapi": + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) + else: + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) + except KeyboardInterrupt: + logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") + except Exception: + logger.exception("Strands A2A server encountered exception.") + finally: + logger.info("Strands A2A server has shutdown.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py new file mode 100644 index 00000000..b7a7af09 --- /dev/null +++ b/src/strands/multiagent/a2a/executor.py @@ -0,0 +1,67 @@ +"""Strands Agent executor for the A2A protocol. + +This module provides the StrandsA2AExecutor class, which adapts a Strands Agent +to be used as an executor in the A2A protocol. It handles the execution of agent +requests and the conversion of Strands Agent responses to A2A events. +""" + +import logging + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message +from a2a.utils.errors import ServerError + +from ...agent.agent import Agent as SAAgent +from ...agent.agent_result import AgentResult as SAAgentResult + +log = logging.getLogger(__name__) + + +class StrandsA2AExecutor(AgentExecutor): + """Executor that adapts a Strands Agent to the A2A protocol.""" + + def __init__(self, agent: SAAgent): + """Initialize a StrandsA2AExecutor. + + Args: + agent: The Strands Agent to adapt to the A2A protocol. + """ + self.agent = agent + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute a request using the Strands Agent and send the response as A2A events. + + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. + + Args: + context: The A2A request context, containing the user's input and other metadata. + event_queue: The A2A event queue, used to send response events. + """ + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel an ongoing execution. + + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. + + Args: + context: The A2A request context. + event_queue: The A2A event queue. + + Raises: + ServerError: Always raised with an UnsupportedOperationError, as cancellation + is not currently supported. + """ + raise ServerError(error=UnsupportedOperationError()) diff --git a/src/strands/telemetry/__init__.py b/src/strands/telemetry/__init__.py index 15981216..21dd6ebf 100644 --- a/src/strands/telemetry/__init__.py +++ b/src/strands/telemetry/__init__.py @@ -3,7 +3,8 @@ This module provides metrics and tracing functionality. """ -from .metrics import EventLoopMetrics, Trace, metrics_to_string +from .config import get_otel_resource +from .metrics import EventLoopMetrics, MetricsClient, Trace, metrics_to_string from .tracer import Tracer, get_tracer __all__ = [ @@ -12,4 +13,6 @@ "metrics_to_string", "Tracer", "get_tracer", + "MetricsClient", + "get_otel_resource", ] diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py new file mode 100644 index 00000000..9f5a05fd --- /dev/null +++ b/src/strands/telemetry/config.py @@ -0,0 +1,33 @@ +"""OpenTelemetry configuration and setup utilities for Strands agents. + +This module provides centralized configuration and initialization functionality +for OpenTelemetry components and other telemetry infrastructure shared across Strands applications. +""" + +from importlib.metadata import version + +from opentelemetry.sdk.resources import Resource + + +def get_otel_resource() -> Resource: + """Create a standard OpenTelemetry resource with service information. + + This function implements a singleton pattern - it will return the same + Resource object for the same service_name parameter. + + Args: + service_name: Name of the service for OpenTelemetry. + + Returns: + Resource object with standard service information. + """ + resource = Resource.create( + { + "service.name": __name__, + "service.version": version("strands-agents"), + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.language": "python", + } + ) + + return resource diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index cd70819b..332ab2ae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -6,6 +6,10 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +import opentelemetry.metrics as metrics_api +from opentelemetry.metrics import Counter, Histogram, Meter + +from ..telemetry import metrics_constants as constants from ..types.content import Message from ..types.streaming import Metrics, Usage from ..types.tools import ToolUse @@ -117,22 +121,34 @@ class ToolMetrics: error_count: int = 0 total_time: float = 0.0 - def add_call(self, tool: ToolUse, duration: float, success: bool) -> None: + def add_call( + self, + tool: ToolUse, + duration: float, + success: bool, + metrics_client: "MetricsClient", + attributes: Optional[Dict[str, Any]] = None, + ) -> None: """Record a new tool call with its outcome. Args: tool: The tool that was called. duration: How long the call took in seconds. success: Whether the call was successful. + metrics_client: The metrics client for recording the metrics. + attributes: attributes of the metrics. """ self.tool = tool # Update with latest tool state self.call_count += 1 self.total_time += duration - + metrics_client.tool_call_count.add(1, attributes=attributes) + metrics_client.tool_duration.record(duration, attributes=attributes) if success: self.success_count += 1 + metrics_client.tool_success_count.add(1, attributes=attributes) else: self.error_count += 1 + metrics_client.tool_error_count.add(1, attributes=attributes) @dataclass @@ -155,32 +171,53 @@ class EventLoopMetrics: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) - def start_cycle(self) -> Tuple[float, Trace]: + @property + def _metrics_client(self) -> "MetricsClient": + """Get the singleton MetricsClient instance.""" + return MetricsClient() + + def start_cycle( + self, + attributes: Optional[Dict[str, Any]] = None, + ) -> Tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. + Args: + attributes: attributes of the metrics. + Returns: A tuple containing the start time and the cycle trace object. """ + self._metrics_client.event_loop_cycle_count.add(1, attributes=attributes) + self._metrics_client.event_loop_start_cycle.add(1, attributes=attributes) self.cycle_count += 1 start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: """End the current event loop cycle and record its duration. Args: start_time: The timestamp when the cycle started. cycle_trace: The trace object for this cycle. + attributes: attributes of the metrics. """ + self._metrics_client.event_loop_end_cycle.add(1, attributes) end_time = time.time() duration = end_time - start_time + self._metrics_client.event_loop_cycle_duration.record(duration, attributes) self.cycle_durations.append(duration) cycle_trace.end(end_time) def add_tool_usage( - self, tool: ToolUse, duration: float, tool_trace: Trace, success: bool, message: Message + self, + tool: ToolUse, + duration: float, + tool_trace: Trace, + success: bool, + message: Message, ) -> None: """Record metrics for a tool invocation. @@ -203,8 +240,16 @@ def add_tool_usage( tool_trace.raw_name = f"{tool_name} - {tool_use_id}" tool_trace.add_message(message) - self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call(tool, duration, success) - + self.tool_metrics.setdefault(tool_name, ToolMetrics(tool)).add_call( + tool, + duration, + success, + self._metrics_client, + attributes={ + "tool_name": tool_name, + "tool_use_id": tool_use_id, + }, + ) tool_trace.end() def update_usage(self, usage: Usage) -> None: @@ -213,6 +258,8 @@ def update_usage(self, usage: Usage) -> None: Args: usage: The usage data to add to the accumulated totals. """ + self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) + self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) self.accumulated_usage["inputTokens"] += usage["inputTokens"] self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] @@ -223,6 +270,7 @@ def update_metrics(self, metrics: Metrics) -> None: Args: metrics: The metrics data to add to the accumulated totals. """ + self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -355,3 +403,74 @@ def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optio A formatted string representation of the metrics. """ return "\n".join(_metrics_summary_to_lines(event_loop_metrics, allowed_names or set())) + + +class MetricsClient: + """Singleton client for managing OpenTelemetry metrics instruments. + + The actual metrics export destination (console, OTLP endpoint, etc.) is configured + through OpenTelemetry SDK configuration by users, not by this client. + """ + + _instance: Optional["MetricsClient"] = None + meter: Meter + event_loop_cycle_count: Counter + event_loop_start_cycle: Counter + event_loop_end_cycle: Counter + event_loop_cycle_duration: Histogram + event_loop_latency: Histogram + event_loop_input_tokens: Histogram + event_loop_output_tokens: Histogram + + tool_call_count: Counter + tool_success_count: Counter + tool_error_count: Counter + tool_duration: Histogram + + def __new__(cls) -> "MetricsClient": + """Create or return the singleton instance of MetricsClient. + + Returns: + The single MetricsClient instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Initialize the MetricsClient. + + This method only runs once due to the singleton pattern. + Sets up the OpenTelemetry meter and creates metric instruments. + """ + if hasattr(self, "meter"): + return + + logger.info("Creating Strands MetricsClient") + meter_provider: metrics_api.MeterProvider = metrics_api.get_meter_provider() + self.meter = meter_provider.get_meter(__name__) + self.create_instruments() + + def create_instruments(self) -> None: + """Create and initialize all OpenTelemetry metric instruments.""" + self.event_loop_cycle_count = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_CYCLE_COUNT, unit="Count" + ) + self.event_loop_start_cycle = self.meter.create_counter( + name=constants.STRANDS_EVENT_LOOP_START_CYCLE, unit="Count" + ) + self.event_loop_end_cycle = self.meter.create_counter(name=constants.STRANDS_EVENT_LOOP_END_CYCLE, unit="Count") + self.event_loop_cycle_duration = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CYCLE_DURATION, unit="s" + ) + self.event_loop_latency = self.meter.create_histogram(name=constants.STRANDS_EVENT_LOOP_LATENCY, unit="ms") + self.tool_call_count = self.meter.create_counter(name=constants.STRANDS_TOOL_CALL_COUNT, unit="Count") + self.tool_success_count = self.meter.create_counter(name=constants.STRANDS_TOOL_SUCCESS_COUNT, unit="Count") + self.tool_error_count = self.meter.create_counter(name=constants.STRANDS_TOOL_ERROR_COUNT, unit="Count") + self.tool_duration = self.meter.create_histogram(name=constants.STRANDS_TOOL_DURATION, unit="s") + self.event_loop_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKENS, unit="token" + ) + self.event_loop_output_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py new file mode 100644 index 00000000..b622eebf --- /dev/null +++ b/src/strands/telemetry/metrics_constants.py @@ -0,0 +1,15 @@ +"""Metrics that are emitted in Strands-Agents.""" + +STRANDS_EVENT_LOOP_CYCLE_COUNT = "strands.event_loop.cycle_count" +STRANDS_EVENT_LOOP_START_CYCLE = "strands.event_loop.start_cycle" +STRANDS_EVENT_LOOP_END_CYCLE = "strands.event_loop.end_cycle" +STRANDS_TOOL_CALL_COUNT = "strands.tool.call_count" +STRANDS_TOOL_SUCCESS_COUNT = "strands.tool.success_count" +STRANDS_TOOL_ERROR_COUNT = "strands.tool.error_count" + +# Histograms +STRANDS_EVENT_LOOP_LATENCY = "strands.event_loop.latency" +STRANDS_TOOL_DURATION = "strands.tool.duration" +STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" +STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" +STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index ad30a445..813c90e1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,18 +7,20 @@ import json import logging import os -from datetime import datetime, timezone -from importlib.metadata import version +from datetime import date, datetime, timezone from typing import Any, Dict, Mapping, Optional -from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +import opentelemetry.trace as trace_api +from opentelemetry import propagate +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor -from opentelemetry.trace import StatusCode +from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from ..agent.agent_result import AgentResult +from ..telemetry import get_otel_resource from ..types.content import Message, Messages from ..types.streaming import Usage from ..types.tools import ToolResult, ToolUse @@ -26,25 +28,66 @@ logger = logging.getLogger(__name__) +HAS_OTEL_EXPORTER_MODULE = False +OTEL_EXPORTER_MODULE_ERROR = ( + "opentelemetry-exporter-otlp-proto-http not detected;" + "please install strands-agents with the optional 'otel' target" + "otel http exporting is currently DISABLED" +) +try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + HAS_OTEL_EXPORTER_MODULE = True +except ImportError: + pass + class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" - def default(self, obj: Any) -> Any: - """Handle non-serializable types. + def encode(self, obj: Any) -> str: + """Recursively encode objects, preserving structure and only replacing unserializable values. Args: - obj: The object to serialize + obj: The object to encode Returns: - A JSON serializable version of the object + JSON string representation of the object """ - value = "" - try: - value = super().default(obj) - except TypeError: - value = "" - return value + # Process the object to handle non-serializable values + processed_obj = self._process_value(obj) + # Use the parent class to encode the processed object + return super().encode(processed_obj) + + def _process_value(self, value: Any) -> Any: + """Process any value, handling containers recursively. + + Args: + value: The value to process + + Returns: + Processed value with unserializable parts replaced + """ + # Handle datetime objects directly + if isinstance(value, (datetime, date)): + return value.isoformat() + + # Handle dictionaries + elif isinstance(value, dict): + return {k: self._process_value(v) for k, v in value.items()} + + # Handle lists + elif isinstance(value, list): + return [self._process_value(item) for item in value] + + # Handle all other values + else: + try: + # Test if the value is JSON serializable + json.dumps(value) + return value + except (TypeError, OverflowError, ValueError): + return "" class Tracer: @@ -65,7 +108,7 @@ def __init__( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ): """Initialize the tracer. @@ -77,13 +120,17 @@ def __init__( """ # Check environment variables first env_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") - env_console_export = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT", "").lower() in ("true", "1", "yes") + env_console_export_str = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT") + + # Constructor parameters take precedence over environment variables + self.otlp_endpoint = otlp_endpoint or env_endpoint - # Environment variables take precedence over constructor parameters - if env_endpoint: - otlp_endpoint = env_endpoint - if env_console_export: - enable_console_export = True + if enable_console_export is not None: + self.enable_console_export = enable_console_export + elif env_console_export_str: + self.enable_console_export = env_console_export_str.lower() in ("true", "1", "yes") + else: + self.enable_console_export = False # Parse headers from environment if available env_headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS") @@ -97,35 +144,37 @@ def __init__( headers_dict[key.strip()] = value.strip() otlp_headers = headers_dict except Exception as e: - logger.warning(f"error=<{e}> | failed to parse OTEL_EXPORTER_OTLP_HEADERS") + logger.warning("error=<%s> | failed to parse OTEL_EXPORTER_OTLP_HEADERS", e) self.service_name = service_name - self.otlp_endpoint = otlp_endpoint self.otlp_headers = otlp_headers or {} - self.enable_console_export = enable_console_export - - self.tracer_provider: Optional[TracerProvider] = None - self.tracer: Optional[trace.Tracer] = None - - if otlp_endpoint or enable_console_export: + self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer: Optional[trace_api.Tracer] = None + propagate.set_global_textmap( + CompositePropagator( + [ + W3CBaggagePropagator(), + TraceContextTextMapPropagator(), + ] + ) + ) + if self.otlp_endpoint or self.enable_console_export: + # Create our own tracer provider self._initialize_tracer() def _initialize_tracer(self) -> None: """Initialize the OpenTelemetry tracer.""" logger.info("initializing tracer") - # Create resource with service information - resource = Resource.create( - { - "service.name": self.service_name, - "service.version": version("strands-agents"), - "telemetry.sdk.name": "opentelemetry", - "telemetry.sdk.language": "python", - } - ) + if self._is_initialized(): + self.tracer_provider = trace_api.get_tracer_provider() + self.tracer = self.tracer_provider.get_tracer(self.service_name) + return + + resource = get_otel_resource() # Create tracer provider - self.tracer_provider = TracerProvider(resource=resource) + self.tracer_provider = SDKTracerProvider(resource=resource) # Add console exporter if enabled if self.enable_console_export and self.tracer_provider: @@ -134,7 +183,7 @@ def _initialize_tracer(self) -> None: self.tracer_provider.add_span_processor(console_processor) # Add OTLP exporter if endpoint is provided - if self.otlp_endpoint and self.tracer_provider: + if HAS_OTEL_EXPORTER_MODULE and self.otlp_endpoint and self.tracer_provider: try: # Ensure endpoint has the right format endpoint = self.otlp_endpoint @@ -156,20 +205,27 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) - logger.info(f"endpoint=<{endpoint}> | OTLP exporter configured with endpoint") + logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) + except Exception as e: - logger.error(f"error=<{e}> | Failed to configure OTLP exporter", exc_info=True) + logger.exception("error=<%s> | Failed to configure OTLP exporter", e) + elif self.otlp_endpoint and self.tracer_provider: + raise ModuleNotFoundError(OTEL_EXPORTER_MODULE_ERROR) # Set as global tracer provider - trace.set_tracer_provider(self.tracer_provider) - self.tracer = trace.get_tracer(self.service_name) + trace_api.set_tracer_provider(self.tracer_provider) + self.tracer = trace_api.get_tracer(self.service_name) + + def _is_initialized(self) -> bool: + tracer_provider = trace_api.get_tracer_provider() + return isinstance(tracer_provider, SDKTracerProvider) def _start_span( self, span_name: str, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, attributes: Optional[Dict[str, AttributeValue]] = None, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Generic helper method to start a span with common attributes. Args: @@ -183,7 +239,7 @@ def _start_span( if self.tracer is None: return None - context = trace.set_span_in_context(parent_span) if parent_span else None + context = trace_api.set_span_in_context(parent_span) if parent_span else None span = self.tracer.start_span(name=span_name, context=context) # Set start time as a common attribute @@ -195,7 +251,7 @@ def _start_span( return span - def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -210,7 +266,7 @@ def _set_attributes(self, span: trace.Span, attributes: Dict[str, AttributeValue def _end_span( self, - span: trace.Span, + span: Span, attributes: Optional[Dict[str, AttributeValue]] = None, error: Optional[Exception] = None, ) -> None: @@ -239,17 +295,17 @@ def _end_span( else: span.set_status(StatusCode.OK) except Exception as e: - logger.warning(f"error=<{e}> | error while ending span", exc_info=True) + logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() # Force flush to ensure spans are exported - if self.tracer_provider: + if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): try: self.tracer_provider.force_flush() except Exception as e: - logger.warning(f"error=<{e}> | failed to force flush tracer provider") + logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. Args: @@ -265,12 +321,12 @@ def end_span_with_error(self, span: trace.Span, error_message: str, exception: O def start_model_invoke_span( self, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, agent_name: str = "Strands Agent", messages: Optional[Messages] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for a model invocation. Args: @@ -287,7 +343,7 @@ def start_model_invoke_span( "gen_ai.system": "strands-agents", "agent.name": agent_name, "gen_ai.agent.name": agent_name, - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), } if model_id: @@ -299,7 +355,7 @@ def start_model_invoke_span( return self._start_span("Model invoke", parent_span, attributes) def end_model_invoke_span( - self, span: trace.Span, message: Message, usage: Usage, error: Optional[Exception] = None + self, span: Span, message: Message, usage: Usage, error: Optional[Exception] = None ) -> None: """End a model invocation span with results and metrics. @@ -310,7 +366,7 @@ def end_model_invoke_span( error: Optional exception if the model call failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], @@ -318,9 +374,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span( - self, tool: ToolUse, parent_span: Optional[trace.Span] = None, **kwargs: Any - ) -> Optional[trace.Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: """Start a new span for a tool call. Args: @@ -332,9 +386,10 @@ def start_tool_call_span( The created span, or None if tracing is not enabled. """ attributes: Dict[str, AttributeValue] = { + "gen_ai.prompt": serialize(tool), "tool.name": tool["name"], "tool.id": tool["toolUseId"], - "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder), + "tool.parameters": serialize(tool["input"]), } # Add additional kwargs as attributes @@ -344,7 +399,7 @@ def start_tool_call_span( return self._start_span(span_name, parent_span, attributes) def end_tool_call_span( - self, span: trace.Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None + self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None ) -> None: """End a tool call span with results. @@ -358,10 +413,11 @@ def end_tool_call_span( status = tool_result.get("status") status_str = str(status) if status is not None else "" + tool_result_content_json = serialize(tool_result.get("content")) attributes.update( { - "tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder), - "gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder), + "tool.result": tool_result_content_json, + "gen_ai.completion": tool_result_content_json, "tool.status": status_str, } ) @@ -371,10 +427,10 @@ def end_tool_call_span( def start_event_loop_cycle_span( self, event_loop_kwargs: Any, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[Span] = None, messages: Optional[Messages] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an event loop cycle. Args: @@ -390,7 +446,7 @@ def start_event_loop_cycle_span( parent_span = parent_span if parent_span else event_loop_kwargs.get("event_loop_parent_span") attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), "event_loop.cycle_id": event_loop_cycle_id, } @@ -405,7 +461,7 @@ def start_event_loop_cycle_span( def end_event_loop_cycle_span( self, - span: trace.Span, + span: Span, message: Message, tool_result_message: Optional[Message] = None, error: Optional[Exception] = None, @@ -419,11 +475,11 @@ def end_event_loop_cycle_span( error: Optional exception if the cycle failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), } if tool_result_message: - attributes["tool.result"] = json.dumps(tool_result_message["content"], cls=JSONEncoder) + attributes["tool.result"] = serialize(tool_result_message["content"]) self._end_span(span, attributes, error) @@ -435,7 +491,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[trace.Span]: + ) -> Optional[Span]: """Start a new span for an agent invocation. Args: @@ -460,7 +516,7 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = json.dumps(tools, cls=JSONEncoder) + tools_json = serialize(tools) attributes["agent.tools"] = tools_json attributes["gen_ai.agent.tools"] = tools_json @@ -475,7 +531,7 @@ def start_agent_span( def end_agent_span( self, - span: trace.Span, + span: Span, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: @@ -492,7 +548,7 @@ def end_agent_span( if response: attributes.update( { - "gen_ai.completion": json.dumps(response, cls=JSONEncoder), + "gen_ai.completion": str(response), } ) @@ -517,7 +573,7 @@ def get_tracer( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ) -> Tracer: """Get or create the global tracer. @@ -526,13 +582,16 @@ def get_tracer( otlp_endpoint: OTLP endpoint URL for sending traces. otlp_headers: Headers to include with OTLP requests. enable_console_export: Whether to also export traces to console. + tracer_provider: Optional existing TracerProvider to use instead of creating a new one. Returns: The global tracer instance. """ global _tracer_instance - if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable] + if ( + _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint) # type: ignore[unreachable] + ): _tracer_instance = Tracer( service_name=service_name, otlp_endpoint=otlp_endpoint, @@ -541,3 +600,15 @@ def get_tracer( ) return _tracer_instance + + +def serialize(obj: Any) -> str: + """Serialize an object to JSON with consistent settings. + + Args: + obj: The object to serialize + + Returns: + JSON string representation of the object + """ + return json.dumps(obj, ensure_ascii=False, cls=JSONEncoder) diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index b3ee1566..12979015 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -4,6 +4,7 @@ """ from .decorator import tool +from .structured_output import convert_pydantic_to_tool_spec from .thread_pool_executor import ThreadPoolExecutorWrapper from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec @@ -15,4 +16,5 @@ "normalize_schema", "normalize_tool_spec", "ThreadPoolExecutorWrapper", + "convert_pydantic_to_tool_spec", ] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c6890945..a2298813 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -41,6 +41,12 @@ "image/webp": "webp", } +CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( + "the client session is not running. Ensure the agent is used within " + "the MCP client context manager. For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" +) + class MCPClient: """Represents a connection to a Model Context Protocol (MCP) server. @@ -145,7 +151,7 @@ def list_tools_sync(self) -> List[MCPAgentTool]: """ self._log_debug_with_thread("listing MCP tools synchronously") if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: return await self._background_thread_session.list_tools() @@ -180,7 +186,7 @@ def call_tool_sync( """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2efdd600..e56ee999 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + if self.registry.get(tool.tool_name) is None: + normalized_name = tool.tool_name.replace("-", "_") + + matching_tools = [ + tool_name + for (tool_name, tool) in self.registry.items() + if tool_name.replace("-", "_") == normalized_name + ] + + if matching_tools: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'." + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + # Register in main registry self.registry[tool.tool_name] = tool diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py new file mode 100644 index 00000000..5421cdc6 --- /dev/null +++ b/src/strands/tools/structured_output.py @@ -0,0 +1,415 @@ +"""Tools for converting Pydantic models to Bedrock tools.""" + +from typing import Any, Dict, Optional, Type, Union + +from pydantic import BaseModel + +from ..types.tools import ToolSpec + + +def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Flattens a JSON schema by removing $defs and resolving $ref references. + + Handles required vs optional fields properly. + + Args: + schema: The JSON schema to flatten + + Returns: + Flattened JSON schema + """ + # Extract required fields list + required_fields = schema.get("required", []) + + # Initialize the flattened schema with basic properties + flattened = { + "type": schema.get("type", "object"), + "properties": {}, + } + + # Add title if present + if "title" in schema: + flattened["title"] = schema["title"] + + # Add description from schema if present, or use model docstring + if "description" in schema and schema["description"]: + flattened["description"] = schema["description"] + + # Process properties + required_props: list[str] = [] + if "properties" in schema: + required_props = [] + for prop_name, prop_value in schema["properties"].items(): + # Process the property and add to flattened properties + is_required = prop_name in required_fields + + # If the property already has nested properties (expanded), preserve them + if "properties" in prop_value: + # This is an expanded nested schema, preserve its structure + processed_prop = { + "type": prop_value.get("type", "object"), + "description": prop_value.get("description", ""), + "properties": {}, + } + + # Process each nested property + for nested_prop_name, nested_prop_value in prop_value["properties"].items(): + processed_prop["properties"][nested_prop_name] = nested_prop_value + + # Copy required fields if present + if "required" in prop_value: + processed_prop["required"] = prop_value["required"] + else: + # Process as normal + processed_prop = _process_property(prop_value, schema.get("$defs", {}), is_required) + + flattened["properties"][prop_name] = processed_prop + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed_prop.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any (only those that are truly required after processing) + # Check if required props are empty, if so, raise an error because it means there is a circular reference + + if len(required_props) > 0: + flattened["required"] = required_props + else: + raise ValueError("Circular reference detected and not supported") + + return flattened + + +def _process_property( + prop: Dict[str, Any], + defs: Dict[str, Any], + is_required: bool = False, + fully_expand: bool = True, +) -> Dict[str, Any]: + """Process a property in a schema, resolving any references. + + Args: + prop: The property to process + defs: The definitions dictionary for resolving references + is_required: Whether this property is required + fully_expand: Whether to fully expand nested properties + + Returns: + Processed property + """ + result = {} + is_nullable = False + + # Handle anyOf for optional fields (like Optional[Type]) + if "anyOf" in prop: + # Check if this is an Optional[...] case (one null, one type) + null_type = False + non_null_type = None + + for option in prop["anyOf"]: + if option.get("type") == "null": + null_type = True + is_nullable = True + elif "$ref" in option: + ref_path = option["$ref"].split("/")[-1] + if ref_path in defs: + non_null_type = _process_schema_object(defs[ref_path], defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + else: + non_null_type = option + + if null_type and non_null_type: + # For Optional fields, we mark as nullable but copy all properties from the non-null option + result = non_null_type.copy() if isinstance(non_null_type, dict) else {} + + # For type, ensure it includes "null" + if "type" in result and isinstance(result["type"], str): + result["type"] = [result["type"], "null"] + elif "type" in result and isinstance(result["type"], list) and "null" not in result["type"]: + result["type"].append("null") + elif "type" not in result: + # Default to object type if not specified + result["type"] = ["object", "null"] + + # Copy description if available in the property + if "description" in prop: + result["description"] = prop["description"] + + return result + + # Handle direct references + elif "$ref" in prop: + # Resolve reference + ref_path = prop["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Process the referenced object to get a complete schema + result = _process_schema_object(ref_dict, defs, fully_expand) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # For regular fields, copy all properties + for key, value in prop.items(): + if key not in ["$ref", "anyOf"]: + if isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif key == "type" and not is_required and not is_nullable: + # For non-required fields, ensure type is a list with "null" + if isinstance(value, str): + result[key] = [value, "null"] + elif isinstance(value, list) and "null" not in value: + result[key] = value + ["null"] + else: + result[key] = value + else: + result[key] = value + + return result + + +def _process_schema_object( + schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True +) -> Dict[str, Any]: + """Process a schema object, typically from $defs, to resolve all nested properties. + + Args: + schema_obj: The schema object to process + defs: The definitions dictionary for resolving references + fully_expand: Whether to fully expand nested properties + + Returns: + Processed schema object with all properties resolved + """ + result = {} + + # Copy basic attributes + for key, value in schema_obj.items(): + if key != "properties" and key != "required" and key != "$defs": + result[key] = value + + # Process properties if present + if "properties" in schema_obj: + result["properties"] = {} + required_props = [] + + # Get required fields list + required_fields = schema_obj.get("required", []) + + for prop_name, prop_value in schema_obj["properties"].items(): + # Process each property + is_required = prop_name in required_fields + processed = _process_property(prop_value, defs, is_required, fully_expand) + result["properties"][prop_name] = processed + + # Track which properties are actually required after processing + if is_required and "null" not in str(processed.get("type", "")): + required_props.append(prop_name) + + # Add required fields if any + if required_props: + result["required"] = required_props + + return result + + +def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """Recursively processes nested dictionaries and resolves $ref references. + + Args: + d: The dictionary to process + defs: The definitions dictionary for resolving references + + Returns: + Processed dictionary + """ + result: Dict[str, Any] = {} + + # Handle direct reference + if "$ref" in d: + ref_path = d["$ref"].split("/")[-1] + if ref_path in defs: + ref_dict = defs[ref_path] + # Recursively process the referenced object + return _process_schema_object(ref_dict, defs) + else: + # Handle missing reference path gracefully + raise ValueError(f"Missing reference: {ref_path}") + + # Process each key-value pair + for key, value in d.items(): + if key == "$ref": + # Already handled above + continue + elif isinstance(value, dict): + result[key] = _process_nested_dict(value, defs) + elif isinstance(value, list): + # Process lists (like for enum values) + result[key] = [_process_nested_dict(item, defs) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + return result + + +def convert_pydantic_to_tool_spec( + model: Type[BaseModel], + description: Optional[str] = None, +) -> ToolSpec: + """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. + + Handles optional vs. required fields, resolves $refs, and uses docstrings. + + Args: + model: The Pydantic model class to convert + description: Optional description of the tool's purpose + + Returns: + ToolSpec: Dict containing the Bedrock tool specification + """ + name = model.__name__ + + # Get the JSON schema + input_schema = model.model_json_schema() + + # Get model docstring for description if not provided + model_description = description + if not model_description and model.__doc__: + model_description = model.__doc__.strip() + + # Process all referenced models to ensure proper docstrings + # This step is important for gathering descriptions from referenced models + _process_referenced_models(input_schema, model) + + # Now, let's fully expand the nested models with all their properties + _expand_nested_properties(input_schema, model) + + # Flatten the schema + flattened_schema = _flatten_schema(input_schema) + + final_schema = flattened_schema + + # Construct the tool specification + return ToolSpec( + name=name, + description=model_description or f"{name} structured output tool", + inputSchema={"json": final_schema}, + ) + + +def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Expand the properties of nested models in the schema to include their full structure. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # First, process the properties at this level + if "properties" not in schema: + return + + # Create a modified copy of the properties to avoid modifying while iterating + for prop_name, prop_info in list(schema["properties"].items()): + field = model.model_fields.get(prop_name) + if not field: + continue + + field_type = field.annotation + + # Handle Optional types + is_optional = False + if ( + field_type is not None + and hasattr(field_type, "__origin__") + and field_type.__origin__ is Union + and hasattr(field_type, "__args__") + ): + # Look for Optional[BaseModel] + for arg in field_type.__args__: + if arg is type(None): + is_optional = True + elif isinstance(arg, type) and issubclass(arg, BaseModel): + field_type = arg + + # If this is a BaseModel field, expand its properties with full details + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Get the nested model's schema with all its properties + nested_model_schema = field_type.model_json_schema() + + # Create a properly expanded nested object + expanded_object = { + "type": ["object", "null"] if is_optional else "object", + "description": prop_info.get("description", field.description or f"The {prop_name}"), + "properties": {}, + } + + # Copy all properties from the nested schema + if "properties" in nested_model_schema: + expanded_object["properties"] = nested_model_schema["properties"] + + # Copy required fields + if "required" in nested_model_schema: + expanded_object["required"] = nested_model_schema["required"] + + # Replace the original property with this expanded version + schema["properties"][prop_name] = expanded_object + + +def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process referenced models to ensure their docstrings are included. + + This updates the schema in place. + + Args: + schema: The JSON schema to process + model: The Pydantic model class + """ + # Process $defs to add docstrings from the referenced models + if "$defs" in schema: + # Look through model fields to find referenced models + for _, field in model.model_fields.items(): + field_type = field.annotation + + # Handle Optional types - with null checks + if field_type is not None and hasattr(field_type, "__origin__"): + origin = field_type.__origin__ + if origin is Union and hasattr(field_type, "__args__"): + # Find the non-None type in the Union (for Optional fields) + for arg in field_type.__args__: + if arg is not type(None): + field_type = arg + break + + # Check if this is a BaseModel subclass + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Update $defs with this model's information + ref_name = field_type.__name__ + if ref_name in schema.get("$defs", {}): + ref_def = schema["$defs"][ref_name] + + # Add docstring as description if available + if field_type.__doc__ and not ref_def.get("description"): + ref_def["description"] = field_type.__doc__.strip() + + # Recursively process properties in the referenced model + _process_properties(ref_def, field_type) + + +def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: + """Process properties in a schema definition to add descriptions from field metadata. + + Args: + schema_def: The schema definition to update + model: The model class that defines the schema + """ + if "properties" in schema_def: + for prop_name, prop_info in schema_def["properties"].items(): + field = model.model_fields.get(prop_name) + + # Add field description if available and not already set + if field and field.description and not prop_info.get("description"): + prop_info["description"] = field.description diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 40565a24..a449c74e 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -40,14 +40,14 @@ def validate_tool_use_name(tool: ToolUse) -> None: Raises: InvalidToolUseNameException: If the tool name is invalid. """ - # We need to fix some typing here, because we dont actually expect a ToolUse, but dict[str, Any] + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] if "name" not in tool: message = "tool name missing" # type: ignore[unreachable] logger.warning(message) raise InvalidToolUseNameException(message) tool_name = tool["name"] - tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_]*$" + tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_\-]*$" tool_name_max_length = 64 valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) tool_name_len = len(tool_name) @@ -63,50 +63,54 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) +def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: + """Normalize a single property definition. + + Args: + prop_name: The name of the property. + prop_def: The property definition to normalize. + + Returns: + The normalized property definition. + """ + if not isinstance(prop_def, dict): + return {"type": "string", "description": f"Property {prop_name}"} + + if prop_def.get("type") == "object" and "properties" in prop_def: + return normalize_schema(prop_def) # Recursive call + + # Copy existing property, ensuring defaults + normalized_prop = prop_def.copy() + normalized_prop.setdefault("type", "string") + normalized_prop.setdefault("description", f"Property {prop_name}") + return normalized_prop + + def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Normalize a JSON schema to match expectations. + This function recursively processes nested objects to preserve the complete schema structure. + Uses a copy-then-normalize approach to preserve all original schema properties. + Args: schema: The schema to normalize. Returns: The normalized schema. """ - normalized = {"type": schema.get("type", "object"), "properties": {}} - - # Handle properties - if "properties" in schema: - for prop_name, prop_def in schema["properties"].items(): - if isinstance(prop_def, dict): - normalized_prop = { - "type": prop_def.get("type", "string"), - "description": prop_def.get("description", f"Property {prop_name}"), - } - - # Handle enum values correctly - if "enum" in prop_def: - normalized_prop["enum"] = prop_def["enum"] - - # Handle numeric constraints - if prop_def.get("type") in ["number", "integer"]: - if "minimum" in prop_def: - normalized_prop["minimum"] = prop_def["minimum"] - if "maximum" in prop_def: - normalized_prop["maximum"] = prop_def["maximum"] - - normalized["properties"][prop_name] = normalized_prop - else: - # Handle non-dict property definitions (like simple strings) - normalized["properties"][prop_name] = { - "type": "string", - "description": f"Property {prop_name}", - } - - # Required fields - if "required" in schema: - normalized["required"] = schema["required"] - else: - normalized["required"] = [] + # Start with a complete copy to preserve all existing properties + normalized = schema.copy() + + # Ensure essential structure exists + normalized.setdefault("type", "object") + normalized.setdefault("properties", {}) + normalized.setdefault("required", []) + + # Process properties recursively + if "properties" in normalized: + properties = normalized["properties"] + for prop_name, prop_def in properties.items(): + normalized["properties"][prop_name] = _normalize_property(prop_name, prop_def) return normalized diff --git a/src/strands/types/content.py b/src/strands/types/content.py index c64b6773..790e9094 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -60,10 +60,21 @@ class ReasoningContentBlock(TypedDict, total=False): redactedContent: bytes +class CachePoint(TypedDict): + """A cache point configuration for optimizing conversation history. + + Attributes: + type: The type of cache point, typically "default". + """ + + type: str + + class ContentBlock(TypedDict, total=False): """A block of content for a message that you pass to, or receive from, a model. Attributes: + cachePoint: A cache point configuration to optimize conversation history. document: A document to include in the message. guardContent: Contains the content to assess with the guardrail. image: Image to include in the message. @@ -74,6 +85,7 @@ class ContentBlock(TypedDict, total=False): video: Video to include in the message. """ + cachePoint: CachePoint document: DocumentContent guardContent: GuardContent image: ImageContent diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index 6055b9ab..c15ba1be 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -16,7 +16,7 @@ class GuardrailConfig(TypedDict, total=False): Attributes: guardrailIdentifier: Unique identifier for the guardrail. guardrailVersion: Version of the guardrail to apply. - streamProcessingMode: Procesing mode. + streamProcessingMode: Processing mode. trace: The trace behavior for the guardrail. """ @@ -219,7 +219,7 @@ class GuardrailAssessment(TypedDict): contentPolicy: The content policy. contextualGroundingPolicy: The contextual grounding policy used for the guardrail assessment. sensitiveInformationPolicy: The sensitive information policy. - topicPolic: The topic policy. + topicPolicy: The topic policy. wordPolicy: The word policy. """ diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 058a09ea..29b89e5c 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -68,7 +68,7 @@ class ImageContent(TypedDict): class VideoSource(TypedDict): - """Contains the content of a vidoe. + """Contains the content of a video. Attributes: bytes: The binary content of the video. diff --git a/src/strands/types/models/__init__.py b/src/strands/types/models/__init__.py new file mode 100644 index 00000000..5ce0a498 --- /dev/null +++ b/src/strands/types/models/__init__.py @@ -0,0 +1,6 @@ +"""Model-related type definitions for the SDK.""" + +from .model import Model +from .openai import OpenAIModel + +__all__ = ["Model", "OpenAIModel"] diff --git a/src/strands/types/models.py b/src/strands/types/models/model.py similarity index 78% rename from src/strands/types/models.py rename to src/strands/types/models/model.py index e3d96e29..071c8a51 100644 --- a/src/strands/types/models.py +++ b/src/strands/types/models/model.py @@ -2,14 +2,18 @@ import abc import logging -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Type, TypeVar -from .content import Messages -from .streaming import StreamEvent -from .tools import ToolSpec +from pydantic import BaseModel + +from ..content import Messages +from ..streaming import StreamEvent +from ..tools import ToolSpec logger = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) + class Model(abc.ABC): """Abstract base class for AI model implementations. @@ -38,6 +42,26 @@ def get_config(self) -> Any: """ pass + @abc.abstractmethod + # pragma: no cover + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt messages to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + + Returns: + The structured output as a serialized instance of the output model. + + Raises: + ValidationException: The response format from the model does not match the output_model + """ + pass + @abc.abstractmethod # pragma: no cover def format_request( diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py new file mode 100644 index 00000000..8ff37d35 --- /dev/null +++ b/src/strands/types/models/openai.py @@ -0,0 +1,307 @@ +"""Base OpenAI model provider. + +This module provides the base OpenAI model provider class which implements shared logic for formatting requests and +responses to and from the OpenAI specification. + +- Docs: https://pypi.org/project/openai +""" + +import abc +import base64 +import json +import logging +import mimetypes +from typing import Any, Callable, Optional, Type, TypeVar, cast + +from pydantic import BaseModel +from typing_extensions import override + +from ..content import ContentBlock, Messages +from ..streaming import StreamEvent +from ..tools import ToolResult, ToolSpec, ToolUse +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class OpenAIModel(Model, abc.ABC): + """Base OpenAI model provider implementation. + + Implements shared logic for formatting requests and responses to and from the OpenAI specification. + """ + + config: dict[str, Any] + + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [cls.format_request_message_content(content) for content in contents], + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an OpenAI compatible chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + return { + "messages": self.format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + **(self.config.get("params") or {}), + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index db600f15..9c99b210 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -157,7 +157,7 @@ class ModelStreamErrorEvent(ExceptionEvent): originalStatusCode: int -class RedactContentEvent(TypedDict): +class RedactContentEvent(TypedDict, total=False): """Event for redacting content. Attributes: diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py index 9ffd1bdf..bf0be706 100644 --- a/tests-integ/test_bedrock_guardrails.py +++ b/tests-integ/test_bedrock_guardrails.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def boto_session(): - return boto3.Session(region_name="us-west-2") + return boto3.Session(region_name="us-east-1") @pytest.fixture(scope="module") @@ -142,7 +142,7 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_stream_processing_mode=processing_mode, guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, - region_name="us-west-2", + region_name="us-east-1", ) agent = Agent( diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py index 59ae2a14..8b1dade3 100644 --- a/tests-integ/test_mcp_client.py +++ b/tests-integ/test_mcp_client.py @@ -1,8 +1,10 @@ import base64 +import os import threading import time from typing import List, Literal +import pytest from mcp import StdioServerParameters, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -101,6 +103,10 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", +) def test_streamable_http_mcp_client(): server_thread = threading.Thread( target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py index 1b0412c9..50033f8f 100644 --- a/tests-integ/test_model_anthropic.py +++ b/tests-integ/test_model_anthropic.py @@ -1,6 +1,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,7 +34,7 @@ def tool_weather() -> str: @pytest.fixture def system_prompt(): - return "You are an AI assistant that uses & instead of ." + return "You are an AI assistant." @pytest.fixture @@ -46,4 +47,17 @@ def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - assert all(string in text for string in ["12:00", "sunny", "&"]) + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py new file mode 100644 index 00000000..5378a9b2 --- /dev/null +++ b/tests-integ/test_model_bedrock.py @@ -0,0 +1,151 @@ +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models import BedrockModel + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def streaming_model(): + return BedrockModel( + model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + streaming=True, + ) + + +@pytest.fixture +def non_streaming_model(): + return BedrockModel( + model_id="us.meta.llama3-2-90b-instruct-v1:0", + streaming=False, + ) + + +@pytest.fixture +def streaming_agent(streaming_model, system_prompt): + return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model, system_prompt): + return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +def test_streaming_agent(streaming_agent): + """Test agent with streaming model.""" + result = streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_non_streaming_agent(non_streaming_agent): + """Test agent with non-streaming model.""" + result = non_streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_streaming_model_events(streaming_model): + """Test streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_non_streaming_model_events(non_streaming_model): + """Test non-streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(non_streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is 123 + 456?") + + # Print the full message content for debugging + print("\nFull message content:") + import json + + print(json.dumps(result.message["content"], indent=2)) + + assert tool_was_called + + +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + agent("What is 123 + 456?") + + assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py index f1afb61f..01a3e121 100644 --- a/tests-integ/test_model_litellm.py +++ b/tests-integ/test_model_litellm.py @@ -1,4 +1,5 @@ import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -7,7 +8,7 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") @pytest.fixture @@ -33,3 +34,16 @@ def test_agent(agent): text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_llamaapi.py b/tests-integ/test_model_llamaapi.py index 5cddc1a0..dad6919e 100644 --- a/tests-integ/test_model_llamaapi.py +++ b/tests-integ/test_model_llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates import os import pytest diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 00000000..38b46821 --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py new file mode 100644 index 00000000..b0790ba0 --- /dev/null +++ b/tests-integ/test_model_openai.py @@ -0,0 +1,66 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def model(): + return OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_structured_output(model): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + agent = Agent(model=model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py index 4c97db6b..01f20339 100644 --- a/tests-integ/test_stream_agent.py +++ b/tests-integ/test_stream_agent.py @@ -1,5 +1,5 @@ """ -Test script for Strands's custom callback handler functionality. +Test script for Strands' custom callback handler functionality. Demonstrates different patterns of callback handling and processing. """ diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 00000000..5dcf4944 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72Β°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72Β°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72Β°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/conftest.py b/tests/conftest.py index 4f0b5b21..cd18b698 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,8 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) @pytest.fixture diff --git a/tests/multiagent/__init__.py b/tests/multiagent/__init__.py new file mode 100644 index 00000000..b43bae53 --- /dev/null +++ b/tests/multiagent/__init__.py @@ -0,0 +1 @@ +"""Tests for the multiagent module.""" diff --git a/tests/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py new file mode 100644 index 00000000..eb5487d9 --- /dev/null +++ b/tests/multiagent/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A module.""" diff --git a/tests/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py new file mode 100644 index 00000000..558a4594 --- /dev/null +++ b/tests/multiagent/a2a/conftest.py @@ -0,0 +1,41 @@ +"""Common fixtures for A2A module tests.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue + +from strands.agent.agent import Agent as SAAgent +from strands.agent.agent_result import AgentResult as SAAgentResult + + +@pytest.fixture +def mock_strands_agent(): + """Create a mock Strands Agent for testing.""" + agent = MagicMock(spec=SAAgent) + agent.name = "Test Agent" + agent.description = "A test agent for unit testing" + + # Setup default response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + agent.return_value = mock_result + + return agent + + +@pytest.fixture +def mock_request_context(): + """Create a mock RequestContext for testing.""" + context = MagicMock(spec=RequestContext) + context.get_user_input.return_value = "Test input" + return context + + +@pytest.fixture +def mock_event_queue(): + """Create a mock EventQueue for testing.""" + queue = MagicMock(spec=EventQueue) + queue.enqueue_event = AsyncMock() + return queue diff --git a/tests/multiagent/a2a/test_agent.py b/tests/multiagent/a2a/test_agent.py new file mode 100644 index 00000000..5558c2af --- /dev/null +++ b/tests/multiagent/a2a/test_agent.py @@ -0,0 +1,165 @@ +"""Tests for the A2AAgent class.""" + +from unittest.mock import patch + +import pytest +from a2a.types import AgentCapabilities, AgentCard +from fastapi import FastAPI +from starlette.applications import Starlette + +from strands.multiagent.a2a.agent import A2AAgent + + +def test_a2a_agent_initialization(mock_strands_agent): + """Test that A2AAgent initializes correctly with default values.""" + a2a_agent = A2AAgent(mock_strands_agent) + + assert a2a_agent.strands_agent == mock_strands_agent + assert a2a_agent.name == "Test Agent" + assert a2a_agent.description == "A test agent for unit testing" + assert a2a_agent.host == "0.0.0" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://0.0.0:9000/" + assert a2a_agent.version == "0.0.1" + assert isinstance(a2a_agent.capabilities, AgentCapabilities) + + +def test_a2a_agent_initialization_with_custom_values(mock_strands_agent): + """Test that A2AAgent initializes correctly with custom values.""" + a2a_agent = A2AAgent( + mock_strands_agent, + host="127.0.0.1", + port=8080, + version="1.0.0", + ) + + assert a2a_agent.host == "127.0.0.1" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.version == "1.0.0" + + +def test_public_agent_card(mock_strands_agent): + """Test that public_agent_card returns a valid AgentCard.""" + a2a_agent = A2AAgent(mock_strands_agent) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + assert card.url == "http://0.0.0:9000/" + assert card.version == "0.0.1" + assert card.defaultInputModes == ["text"] + assert card.defaultOutputModes == ["text"] + assert card.skills == [] + assert card.capabilities == a2a_agent.capabilities + + +def test_public_agent_card_with_missing_name(mock_strands_agent): + """Test that public_agent_card raises ValueError when name is missing.""" + mock_strands_agent.name = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent name cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_public_agent_card_with_missing_description(mock_strands_agent): + """Test that public_agent_card raises ValueError when description is missing.""" + mock_strands_agent.description = "" + a2a_agent = A2AAgent(mock_strands_agent) + + with pytest.raises(ValueError, match="A2A agent description cannot be None or empty"): + _ = a2a_agent.public_agent_card + + +def test_agent_skills(mock_strands_agent): + """Test that agent_skills returns an empty list (current implementation).""" + a2a_agent = A2AAgent(mock_strands_agent) + + skills = a2a_agent.agent_skills + + assert isinstance(skills, list) + assert len(skills) == 0 + + +def test_to_starlette_app(mock_strands_agent): + """Test that to_starlette_app returns a Starlette application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app(mock_strands_agent): + """Test that to_fastapi_app returns a FastAPI application.""" + a2a_agent = A2AAgent(mock_strands_agent) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +@patch("uvicorn.run") +def test_serve_with_starlette(mock_run, mock_strands_agent): + """Test that serve starts a Starlette server by default.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], Starlette) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_fastapi(mock_run, mock_strands_agent): + """Test that serve starts a FastAPI server when specified.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(app_type="fastapi") + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert isinstance(args[0], FastAPI) + assert kwargs["host"] == "0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_custom_kwargs(mock_run, mock_strands_agent): + """Test that serve passes additional kwargs to uvicorn.run.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve(log_level="debug", reload=True) + + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["log_level"] == "debug" + assert kwargs["reload"] is True + + +@patch("uvicorn.run", side_effect=KeyboardInterrupt) +def test_serve_handles_keyboard_interrupt(mock_run, mock_strands_agent, caplog): + """Test that serve handles KeyboardInterrupt gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server shutdown requested (KeyboardInterrupt)" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text + + +@patch("uvicorn.run", side_effect=Exception("Test exception")) +def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): + """Test that serve handles general exceptions gracefully.""" + a2a_agent = A2AAgent(mock_strands_agent) + + a2a_agent.serve() + + assert "Strands A2A server encountered exception" in caplog.text + assert "Strands A2A server has shutdown" in caplog.text diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 00000000..2ac9bed9 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,118 @@ +"""Tests for the StrandsA2AExecutor class.""" + +from unittest.mock import MagicMock + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult as SAAgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor + + +def test_executor_initialization(mock_strands_agent): + """Test that StrandsA2AExecutor initializes correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + assert executor.agent == mock_strands_agent + + +@pytest.mark.asyncio +async def test_execute_with_text_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes text responses correctly.""" + # Setup mock agent response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify event was enqueued + mock_event_queue.enqueue_event.assert_called_once() + args, _ = mock_event_queue.enqueue_event.call_args + event = args[0] + assert event.parts[0].root.text == "Test response" + + +@pytest.mark.asyncio +async def test_execute_with_multiple_text_blocks(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes multiple text blocks correctly.""" + # Setup mock agent response with multiple text blocks + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "First response"}, {"text": "Second response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify events were enqueued + assert mock_event_queue.enqueue_event.call_count == 2 + + # Check first event + args1, _ = mock_event_queue.enqueue_event.call_args_list[0] + event1 = args1[0] + assert event1.parts[0].root.text == "First response" + + # Check second event + args2, _ = mock_event_queue.enqueue_event.call_args_list[1] + event2 = args2[0] + assert event2.parts[0].root.text == "Second response" + + +@pytest.mark.asyncio +async def test_execute_with_empty_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles empty responses correctly.""" + # Setup mock agent response with empty content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": []} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_with_no_message(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles responses with no message correctly.""" + # Setup mock agent response with no message + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = None + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5d2ffb23..85d17544 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -7,9 +7,11 @@ from time import sleep import pytest +from pydantic import BaseModel import strands -from strands.agent.agent import Agent +from strands import Agent +from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler @@ -317,7 +319,7 @@ def test_agent__call__( ) callback_handler.assert_called() - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, agent, tool, mock_event_loop_cycle): @@ -337,17 +339,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, ], ] + override_system_prompt = "Override system prompt" + override_model = unittest.mock.Mock() + override_tool_execution_handler = unittest.mock.Mock() + override_event_loop_metrics = unittest.mock.Mock() + override_callback_handler = unittest.mock.Mock() + override_tool_handler = unittest.mock.Mock() + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] + override_tool_config = {"test": "config"} + def check_kwargs(some_value, **kwargs): assert some_value == "a_value" assert kwargs is not None + assert kwargs["system_prompt"] == override_system_prompt + assert kwargs["model"] == override_model + assert kwargs["tool_execution_handler"] == override_tool_execution_handler + assert kwargs["event_loop_metrics"] == override_event_loop_metrics + assert kwargs["callback_handler"] == override_callback_handler + assert kwargs["tool_handler"] == override_tool_handler + assert kwargs["messages"] == override_messages + assert kwargs["tool_config"] == override_tool_config + assert kwargs["agent"] == agent # Return expected values from event_loop_cycle return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {} mock_event_loop_cycle.side_effect = check_kwargs - agent("test message", some_value="a_value") - assert mock_event_loop_cycle.call_count == 1 + agent( + "test message", + some_value="a_value", + system_prompt=override_system_prompt, + model=override_model, + tool_execution_handler=override_tool_execution_handler, + event_loop_metrics=override_event_loop_metrics, + callback_handler=override_callback_handler, + tool_handler=override_tool_handler, + messages=override_messages, + tool_config=override_tool_config, + ) + + mock_event_loop_cycle.assert_called_once() def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): @@ -407,7 +439,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -428,7 +460,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite with pytest.raises(ContextWindowOverflowException): agent("Test!") - assert conversation_manager_spy.reduce_context.call_count == 251 + assert conversation_manager_spy.reduce_context.call_count > 0 assert conversation_manager_spy.apply_management.call_count == 1 @@ -453,10 +485,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -473,6 +538,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -507,7 +575,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 @@ -552,7 +620,7 @@ def test_agent_tool(mock_randint, agent): } assert tru_result == exp_result - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent_tool_user_message_override(agent): @@ -566,7 +634,7 @@ def test_agent_tool_user_message_override(agent): }, { "text": ( - "agent.tool_decorated direct tool call\n" + "agent.tool.tool_decorated direct tool call.\n" "Input parameters: " '{"random_string": "abcdEfghI123", "user_message_override": "test override"}\n' ), @@ -643,6 +711,46 @@ def function(system_prompt: str) -> str: ) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + tool_name = "system-prompter" + + @strands.tools.tool(name=tool_name) + def function(system_prompt: str) -> str: + return system_prompt + + tool = strands.tools.tools.FunctionTool(function) + agent.tool_registry.register_tool(tool) + + mock_randint.return_value = 1 + + agent.tool.system_prompter(system_prompt="tool prompt") + + # Verify the correct tool was invoked + assert agent.tool_handler.process.call_count == 1 + tool_call = agent.tool_handler.process.call_args.kwargs.get("tool") + + assert tool_call == { + # Note that the tool-use uses the "python safe" name + "toolUseId": "tooluse_system_prompter_1", + # But the name of the tool is the one in the registry + "name": tool_name, + "input": {"system_prompt": "tool prompt"}, + } + + +def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + mock_randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Tool 'system_prompter_1' not found" + + def test_agent_with_none_callback_handler_prints_nothing(): agent = Agent() @@ -655,10 +763,64 @@ def test_agent_with_callback_handler_none_uses_null_handler(): assert agent.callback_handler == null_callback_handler +def test_agent_callback_handler_not_provided_creates_new_instances(): + """Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created.""" + # Create two agents without providing callback_handler + agent1 = Agent() + agent2 = Agent() + + # Both should have PrintingCallbackHandler instances + assert isinstance(agent1.callback_handler, PrintingCallbackHandler) + assert isinstance(agent2.callback_handler, PrintingCallbackHandler) + + # But they should be different object instances + assert agent1.callback_handler is not agent2.callback_handler + + +def test_agent_callback_handler_explicit_none_uses_null_handler(): + """Test that when callback_handler is explicitly set to None, null_callback_handler is used.""" + agent = Agent(callback_handler=None) + + # Should use null_callback_handler + assert agent.callback_handler is null_callback_handler + + +def test_agent_callback_handler_custom_handler_used(): + """Test that when a custom callback_handler is provided, it is used.""" + custom_handler = unittest.mock.Mock() + agent = Agent(callback_handler=custom_handler) + + # Should use the provided custom handler + assert agent.callback_handler is custom_handler + + +# mock the User(name='Jane Doe', age=30, email='jane@doe.com') +class User(BaseModel): + """A user of the system.""" + + name: str + age: int + email: str + + +def test_agent_method_structured_output(agent): + # Mock the structured_output method on the model + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=expected_user) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + result = agent.structured_output(User, prompt) + assert result == expected_user + + # Verify the model's structured_output was called with correct arguments + agent.model.structured_output.assert_called_once_with( + User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler + ) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): - mock_event_loop_cycle.side_effect = ValueError("Test exception") - agent = Agent() # Define the side effect to simulate callback handler being called multiple times @@ -922,6 +1084,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler being called multiple times + def call_callback_handler(*args, **kwargs): + # Extract the callback handler from kwargs + callback_handler = kwargs.get("callback_handler") + # Call the callback handler with different data values + callback_handler(data="First chunk") + callback_handler(data="Second chunk") + callback_handler(data="Final chunk", complete=True) + # Return expected values from event_loop_cycle + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} + + mock_event_loop_cycle.side_effect = call_callback_handler + + # Create agent and make a call + agent = Agent(model=mock_model) + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + expected_response = AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} + ) + + # Verify span was ended with the result + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): """Test that __call__ creates and ends a span when an exception occurs.""" @@ -955,6 +1163,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler raising an Exception + test_exception = ValueError("Test exception") + mock_model.mock_converse.side_effect = test_exception + + # Create agent and make a call + agent = Agent(model=mock_model) + + # Call the agent and catch the exception + with pytest.raises(ValueError): + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + # Verify span was ended with the exception + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): """Test that event_loop_cycle is called with the parent span.""" diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 2f6ee77d..7d43199e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,6 +1,7 @@ import pytest import strands +from strands.agent.agent import Agent from strands.types.exceptions import ContextWindowOverflowException @@ -24,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -111,41 +113,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Second response"}]}, ], ), - # 7 - Message count above max window size - Remove dangling tool uses and tool results - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"text": "First message"}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]}, - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}} - ], - }, - ], - [ - { - "role": "user", - "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}], - }, - ], - ), - # 8 - Message count above max window size - Remove multiple tool use/tool result pairs - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, - ], - [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, - ], - ), - # 9 - Message count above max window size - Preserve tool use/tool result pairs + # 7 - Message count above max window size - Preserve tool use/tool result pairs ( {"window_size": 2}, [ @@ -158,7 +126,7 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, ], ), - # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary ( {"window_size": 3}, [ @@ -173,7 +141,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Response after tool use"}]}, ], ), - # 11 - Test sliding window with multiple tool pairs that need preservation + # 9 - Test sliding window with multiple tool pairs that need preservation ( {"window_size": 4}, [ @@ -185,7 +153,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Final response"}]}, ], [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, @@ -195,7 +162,57 @@ def conversation_manager(request): indirect=["conversation_manager"], ) def test_apply_management(conversation_manager, messages, expected_messages): - conversation_manager.apply_management(messages) + test_agent = Agent(messages=messages) + conversation_manager.apply_management(test_agent) + + assert messages == expected_messages + + +def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + test_agent = Agent(messages=messages) + + with pytest.raises(ContextWindowOverflowException): + manager.apply_management(test_agent) + + assert messages == original_messages + + +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] assert messages == expected_messages @@ -208,8 +225,9 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(ContextWindowOverflowException): manager.reduce_context(messages) @@ -225,8 +243,9 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(RuntimeError): manager.reduce_context(messages, RuntimeError("test")) diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py new file mode 100644 index 00000000..9952203e --- /dev/null +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -0,0 +1,566 @@ +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.types.content import Messages +from strands.types.exceptions import ContextWindowOverflowException + +if TYPE_CHECKING: + from strands.agent.agent import Agent + + +class MockAgent: + """Mock agent for testing summarization.""" + + def __init__(self, summary_response="This is a summary of the conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.call_tracker = Mock() + + def __call__(self, prompt): + """Mock agent call that returns a summary.""" + self.call_tracker(prompt) + result = Mock() + result.message = {"role": "assistant", "content": [{"text": self.summary_response}]} + return result + + +def create_mock_agent(summary_response="This is a summary of the conversation.") -> "Agent": + """Factory function that returns a properly typed MockAgent.""" + return cast("Agent", MockAgent(summary_response)) + + +@pytest.fixture +def mock_agent(): + """Fixture for mock agent.""" + return create_mock_agent() + + +@pytest.fixture +def summarizing_manager(): + """Fixture for summarizing conversation manager with default settings.""" + return SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + + +def test_init_default_values(): + """Test initialization with default values.""" + manager = SummarizingConversationManager() + + assert manager.summarization_agent is None + assert manager.summary_ratio == 0.3 + assert manager.preserve_recent_messages == 10 + + +def test_init_clamps_summary_ratio(): + """Test that summary_ratio is clamped to valid range.""" + # Test lower bound + manager = SummarizingConversationManager(summary_ratio=0.05) + assert manager.summary_ratio == 0.1 + + # Test upper bound + manager = SummarizingConversationManager(summary_ratio=0.95) + assert manager.summary_ratio == 0.8 + + +def test_reduce_context_raises_when_no_agent(): + """Test that reduce_context raises exception when agent has no messages.""" + manager = SummarizingConversationManager() + + # Create a mock agent with no messages + mock_agent = Mock() + empty_messages: Messages = [] + mock_agent.messages = empty_messages + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_with_summarization(summarizing_manager, mock_agent): + """Test reduce_context with summarization enabled.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "user", "content": [{"text": "Message 3"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + mock_agent.messages = test_messages + + summarizing_manager.reduce_context(mock_agent) + + # Should have: 1 summary message + 2 preserved recent messages + remaining from summarization + assert len(mock_agent.messages) == 4 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + first_content = mock_agent.messages[0]["content"][0] + assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] + + # Recent messages should be preserved + assert "Message 3" in str(mock_agent.messages[-2]["content"]) + assert "Response 3" in str(mock_agent.messages[-1]["content"]) + + +def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, mock_agent): + """Test that reduce_context raises exception when there are too few messages to summarize effectively.""" + # Create a scenario where calculation results in 0 messages to summarize + manager = SummarizingConversationManager( + summary_ratio=0.1, # Very small ratio + preserve_recent_messages=5, # High preservation + ) + + insufficient_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_insufficient_messages_for_summarization(mock_agent): + """Test reduce_context when there aren't enough messages to summarize.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=3, + ) + + insufficient_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = insufficient_messages + + # This should raise an exception since there aren't enough messages to summarize + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) + + +def test_reduce_context_raises_on_summarization_failure(): + """Test that reduce_context raises exception when summarization fails.""" + # Create an agent that will fail + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + failing_agent_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + ] + failing_agent.messages = failing_agent_messages + + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + with pytest.raises(Exception, match="Agent failed"): + manager.reduce_context(failing_agent) + + # Should log the error + mock_logger.error.assert_called_once() + + +def test_generate_summary(summarizing_manager, mock_agent): + """Test the _generate_summary method.""" + test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + summary = summarizing_manager._generate_summary(test_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): + """Test summary generation with tool use and results.""" + tool_messages: Messages = [ + {"role": "user", "content": [{"text": "Use a tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + summary = summarizing_manager._generate_summary(tool_messages, mock_agent) + + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." + + +def test_generate_summary_raises_on_agent_failure(): + """Test that _generate_summary raises exception when agent fails.""" + failing_agent = Mock() + failing_agent.side_effect = Exception("Agent failed") + failing_agent.system_prompt = None + empty_failing_messages: Messages = [] + failing_agent.messages = empty_failing_messages + + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should raise the exception from the agent + with pytest.raises(Exception, match="Agent failed"): + manager._generate_summary(messages, failing_agent) + + +def test_adjust_split_point_for_tool_pairs(summarizing_manager): + """Test that the split point is adjusted to avoid breaking ToolUse/ToolResult pairs.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Tool output"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Response after tool"}]}, + ] + + # If we try to split at message 2 (the ToolResult), it should move forward to message 3 + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert adjusted_split == 3 # Should move to after the ToolResult + + # If we try to split at message 3, it should be fine (no tool issues) + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + assert adjusted_split == 3 + + # If we try to split at message 1 (toolUse with following toolResult), it should be valid + adjusted_split = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert adjusted_split == 1 # Should be valid because toolResult follows + + +def test_apply_management_no_op(summarizing_manager, mock_agent): + """Test apply_management does not modify messages (no-op behavior).""" + apply_test_messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More messages"}]}, + {"role": "assistant", "content": [{"text": "Even more"}]}, + ] + mock_agent.messages = apply_test_messages + original_messages = mock_agent.messages.copy() + + summarizing_manager.apply_management(mock_agent) + + # Should never modify messages - summarization only happens on context overflow + assert mock_agent.messages == original_messages + + +def test_init_with_custom_parameters(): + """Test initialization with custom parameters.""" + mock_agent = create_mock_agent() + + manager = SummarizingConversationManager( + summary_ratio=0.4, + preserve_recent_messages=5, + summarization_agent=mock_agent, + ) + assert manager.summary_ratio == 0.4 + assert manager.preserve_recent_messages == 5 + assert manager.summarization_agent == mock_agent + assert manager.summarization_system_prompt is None + + +def test_init_with_both_agent_and_prompt_raises_error(): + """Test that providing both agent and system prompt raises ValueError.""" + mock_agent = create_mock_agent() + custom_prompt = "Custom summarization prompt" + + with pytest.raises(ValueError, match="Cannot provide both summarization_agent and summarization_system_prompt"): + SummarizingConversationManager( + summarization_agent=mock_agent, + summarization_system_prompt=custom_prompt, + ) + + +def test_uses_summarization_agent_when_provided(): + """Test that summarization_agent is used when provided.""" + summary_agent = create_mock_agent("Custom summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the dedicated summarization agent, not the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Custom summary from dedicated agent" + + # Assert that the summarization agent was called + summary_agent.call_tracker.assert_called_once() + + +def test_uses_parent_agent_when_no_summarization_agent(): + """Test that parent agent is used when no summarization_agent is provided.""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Parent agent summary") + summary = manager._generate_summary(messages, parent_agent) + + # Should use the parent agent + summary_content = summary["content"][0] + assert "text" in summary_content and summary_content["text"] == "Parent agent summary" + + # Assert that the parent agent was called + parent_agent.call_tracker.assert_called_once() + + +def test_uses_custom_system_prompt(): + """Test that custom system prompt is used when provided.""" + custom_prompt = "Custom system prompt for summarization" + manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) + mock_agent = create_mock_agent() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Capture the agent's system prompt changes + original_prompt = mock_agent.system_prompt + manager._generate_summary(messages, mock_agent) + + # The agent's system prompt should be restored after summarization + assert mock_agent.system_prompt == original_prompt + + +def test_agent_state_restoration(): + """Test that agent state is properly restored after summarization.""" + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + + # Set initial state + original_system_prompt = "Original system prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] + mock_agent.system_prompt = original_system_prompt + mock_agent.messages = original_messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager._generate_summary(messages, mock_agent) + + # State should be restored + assert mock_agent.system_prompt == original_system_prompt + assert mock_agent.messages == original_messages + + +def test_agent_state_restoration_on_exception(): + """Test that agent state is restored even when summarization fails.""" + manager = SummarizingConversationManager() + + # Create an agent that fails during summarization + mock_agent = Mock() + mock_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + mock_agent.messages = agent_messages + mock_agent.side_effect = Exception("Summarization failed") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + # Should restore state even on exception + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, mock_agent) + + # State should still be restored + assert mock_agent.system_prompt == "Original prompt" + + +def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): + """Test that tool pair adjustment works correctly with the forward-search logic.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + mock_agent = create_mock_agent() + # Create messages where the split point would be adjusted to 0 due to tool pairs + tool_pair_messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest message"}]}, + ] + mock_agent.messages = tool_pair_messages + + # With 3 messages, preserve_recent_messages=1, summary_ratio=0.5: + # messages_to_summarize_count = (3 - 1) * 0.5 = 1 + # But split point adjustment will move forward from the toolUse, potentially increasing count + manager.reduce_context(mock_agent) + # Should have summary + remaining messages + assert len(mock_agent.messages) == 2 + + # First message should be the summary + assert mock_agent.messages[0]["role"] == "assistant" + summary_content = mock_agent.messages[0]["content"][0] + assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] + + # Last message should be the preserved recent message + assert mock_agent.messages[1]["role"] == "user" + assert mock_agent.messages[1]["content"][0]["text"] == "Latest message" + + +def test_adjust_split_point_exceeds_message_length(summarizing_manager): + """Test that split point exceeding message array length raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Try to split at point 5 when there are only 2 messages + with pytest.raises(ContextWindowOverflowException, match="Split point exceeds message array length"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 5) + + +def test_adjust_split_point_equals_message_length(summarizing_manager): + """Test that split point equal to message array length returns unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + ] + + # Split point equals message length (2) - should return unchanged + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 2 + + +def test_adjust_split_point_no_tool_result_at_split(summarizing_manager): + """Test split point that doesn't contain tool result, ensuring we reach return split_point.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + + # Split point message is not a tool result, so it should directly return split_point + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + assert result == 1 + + +def test_adjust_split_point_tool_result_without_tool_use(summarizing_manager): + """Test that having tool results without tool uses raises exception.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Has tool result but no tool use - invalid state + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 1) + + +def test_adjust_split_point_tool_result_moves_to_end(summarizing_manager): + """Test tool result at split point moves forward to valid position at end.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "different_tool", "input": {}}}]}, + ] + + # Split at message 2 (toolResult) - will move forward to message 3 (toolUse at end is valid) + result = summarizing_manager._adjust_split_point_for_tool_pairs(messages, 2) + assert result == 3 + + +def test_adjust_split_point_tool_result_no_forward_position(summarizing_manager): + """Test tool result at split point where forward search finds no valid position.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "Message between"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Tool output"}], "status": "success"}} + ], + }, + ] + + # Split at message 3 (toolResult) - will try to move forward but no valid position exists + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + summarizing_manager._adjust_split_point_for_tool_pairs(messages, 3) + + +def test_reduce_context_adjustment_returns_zero(): + """Test that tool pair adjustment can return zero, triggering the check at line 122.""" + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=1, + ) + + # Mock the adjustment method to return 0 + def mock_adjust(messages, split_point): + return 0 # This should trigger the <= 0 check at line 122 + + manager._adjust_split_point_for_tool_pairs = mock_adjust + + mock_agent = Mock() + simple_messages: Messages = [ + {"role": "user", "content": [{"text": "Message 1"}]}, + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "user", "content": [{"text": "Message 2"}]}, + ] + mock_agent.messages = simple_messages + + # The adjustment method will return 0, which should trigger line 122-123 + with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): + manager.reduce_context(mock_agent) diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef..fe1b3e9f 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f1..11f14503 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -104,27 +111,6 @@ def mock_tracer(): return tracer -@pytest.mark.parametrize( - ("kwargs", "exp_state"), - [ - ( - {"request_state": {"key1": "value1"}}, - {"key1": "value1"}, - ), - ( - {}, - {}, - ), - ], -) -def test_initialize_state(kwargs, exp_state): - kwargs = strands.event_loop.event_loop.initialize_state(**kwargs) - - tru_state = kwargs["request_state"] - - assert tru_state == exp_state - - def test_event_loop_cycle_text_response( model, model_id, @@ -157,7 +143,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( +def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -168,26 +155,12 @@ def test_event_loop_cycle_text_response_input_too_long( tool_execution_handler, ): model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, {"contentBlockStop": {}}, ], ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( model=model, @@ -204,10 +177,12 @@ def test_event_loop_cycle_text_response_input_too_long( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") -def test_event_loop_cycle_text_response_throttling( +def test_event_loop_cycle_exponential_backoff( + mock_time, model, model_id, system_prompt, @@ -217,7 +192,11 @@ def test_event_loop_cycle_text_response_throttling( tool_handler, tool_execution_handler, ): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), [ {"contentBlockDelta": {"delta": {"text": "test text"}}}, @@ -235,11 +214,16 @@ def test_event_loop_cycle_text_response_throttling( tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error( @@ -460,19 +444,6 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_prepare_next_cycle(): - kwargs = {"event_loop_cycle_id": "c1"} - event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics() - tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics) - exp_result = { - "event_loop_cycle_id": "c1", - "event_loop_parent_cycle_id": "c1", - "event_loop_metrics": event_loop_metrics, - } - - assert tru_result == exp_result - - def test_cycle_exception( model, system_prompt, @@ -728,3 +699,176 @@ def test_event_loop_cycle_with_parent_span( mock_tracer.start_event_loop_cycle_span.assert_called_once_with( event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages ) + + +def test_event_loop_cycle_callback( + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + model.converse.return_value = [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + callback_handler.assert_has_calls( + [ + call(start=True), + call(start_event_loop=True), + call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + call( + delta={"toolUse": {"input": '{"value"}'}}, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + call( + reasoningText="value", + delta={"reasoningContent": {"text": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + call( + reasoning_signature="value", + delta={"reasoningContent": {"signature": "value"}}, + reasoning=True, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call(event={"contentBlockStart": {"start": {}}}), + call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + call( + data="value", + delta={"text": "value"}, + model_id="m1", + event_loop_cycle_id=unittest.mock.ANY, + request_state={}, + event_loop_cycle_trace=unittest.mock.ANY, + event_loop_cycle_span=None, + ), + call(event={"contentBlockStop": {}}), + call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + ], + ) + + +def test_request_state_initialization(): + # Call without providing request_state + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + # Verify request_state was initialized to empty dict + assert tru_request_state == {} + + # Call with pre-existing request_state + initial_request_state = {"key": "value"} + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=MagicMock(), + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + request_state=initial_request_state, + ) + + # Verify existing request_state was preserved + assert tru_request_state == initial_request_state + + +def test_prepare_next_cycle_in_tool_execution(model, tool_stream): + """Test that cycle ID and metrics are properly updated during tool execution.""" + model.converse.side_effect = [ + tool_stream, + [ + {"contentBlockStop": {}}, + ], + ] + + # Create a mock for recurse_event_loop to capture the kwargs passed to it + with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: + # Set up mock to return a valid response + mock_recurse.return_value = ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ) + + # Call event_loop_cycle which should execute a tool and then call recurse_event_loop + strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=MagicMock(), + system_prompt=MagicMock(), + messages=MagicMock(), + tool_config=MagicMock(), + callback_handler=MagicMock(), + tool_handler=MagicMock(), + tool_execution_handler=MagicMock(), + ) + + assert mock_recurse.called + + # Verify required properties are present + recursive_kwargs = mock_recurse.call_args[1] + assert "event_loop_metrics" in recursive_kwargs + assert "event_loop_parent_cycle_id" in recursive_kwargs + assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"] diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py new file mode 100644 index 00000000..fcf531df --- /dev/null +++ b/tests/strands/event_loop/test_message_processor.py @@ -0,0 +1,47 @@ +import copy + +import pytest + +from strands.event_loop import message_processor + + +@pytest.mark.parametrize( + "messages,expected,expected_messages", + [ + # Orphaned toolUse with empty input, no toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + True, + [ + {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + ), + # toolUse with input, has matching toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + False, + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + ), + # No messages + ( + [], + False, + [], + ), + ], +) +def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): + test_messages = copy.deepcopy(messages) + result = message_processor.clean_orphaned_empty_tool_uses(test_messages) + assert result == expected + assert test_messages == expected_messages diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48..e91f4986 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -3,6 +3,7 @@ import pytest import strands +import strands.event_loop from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -17,13 +18,6 @@ def moto_autouse(moto_env, moto_mock_aws): _ = moto_mock_aws -@pytest.fixture -def agent(): - mock = unittest.mock.Mock() - - return mock - - @pytest.mark.parametrize( ("messages", "exp_result"), [ @@ -81,7 +75,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "exp_handler_args"), + ("event", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( @@ -148,21 +142,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, exp_handler_args): - if exp_handler_args: - exp_handler_args.update({"delta": event["delta"], "extra_arg": 1}) +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): + exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} - tru_handler_args = {} - - def callback_handler(**kwargs): - tru_handler_args.update(kwargs) - - tru_updated_state = strands.event_loop.streaming.handle_content_block_delta( - event, state, callback_handler, extra_arg=1 - ) + tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) assert tru_updated_state == exp_updated_state - assert tru_handler_args == exp_handler_args + assert tru_callback_event == exp_callback_event @pytest.mark.parametrize( @@ -275,8 +261,9 @@ def test_extract_usage_metrics(): @pytest.mark.parametrize( - ("response", "exp_stop_reason", "exp_message", "exp_usage", "exp_metrics", "exp_request_state", "exp_messages"), + ("response", "exp_events"), [ + # Standard Message ( [ {"messageStart": {"role": "assistant"}}, @@ -297,28 +284,127 @@ def test_extract_usage_metrics(): } }, ], - "tool_use", - { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", + }, + }, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + }, + }, + { + "callback": { + "current_tool_use": { + "input": { + "key": "value", + }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "tool_use", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "tool_use", + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], ), + # Empty Message ( [{}], - "end_turn", - { - "role": "assistant", - "content": [], - }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - {}, - [{"role": "user", "content": [{"text": "Some input!"}]}], + [ + { + "callback": { + "event": {}, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [], + }, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ), + }, + ], ), + # Redacted Message ( [ {"messageStart": {"role": "assistant"}}, @@ -345,77 +431,161 @@ def test_extract_usage_metrics(): } }, ], - "guardrail_intervened", - { - "role": "assistant", - "content": [{"text": "REDACTED."}], - }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, - {"latencyMs": 1}, - {"calls": 1}, - [{"role": "user", "content": [{"text": "REDACTED"}]}], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", + }, + }, + }, + }, + }, + { + "callback": { + "data": "Hello!", + "delta": { + "text": "Hello!", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", + }, + }, + }, + }, + { + "callback": { + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "guardrail_intervened", + { + "role": "assistant", + "content": [{"text": "REDACTED."}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], ), ], ) -def test_process_stream( - response, exp_stop_reason, exp_message, exp_usage, exp_metrics, exp_request_state, exp_messages -): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 - - tru_messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.process_stream(response, callback_handler, tru_messages) - ) - - assert tru_stop_reason == exp_stop_reason - assert tru_message == exp_message - assert tru_usage == exp_usage - assert tru_metrics == exp_metrics - assert tru_request_state == exp_request_state - assert tru_messages == exp_messages +def test_process_stream(response, exp_events): + messages = [{"role": "user", "content": [{"text": "Some input!"}]}] + stream = strands.event_loop.streaming.process_stream(response, messages) + tru_events = list(stream) + assert tru_events == exp_events -def test_stream_messages(agent): - def callback_handler(**kwargs): - if "request_state" in kwargs: - kwargs["request_state"].setdefault("calls", 0) - kwargs["request_state"]["calls"] += 1 +def test_stream_messages(): mock_model = unittest.mock.MagicMock() mock_model.converse.return_value = [ {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, ] - tru_stop_reason, tru_message, tru_usage, tru_metrics, tru_request_state = ( - strands.event_loop.streaming.stream_messages( - mock_model, - model_id="test_model", - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - callback_handler=callback_handler, - agent=agent, - ) + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], + tool_config=None, ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test"}]} - exp_usage = {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - exp_metrics = {"latencyMs": 0} - exp_request_state = {"calls": 1} - - assert ( - tru_stop_reason == exp_stop_reason - and tru_message == exp_message - and tru_usage == exp_usage - and tru_metrics == exp_metrics - and tru_request_state == exp_request_state - ) + tru_events = list(stream) + exp_events = [ + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", + }, + }, + }, + }, + }, + { + "callback": { + "data": "test", + "delta": { + "text": "test", + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "stop": ( + "end_turn", + {"role": "assistant", "content": [{"text": "test"}]}, + {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + {"latencyMs": 0}, + ) + }, + ] + assert tru_events == exp_events mock_model.converse.assert_called_with( [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 20e238cb..6fb2af07 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print): mock_print.assert_not_called() +def test_call_handler_reasoningText(handler, mock_print): + """Test calling the handler with reasoningText.""" + handler(reasoningText="This is reasoning text") + # Should print reasoning text without newline + mock_print.assert_called_once_with("This is reasoning text", end="") + + +def test_call_without_reasoningText(handler, mock_print): + """Test calling the handler without reasoningText argument.""" + handler(data="Some output") + # Should only print data, not reasoningText + mock_print.assert_called_once_with("Some output", end="") + + +def test_call_with_reasoningText_and_data(handler, mock_print): + """Test calling the handler with both reasoningText and data.""" + handler(reasoningText="Reasoning", data="Output") + # Should print reasoningText and data, both without newline + calls = [ + unittest.mock.call("Reasoning", end=""), + unittest.mock.call("Output", end=""), + ] + mock_print.assert_has_calls(calls) + + def test_call_with_data_incomplete(handler, mock_print): """Test calling the handler with data but not complete.""" handler(data="Test output") diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 48a1da37..a0cfc4d4 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,4 +1,3 @@ -import json import unittest.mock import anthropic @@ -102,19 +101,46 @@ def test_format_request_with_system_prompt(model, messages, model_id, max_tokens assert tru_request == exp_request -def test_format_request_with_document(model, model_id, max_tokens): +@pytest.mark.parametrize( + ("content", "formatted_content"), + [ + # PDF + ( + { + "document": {"format": "pdf", "name": "test doc", "source": {"bytes": b"pdf"}}, + }, + { + "source": { + "data": "cGRm", + "media_type": "application/pdf", + "type": "base64", + }, + "title": "test doc", + "type": "document", + }, + ), + # Plain text + ( + { + "document": {"format": "txt", "name": "test doc", "source": {"bytes": b"txt"}}, + }, + { + "source": { + "data": "txt", + "media_type": "text/plain", + "type": "text", + }, + "title": "test doc", + "type": "document", + }, + ), + ], +) +def test_format_request_with_document(content, formatted_content, model, model_id, max_tokens): messages = [ { "role": "user", - "content": [ - { - "document": { - "format": "pdf", - "name": "test-doc", - "source": {"bytes": b"base64encodeddoc"}, - }, - }, - ], + "content": [content], }, ] @@ -124,17 +150,7 @@ def test_format_request_with_document(model, model_id, max_tokens): "messages": [ { "role": "user", - "content": [ - { - "source": { - "data": "YmFzZTY0ZW5jb2RlZGRvYw==", - "media_type": "application/pdf", - "type": "base64", - }, - "title": "test-doc", - "type": "document", - }, - ], + "content": [formatted_content], }, ], "model": model_id, @@ -273,6 +289,7 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "status": "success", "content": [ {"text": "see image"}, + {"json": ["see image"]}, { "image": { "format": "jpg", @@ -299,6 +316,10 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "text": "see image", "type": "text", }, + { + "text": '["see image"]', + "type": "text", + }, { "source": { "data": "YmFzZTY0ZW5jb2RlZGltYWdl", @@ -322,33 +343,16 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): assert tru_request == exp_request -def test_format_request_with_other(model, model_id, max_tokens): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], + "content": [{"unsupported": {}}], }, ] - tru_request = model.format_request(messages) - exp_request = { - "max_tokens": max_tokens, - "messages": [ - { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], - }, - ], - "model": model_id, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_cache_point(model, model_id, max_tokens): @@ -611,10 +615,24 @@ def test_format_chunk_unknown(model): def test_stream(anthropic_client, model): - mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) - mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_1 = unittest.mock.Mock( + type="message_start", + dict=lambda: {"type": "message_start"}, + model_dump=lambda: {"type": "message_start"}, + ) + mock_event_2 = unittest.mock.Mock( + type="unknown", + dict=lambda: {"type": "unknown"}, + model_dump=lambda: {"type": "unknown"}, + ) mock_event_3 = unittest.mock.Mock( - type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + type="metadata", + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + dict=lambda: {"input_tokens": 1, "output_tokens": 2}, + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), ) mock_stream = unittest.mock.MagicMock() @@ -627,7 +645,10 @@ def test_stream(anthropic_client, model): tru_events = list(response) exp_events = [ {"type": "message_start"}, - {"type": "metadata", "usage": {"input_tokens": 1}}, + { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, ] assert tru_events == exp_events diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 566671ce..137b57c8 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,7 +1,9 @@ +import os import unittest.mock import boto3 import pytest +from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError, EventStreamError import strands @@ -89,6 +91,23 @@ def test__init__default_model_id(bedrock_client): assert tru_model_id == exp_model_id +def test__init__with_default_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + default_region = "us-west-2" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name=default_region) + # Assert that warning logs are emitted + mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") + mock_warning.assert_any_call( + "issue=<%s> | this behavior will change in an upcoming release", + "https://github.com/strands-agents/sdk-python/issues/238", + ) + + def test__init__with_custom_region(bedrock_client): """Test that BedrockModel uses the provided region.""" _ = bedrock_client @@ -99,12 +118,70 @@ def test__init__with_custom_region(bedrock_client): mock_session_cls.assert_called_once_with(region_name=custom_region) +def test__init__with_environment_variable_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + os.environ["AWS_REGION"] = "eu-west-1" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name="eu-west-1") + + def test__init__with_region_and_session_raises_value_error(): """Test that BedrockModel raises ValueError when both region and session are provided.""" with pytest.raises(ValueError): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) +def test__init__default_user_agent(bedrock_client): + """Set user agent when no boto_client_config is provided.""" + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel() + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + + +def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + """Set user agent when boto_client_config is provided without user_agent_extra.""" + custom_config = BotocoreConfig(read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 + + +def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): + """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" + custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 + + def test__init__model_config(bedrock_client): _ = bedrock_client @@ -294,9 +371,9 @@ def test_stream(bedrock_client, model): def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model): - error_message = "ThrottlingException - Rate exceeded" + error_message = "Rate exceeded" bedrock_client.converse_stream.side_effect = EventStreamError( - {"Error": {"Message": error_message}}, "ConverseStream" + {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" ) request = {"a": 1} @@ -361,7 +438,9 @@ def test_converse(bedrock_client, model, messages, tool_spec, model_id, addition bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_input_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): metadata_event = { "metadata": { "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -370,7 +449,15 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m "guardrail": { "inputAssessment": { "3e59qlue4hag": { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + } } } } @@ -395,13 +482,18 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_output_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) metadata_event = { "metadata": { @@ -413,7 +505,13 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -440,7 +538,10 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) @@ -460,7 +561,13 @@ def test_converse_output_guardrails_redacts_input_and_output( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -510,7 +617,13 @@ def test_converse_output_no_blocked_guardrails_doesnt_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "NONE", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "NONE", + "detected": True, + } + ] }, } ] @@ -556,7 +669,13 @@ def test_converse_output_no_guardrail_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -580,7 +699,9 @@ def test_converse_output_no_guardrail_redact( } model.update_config( - additional_request_fields=additional_request_fields, guardrail_redact_output=False, guardrail_redact_input=False + additional_request_fields=additional_request_fields, + guardrail_redact_output=False, + guardrail_redact_input=False, ) chunks = model.converse(messages, [tool_spec]) @@ -589,3 +710,322 @@ def test_converse_output_no_guardrail_redact( assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) + + +def test_stream_with_streaming_false(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "stopReason": "end_turn", + } + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_tool_use(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_reasoning(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_and_reasoning_no_signature(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard...."}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + { + "metadata": { + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + } + }, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_input_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + } + } + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails_redacts_output(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 5d4d9b40..528d1498 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -8,9 +7,14 @@ @pytest.fixture -def litellm_client(): +def litellm_client_cls(): with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: - yield mock_client_cls.return_value + yield mock_client_cls + + +@pytest.fixture +def litellm_client(litellm_client_cls): + return litellm_client_cls.return_value @pytest.fixture @@ -35,15 +39,15 @@ def system_prompt(): return "s1" -def test__init__model_configs(litellm_client, model_id): - _ = litellm_client +def test__init__(litellm_client_cls, model_id): + model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) - model = LiteLLMModel(model_id=model_id, params={"max_tokens": 1}) + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} - tru_max_tokens = model.get_config().get("params") - exp_max_tokens = {"max_tokens": 1} + assert tru_config == exp_config - assert tru_max_tokens == exp_max_tokens + litellm_client_cls.assert_called_once_with(api_key="k1") def test_update_config(model, model_id): @@ -55,513 +59,47 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_params(model, messages, model_id): - model.update_config(params={"max_tokens": 1}) - - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - "max_tokens": 1, - } - - assert tru_request == exp_request - - -def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) - exp_request = { - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": [{"type": "text", "text": "test"}]}, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_image(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "image": { - "format": "jpg", - "source": {"bytes": b"base64encodedimage"}, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "image_url": { - "detail": "auto", - "format": "image/jpeg", - "url": "data:image/jpeg;base64,base64encodedimage", - }, - "type": "image_url", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_reasoning(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "reasoningContent": { - "reasoningText": { - "signature": "reasoning_signature", - "text": "reasoning_text", - }, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Thinking + ( { - "role": "user", - "content": [ - { + "reasoningContent": { + "reasoningText": { "signature": "reasoning_signature", - "thinking": "reasoning_text", - "type": "thinking", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_video(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "video": { - "source": {"bytes": "base64encodedvideo"}, + "text": "reasoning_text", }, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": "base64encodedvideo", - }, - }, - ], }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_other(model, model_id): - messages = [ - { - "role": "user", - "content": [{"other": {"a": 1}}], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], + "signature": "reasoning_signature", + "thinking": "reasoning_text", + "type": "thinking", }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result(model, model_id): - messages = [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + ), + # Case 2: Video + ( { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), - "role": "tool", - "tool_call_id": "c1", - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_use(model, model_id): - messages = [ - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "c1", - "name": "calculator", - "input": {"expression": "2+2"}, - }, + "video": { + "source": {"bytes": "base64encodedvideo"}, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "content": [], - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "calculator", - "arguments": '{"expression": "2+2"}', - }, - "id": "c1", - "type": "function", - } - ], - } - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_specs(model, messages, model_id): - tool_specs = [ - { - "name": "calculator", - "description": "Calculate mathematical expressions", - "inputSchema": { - "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} }, - } - ] - - tru_request = model.format_request(messages, tool_specs) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [ { - "type": "function", - "function": { - "name": "calculator", - "description": "Calculate mathematical expressions", - "parameters": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, + "type": "video_url", + "video_url": { + "detail": "auto", + "url": "base64encodedvideo", }, - } - ], - } - - assert tru_request == exp_request - - -def test_format_chunk_message_start(model): - event = {"chunk_type": "message_start"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStart": {"role": "assistant"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_text(model): - event = {"chunk_type": "content_start", "data_type": "text"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_tool(model): - mock_tool_use = unittest.mock.Mock() - mock_tool_use.function.name = "calculator" - mock_tool_use.id = "c1" - - event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_use} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_text(model): - event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_tool(model): - event = { - "chunk_type": "content_delta", - "data_type": "tool", - "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_stop(model): - event = {"chunk_type": "content_stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStop": {}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_end_turn(model): - event = {"chunk_type": "message_stop", "data": "stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "end_turn"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_tool_use(model): - event = {"chunk_type": "message_stop", "data": "tool_calls"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "tool_use"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_max_tokens(model): - event = {"chunk_type": "message_stop", "data": "length"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_metadata(model): - event = { - "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = { - "metadata": { - "usage": { - "inputTokens": 100, - "outputTokens": 50, - "totalTokens": 150, - }, - "metrics": { - "latencyMs": 0, }, - }, - } - - assert tru_chunk == exp_chunk - - -def test_format_chunk_other(model): - event = {"chunk_type": "other"} - - with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): - model.format_chunk(event) - - -def test_stream(litellm_client, model): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] - ) - - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_2 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] - ) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls")]) - mock_event_4 = unittest.mock.Mock() - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) - - request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_4.usage}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) - - -def test_stream_empty(litellm_client, model): - mock_delta = unittest.mock.Mock(content=None, tool_calls=None) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")]) - mock_event_3 = unittest.mock.Mock(spec=[]) - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) - - request = {"model": "m1", "messages": [{"role": "user", "content": []}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) + ), + # Case 3: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = LiteLLMModel.format_request_message_content(content) + assert tru_result == exp_result diff --git a/tests/test_llamaapi.py b/tests/strands/models/test_llamaapi.py similarity index 94% rename from tests/test_llamaapi.py rename to tests/strands/models/test_llamaapi.py index 5d7cd0f4..309dac2e 100644 --- a/tests/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,4 +1,4 @@ -import json +# Copyright (c) Meta Platforms, Inc. and affiliates import unittest.mock import pytest @@ -144,7 +144,7 @@ def test_format_request_with_tool_result(model, model_id): "toolResult": { "toolUseId": "c1", "status": "success", - "content": [{"value": 4}], + "content": [{"text": "4"}, {"json": ["4"]}], } } ], @@ -155,12 +155,7 @@ def test_format_request_with_tool_result(model, model_id): exp_request = { "messages": [ { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], "role": "tool", "tool_call_id": "c1", }, @@ -233,6 +228,18 @@ def test_format_request_with_empty_content(model, model_id): assert tru_request == exp_request +def test_format_request_with_unsupported_type(model): + messages = [ + { + "role": "user", + "content": [{"unsupported": {}}], + }, + ] + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) + + def test_format_chunk_message_start(model): event = {"chunk_type": "message_start"} diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index d3bd8f7f..fe590dff 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -148,81 +148,24 @@ def test_format_request_with_tool_use(model, model_id): def test_format_request_with_tool_result(model, model_id): messages: Messages = [ { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"text": "4"}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": "4", - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_json(model, model_id): - messages: Messages = [ - { - "role": "tool", - "content": [ - {"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"json": {"result": 4}}]}} - ], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": {"result": 4}, - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_image(model, model_id): - messages: Messages = [ - { - "role": "tool", + "role": "user", "content": [ { "toolResult": { - "toolUseId": "image_generator", + "toolUseId": "calculator", "status": "success", - "content": [{"image": {"source": {"bytes": "base64encodedimage"}}}], - } - } + "content": [ + {"text": "4"}, + {"image": {"source": {"bytes": b"image"}}}, + {"json": ["4"]}, + ], + }, + }, + { + "text": "see results", + }, ], - } + }, ] tru_request = model.format_request(messages) @@ -230,46 +173,20 @@ def test_format_request_with_tool_result_image(model, model_id): "messages": [ { "role": "tool", - "content": json.dumps( - { - "name": "image_generator", - "result": "see images", - "status": "success", - } - ), - "images": ["base64encodedimage"], - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_other(model, model_id): - messages = [ - { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "other", "status": "success", "content": {"other": {"a": 1}}}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + "content": "4", + }, { "role": "tool", - "content": json.dumps( - { - "name": "other", - "result": {"other": {"a": 1}}, - "status": "success", - } - ), - } + "images": [b"image"], + }, + { + "role": "tool", + "content": '["4"]', + }, + { + "role": "user", + "content": "see results", + }, ], "model": model_id, "options": {}, @@ -280,29 +197,16 @@ def test_format_request_with_tool_result_other(model, model_id): assert tru_request == exp_request -def test_format_request_with_other(model, model_id): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], - } + "content": [{"unsupported": {}}], + }, ] - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": json.dumps({"other": {"a": 1}}), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_tool_specs(model, messages, model_id): diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py new file mode 100644 index 00000000..4c1f8528 --- /dev/null +++ b/tests/strands/models/test_openai.py @@ -0,0 +1,175 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def openai_client_cls(): + with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def openai_client(openai_client_cls): + return openai_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "m1" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return OpenAIModel(model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__(openai_client_cls, model_id): + model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1") + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_stream(openai_client, model): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_2 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] + ) + + mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_4.usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_empty(openai_client, model): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_with_empty_choices(openai_client, model): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Event with no choices attribute + mock_event_1 = unittest.mock.Mock(spec=[]) + + # Event with empty choices list + mock_event_2 = unittest.mock.Mock(choices=[]) + + # Valid event with content + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + + # Event with finish reason + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage info + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 4e84f0fd..215e1efd 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -1,9 +1,13 @@ import dataclasses import unittest +from unittest import mock import pytest +from opentelemetry.metrics._internal import _ProxyMeter +from opentelemetry.sdk.metrics import MeterProvider import strands +from strands.telemetry import MetricsClient from strands.types.streaming import Metrics, Usage @@ -117,6 +121,37 @@ def test_trace_end(mock_time, end_time, trace): assert tru_end_time == exp_end_time +@pytest.fixture +def mock_get_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_api.get_meter_provider") as mock_get_meter_provider: + MetricsClient._instance = None + meter_provider_mock = mock.MagicMock(spec=MeterProvider) + + mock_meter = mock.MagicMock() + mock_create_counter = mock.MagicMock() + mock_meter.create_counter.return_value = mock_create_counter + + mock_create_histogram = mock.MagicMock() + mock_meter.create_histogram.return_value = mock_create_histogram + meter_provider_mock.get_meter.return_value = mock_meter + + mock_get_meter_provider.return_value = meter_provider_mock + + yield mock_get_meter_provider + + +@pytest.fixture +def mock_sdk_meter_provider(): + with mock.patch("strands.telemetry.metrics.metrics_sdk.MeterProvider") as mock_meter_provider: + yield mock_meter_provider + + +@pytest.fixture +def mock_resource(): + with mock.patch("opentelemetry.sdk.resources.Resource") as mock_resource: + yield mock_resource + + def test_trace_add_child(child_trace, trace): trace.add_child(child_trace) @@ -162,11 +197,14 @@ def test_trace_to_dict(trace): @pytest.mark.parametrize("success", [True, False]) -def test_tool_metrics_add_call(success, tool, tool_metrics): +def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provider): tool = dict(tool, **{"name": "updated"}) duration = 1 + metrics_client = MetricsClient() - tool_metrics.add_call(tool, duration, success) + attributes = {"foo": "bar"} + + tool_metrics.add_call(tool, duration, success, metrics_client, attributes=attributes) tru_attrs = dataclasses.asdict(tool_metrics) exp_attrs = { @@ -177,12 +215,17 @@ def test_tool_metrics_add_call(success, tool, tool_metrics): "total_time": duration, } + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client.tool_call_count.add.assert_called_with(1, attributes=attributes) + metrics_client.tool_duration.record.assert_called_with(duration, attributes=attributes) + if success: + metrics_client.tool_success_count.add.assert_called_with(1, attributes=attributes) assert tru_attrs == exp_attrs @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") -def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics): +def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 mock_uuid4.return_value = "i1" @@ -192,6 +235,8 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} exp_attrs = {"cycle_count": 1, "traces": [tru_cycle_trace]} + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_cycle_count.add.assert_called() assert ( tru_start_time == exp_start_time and tru_cycle_trace.to_dict() == exp_cycle_trace.to_dict() @@ -200,10 +245,11 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): +def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace) + attributes = {"foo": "bar"} + event_loop_metrics.end_cycle(start_time=0, cycle_trace=trace, attributes=attributes) tru_cycle_durations = event_loop_metrics.cycle_durations exp_cycle_durations = [1] @@ -215,17 +261,23 @@ def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics): assert tru_trace_end_time == exp_trace_end_time + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_end_cycle.add.assert_called_with(1, attributes) + metrics_client.event_loop_cycle_duration.record.assert_called() + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") -def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics): +def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} event_loop_metrics.add_tool_usage(tool, duration, trace, success, message) + mock_get_meter_provider.return_value.get_meter.assert_called() + tru_event_loop_metrics_attrs = {"tool_metrics": event_loop_metrics.tool_metrics} exp_event_loop_metrics_attrs = { "tool_metrics": { @@ -258,7 +310,7 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me assert tru_trace_attrs == exp_trace_attrs -def test_event_loop_metrics_update_usage(usage, event_loop_metrics): +def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_usage(usage) @@ -270,9 +322,13 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics): ) assert tru_usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() + metrics_client = event_loop_metrics._metrics_client + metrics_client.event_loop_input_tokens.record.assert_called() + metrics_client.event_loop_output_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): +def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): for _ in range(3): event_loop_metrics.update_metrics(metrics) @@ -282,9 +338,11 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics): ) assert tru_metrics == exp_metrics + mock_get_meter_provider.return_value.get_meter.assert_called() + event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) -def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics): +def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): duration = 1 success = True message = {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "tool_name": "tool1"}}]} @@ -379,3 +437,31 @@ def test_metrics_to_string(trace, child_trace, tool_metrics, exp_str, event_loop tru_str = strands.telemetry.metrics.metrics_to_string(event_loop_metrics) assert tru_str == exp_str + + +def test_setup_meter_if_meter_provider_is_set( + mock_get_meter_provider, + mock_resource, +): + """Test global meter_provider and meter are used""" + mock_resource_instance = mock.MagicMock() + mock_resource.create.return_value = mock_resource_instance + + metrics_client = MetricsClient() + + mock_get_meter_provider.assert_called() + mock_get_meter_provider.return_value.get_meter.assert_called() + + assert metrics_client is not None + + +def test_use_ProxyMeter_if_no_global_meter_provider(): + """Return _ProxyMeter""" + # Reset the singleton instance + strands.telemetry.metrics.MetricsClient._instance = None + + # Create a new instance which should use the real _ProxyMeter + metrics_client = MetricsClient() + + # Verify it's using a _ProxyMeter + assert isinstance(metrics_client.meter, _ProxyMeter) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 55018c5e..6ae3e1ad 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1,11 +1,14 @@ import json import os +from datetime import date, datetime, timezone from unittest import mock import pytest -from opentelemetry.trace import StatusCode # type: ignore +from opentelemetry.trace import ( + StatusCode, # type: ignore +) -from strands.telemetry.tracer import Tracer, get_tracer +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -17,13 +20,25 @@ def moto_autouse(moto_env, moto_mock_aws): @pytest.fixture def mock_tracer_provider(): - with mock.patch("strands.telemetry.tracer.TracerProvider") as mock_provider: + with mock.patch("strands.telemetry.tracer.SDKTracerProvider") as mock_provider: yield mock_provider +@pytest.fixture +def mock_is_initialized(): + with mock.patch("strands.telemetry.tracer.Tracer._is_initialized") as mock_is_initialized: + yield mock_is_initialized + + +@pytest.fixture +def mock_get_tracer_provider(): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer_provider") as mock_get_tracer_provider: + yield mock_get_tracer_provider + + @pytest.fixture def mock_tracer(): - with mock.patch("strands.telemetry.tracer.trace.get_tracer") as mock_get_tracer: + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer") as mock_get_tracer: mock_tracer = mock.MagicMock() mock_get_tracer.return_value = mock_tracer yield mock_tracer @@ -37,13 +52,16 @@ def mock_span(): @pytest.fixture def mock_set_tracer_provider(): - with mock.patch("strands.telemetry.tracer.trace.set_tracer_provider") as mock_set: + with mock.patch("strands.telemetry.tracer.trace_api.set_tracer_provider") as mock_set: yield mock_set @pytest.fixture def mock_otlp_exporter(): - with mock.patch("strands.telemetry.tracer.OTLPSpanExporter") as mock_exporter: + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") as mock_exporter, + ): yield mock_exporter @@ -55,12 +73,67 @@ def mock_console_exporter(): @pytest.fixture def mock_resource(): - with mock.patch("strands.telemetry.tracer.Resource") as mock_resource: + with mock.patch("strands.telemetry.tracer.get_otel_resource") as mock_resource: + mock_resource_instance = mock.MagicMock() + mock_resource.return_value = mock_resource_instance yield mock_resource -def test_init_default(): +@pytest.fixture +def clean_env(): + """Fixture to provide a clean environment for each test.""" + with mock.patch.dict(os.environ, {}, clear=True): + yield + + +@pytest.fixture +def env_with_otlp(): + """Fixture with OTLP environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + }, + ): + yield + + +@pytest.fixture +def env_with_console(): + """Fixture with console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + +@pytest.fixture +def env_with_both(): + """Fixture with both OTLP and console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + +@pytest.fixture +def mock_initialize(): + with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize: + yield mock_initialize + + +def test_init_default(mock_is_initialized, mock_get_tracer_provider): """Test initializing the Tracer with default parameters.""" + mock_is_initialized.return_value = False + mock_get_tracer_provider.return_value = None + tracer = Tracer() assert tracer.service_name == "strands-agents" @@ -96,17 +169,20 @@ def test_init_with_env_headers(): def test_initialize_tracer_with_console( - mock_tracer_provider, mock_set_tracer_provider, mock_console_exporter, mock_resource + mock_is_initialized, + mock_tracer_provider, + mock_set_tracer_provider, + mock_console_exporter, + mock_resource, ): """Test initializing the tracer with console exporter.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False # Initialize Tracer Tracer(enable_console_export=True) # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify console exporter was added mock_console_exporter.assert_called_once() @@ -116,16 +192,21 @@ def test_initialize_tracer_with_console( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) -def test_initialize_tracer_with_otlp(mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource): +def test_initialize_tracer_with_otlp( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource +): """Test initializing the tracer with OTLP exporter.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False # Initialize Tracer - Tracer(otlp_endpoint="http://test-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://test-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was added with correct endpoint mock_otlp_exporter.assert_called_once() @@ -146,7 +227,7 @@ def test_start_span_no_tracer(): def test_start_span(mock_tracer): """Test starting a span with attributes.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -217,7 +298,7 @@ def test_end_span_with_error_message(mock_span): def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -255,7 +336,7 @@ def test_end_model_invoke_span(mock_span): def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -268,6 +349,9 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool" + mock_span.set_attribute.assert_any_call( + "gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}) + ) mock_span.set_attribute.assert_any_call("tool.name", "test-tool") mock_span.set_attribute.assert_any_call("tool.id", "123") mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"})) @@ -290,7 +374,7 @@ def test_end_tool_call_span(mock_span): def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -325,7 +409,7 @@ def test_end_event_loop_cycle_span(mock_span): def test_start_agent_span(mock_tracer): """Test starting an agent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -369,7 +453,7 @@ def test_end_agent_span(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.completion", '""') + mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response") mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) @@ -416,20 +500,24 @@ def test_get_tracer_parameters(): def test_initialize_tracer_with_invalid_otlp_endpoint( - mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_otlp_exporter, mock_resource ): """Test initializing the tracer with an invalid OTLP endpoint.""" - mock_resource_instance = mock.MagicMock() - mock_resource.create.return_value = mock_resource_instance + mock_is_initialized.return_value = False + mock_otlp_exporter.side_effect = Exception("Connection error") # This should not raise an exception, but should log an error # Initialize Tracer - Tracer(otlp_endpoint="http://invalid-endpoint") + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", True), + mock.patch("strands.telemetry.tracer.OTLPSpanExporter", mock_otlp_exporter), + ): + Tracer(otlp_endpoint="http://invalid-endpoint") # Verify the tracer provider was created with correct resource - mock_tracer_provider.assert_called_once_with(resource=mock_resource_instance) + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) # Verify OTLP exporter was attempted mock_otlp_exporter.assert_called_once() @@ -438,6 +526,42 @@ def test_initialize_tracer_with_invalid_otlp_endpoint( mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value) +def test_initialize_tracer_with_missing_module( + mock_is_initialized, mock_tracer_provider, mock_set_tracer_provider, mock_resource +): + """Test initializing the tracer when the OTLP exporter module is missing.""" + mock_is_initialized.return_value = False + + # Initialize Tracer with OTLP endpoint but missing module + with ( + mock.patch("strands.telemetry.tracer.HAS_OTEL_EXPORTER_MODULE", False), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + Tracer(otlp_endpoint="http://test-endpoint") + + # Verify the error message + assert "opentelemetry-exporter-otlp-proto-http not detected" in str(excinfo.value) + assert "otel http exporting is currently DISABLED" in str(excinfo.value) + + # Verify the tracer provider was created with correct resource + mock_tracer_provider.assert_called_once_with(resource=mock_resource.return_value) + + # Verify set_tracer_provider was not called since an exception was raised + mock_set_tracer_provider.assert_not_called() + + +def test_initialize_tracer_with_custom_tracer_provider(mock_is_initialized, mock_get_tracer_provider, mock_resource): + """Test initializing the tracer with NoOpTracerProvider.""" + mock_is_initialized.return_value = True + tracer = Tracer(otlp_endpoint="http://invalid-endpoint") + + mock_get_tracer_provider.assert_called() + mock_resource.assert_not_called() + + assert tracer.tracer_provider is not None + assert tracer.tracer is not None + + def test_end_span_with_exception_handling(mock_span): """Test ending a span with exception handling.""" tracer = Tracer() @@ -482,7 +606,7 @@ def test_end_tool_call_span_with_none(mock_span): def test_start_model_invoke_span_with_parent(mock_tracer): """Test starting a model invoke span with a parent span.""" - with mock.patch("strands.telemetry.tracer.trace.get_tracer", return_value=mock_tracer): + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): tracer = Tracer(enable_console_export=True) tracer.tracer = mock_tracer @@ -497,3 +621,230 @@ def test_start_model_invoke_span_with_parent(mock_tracer): # Verify span was returned assert span is mock_span + + +@pytest.mark.parametrize( + "input_data, expected_result", + [ + ("test string", '"test string"'), + (1234, "1234"), + (13.37, "13.37"), + (False, "false"), + (None, "null"), + ], +) +def test_json_encoder_serializable(input_data, expected_result): + """Test encoding of serializable values.""" + encoder = JSONEncoder() + + result = encoder.encode(input_data) + assert result == expected_result + + +def test_json_encoder_datetime(): + """Test encoding datetime and date objects.""" + encoder = JSONEncoder() + + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = encoder.encode(dt) + assert result == f'"{dt.isoformat()}"' + + d = date(2025, 1, 1) + result = encoder.encode(d) + assert result == f'"{d.isoformat()}"' + + +def test_json_encoder_list(): + """Test encoding a list with mixed content.""" + encoder = JSONEncoder() + + non_serializable = lambda x: x # noqa: E731 + + data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]] + + result = json.loads(encoder.encode(data)) + assert result == ["value", 42, 13.37, "", None, {"key": True}, ["value here"]] + + +def test_json_encoder_dict(): + """Test encoding a dict with mixed content.""" + encoder = JSONEncoder() + + class UnserializableClass: + def __str__(self): + return "Unserializable Object" + + non_serializable = lambda x: x # noqa: E731 + + now = datetime.now(timezone.utc) + + data = { + "metadata": { + "timestamp": now, + "version": "1.0", + "debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731 + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": non_serializable}, + {"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}], + }, + "list": [ + non_serializable, + 1234, + 13.37, + True, + None, + "string here", + ], + } + + expected = { + "metadata": { + "timestamp": now.isoformat(), + "version": "1.0", + "debug_info": {"object": "", "callable": ""}, + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": ""}, + {"type": "mixed", "values": [1, "text", "", {"nested": ""}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": ""}], + }, + "list": [ + "", + 1234, + 13.37, + True, + None, + "string here", + ], + } + + result = json.loads(encoder.encode(data)) + + assert result == expected + + +def test_json_encoder_value_error(): + """Test encoding values that cause ValueError.""" + encoder = JSONEncoder() + + # A very large integer that exceeds JSON limits and throws ValueError + huge_number = 2**100000 + + # Test in a dictionary + dict_data = {"normal": 42, "huge": huge_number} + result = json.loads(encoder.encode(dict_data)) + assert result == {"normal": 42, "huge": ""} + + # Test in a list + list_data = [42, huge_number] + result = json.loads(encoder.encode(list_data)) + assert result == [42, ""] + + # Test just the value + result = json.loads(encoder.encode(huge_number)) + assert result == "" + + +def test_serialize_non_ascii_characters(): + """Test that non-ASCII characters are preserved in JSON serialization.""" + + # Test with Japanese text + japanese_text = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" + result = serialize({"text": japanese_text}) + assert japanese_text in result + assert "\\u" not in result + + # Test with emoji + emoji_text = "Hello 🌍" + result = serialize({"text": emoji_text}) + assert emoji_text in result + assert "\\u" not in result + + # Test with Chinese characters + chinese_text = "δ½ ε₯½οΌŒδΈ–η•Œ" + result = serialize({"text": chinese_text}) + assert chinese_text in result + assert "\\u" not in result + + # Test with mixed content + mixed_text = {"ja": "こんにけは", "emoji": "😊", "zh": "δ½ ε₯½", "en": "hello"} + result = serialize(mixed_text) + assert "こんにけは" in result + assert "😊" in result + assert "δ½ ε₯½" in result + assert "\\u" not in result + + +def test_serialize_vs_json_dumps(): + """Test that serialize behaves differently from default json.dumps for non-ASCII characters.""" + + # Test with Japanese text + japanese_text = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" + + # Default json.dumps should escape non-ASCII characters + default_result = json.dumps({"text": japanese_text}) + assert "\\u" in default_result + + # Our serialize function should preserve non-ASCII characters + custom_result = serialize({"text": japanese_text}) + assert japanese_text in custom_result + assert "\\u" not in custom_result + + +def test_init_with_no_env_or_param(clean_env): + """Test initializing with neither environment variable nor constructor parameter.""" + tracer = Tracer() + assert tracer.otlp_endpoint is None + assert tracer.enable_console_export is False + + tracer = Tracer(otlp_endpoint="http://param-endpoint") + assert tracer.otlp_endpoint == "http://param-endpoint" + + tracer = Tracer(enable_console_export=True) + assert tracer.enable_console_export is True + + +def test_constructor_params_with_otlp_env(env_with_otlp): + """Test constructor parameters precedence over OTLP environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint") + assert tracer.otlp_endpoint == "http://constructor-endpoint" + + # Without constructor parameter, should use env var + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + + +def test_constructor_params_with_console_env(env_with_console): + """Test constructor parameters precedence over console environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(enable_console_export=False) + assert tracer.enable_console_export is False + + # Without explicit constructor parameter, should use env var + tracer = Tracer() + assert tracer.enable_console_export is True + + +def test_fallback_to_env_vars(env_with_both): + """Test fallback to environment variables when no constructor parameters.""" + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + assert tracer.enable_console_export is True + + # Constructor parameters should still take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint", enable_console_export=False) + assert tracer.otlp_endpoint == "http://constructor-endpoint" + assert tracer.enable_console_export is False diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index ced2bd7f..4b238792 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,6 +1,7 @@ import concurrent import functools import unittest.mock +import uuid import pytest @@ -9,6 +10,11 @@ from strands.types.content import Message +@pytest.fixture(autouse=True) +def moto_autouse(moto_env): + _ = moto_env + + @pytest.fixture def tool_handler(request): def handler(tool_use): @@ -37,6 +43,12 @@ def tool_uses(request, tool_use): return request.param if hasattr(request, "param") else [tool_use] +@pytest.fixture +def mock_metrics_client(): + with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: + yield mock_metrics_client + + @pytest.fixture def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() @@ -52,10 +64,10 @@ def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] -@unittest.mock.patch.object(strands.telemetry.metrics, "uuid4", return_value="trace1") @pytest.fixture def cycle_trace(): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") + with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): + return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") @pytest.fixture @@ -297,6 +309,7 @@ def test_run_tools_creates_and_ends_span_on_success( mock_get_tracer, tool_handler, tool_uses, + mock_metrics_client, event_loop_metrics, request_state, invalid_tool_use_ids, diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3dca5371..1b274f46 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,8 +2,11 @@ Tests for the SDK tool registry module. """ +from unittest.mock import MagicMock + import pytest +from strands.tools import PythonAgentTool from strands.tools.registry import ToolRegistry @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path(): with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): tool_registry.process_tools([invalid_path]) + + +def test_register_tool_with_similar_name_raises(): + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + + with pytest.raises(ValueError) as err: + tool_registry.register_tool(tool_2) + + assert ( + str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " + "Cannot add a duplicate tool which differs by a '-' or '_'" + ) diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py new file mode 100644 index 00000000..2e354b83 --- /dev/null +++ b/tests/strands/tools/test_structured_output.py @@ -0,0 +1,228 @@ +from typing import Literal, Optional + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output import convert_pydantic_to_tool_spec +from strands.types.tools import ToolSpec + + +# Basic test model +class User(BaseModel): + """User model with name and age.""" + + name: str = Field(description="The name of the user") + age: int = Field(description="The age of the user", ge=18, le=100) + + +# Test model with inheritance and literals +class UserWithPlanet(User): + """User with planet.""" + + planet: Literal["Earth", "Mars"] = Field(description="The planet") + + +# Test model with multiple same type fields and optional field +class TwoUsersWithPlanet(BaseModel): + """Two users model with planet.""" + + user1: UserWithPlanet = Field(description="The first user") + user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + + +# Test model with list of same type fields +class ListOfUsersWithPlanet(BaseModel): + """List of users model with planet.""" + + users: list[UserWithPlanet] = Field(description="The users", min_length=2, max_length=3) + + +def test_convert_pydantic_to_tool_spec_basic(): + tool_spec = convert_pydantic_to_tool_spec(User) + + expected_spec = { + "name": "User", + "description": "User model with name and age.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + }, + "title": "User", + "description": "User model with name and age.", + "required": ["name", "age"], + } + }, + } + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + assert tool_spec == expected_spec + + +def test_convert_pydantic_to_tool_spec_complex(): + tool_spec = convert_pydantic_to_tool_spec(ListOfUsersWithPlanet) + + expected_spec = { + "name": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "users": { + "description": "The users", + "items": { + "description": "User with planet.", + "title": "UserWithPlanet", + "type": "object", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "maxItems": 3, + "minItems": 2, + "title": "Users", + "type": "array", + } + }, + "title": "ListOfUsersWithPlanet", + "description": "List of users model with planet.", + "required": ["users"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_to_tool_spec_multiple_same_type(): + tool_spec = convert_pydantic_to_tool_spec(TwoUsersWithPlanet) + + expected_spec = { + "name": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "user1": { + "type": "object", + "description": "The first user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + "user2": { + "type": ["object", "null"], + "description": "The second user", + "properties": { + "name": {"description": "The name of the user", "title": "Name", "type": "string"}, + "age": { + "description": "The age of the user", + "maximum": 100, + "minimum": 18, + "title": "Age", + "type": "integer", + }, + "planet": { + "description": "The planet", + "enum": ["Earth", "Mars"], + "title": "Planet", + "type": "string", + }, + }, + "required": ["name", "age", "planet"], + }, + }, + "title": "TwoUsersWithPlanet", + "description": "Two users model with planet.", + "required": ["user1"], + } + }, + } + + assert tool_spec == expected_spec + + # Verify we can construct a valid ToolSpec + tool_spec_obj = ToolSpec(**tool_spec) + assert tool_spec_obj is not None + + +def test_convert_pydantic_with_missing_refs(): + """Test that the tool handles missing $refs gracefully.""" + # This test checks that our error handling for missing $refs works correctly + # by testing with a model that has circular references + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = Field(None, description="Parent node") + children: list["NodeWithCircularRef"] = Field(default_factory=list, description="Child nodes") + + # This forward reference normally causes issues with schema generation + # but our error handling should prevent errors + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_custom_description(): + """Test that custom descriptions override model docstrings.""" + + # Test with custom description + custom_description = "Custom tool description for user model" + tool_spec = convert_pydantic_to_tool_spec(User, description=custom_description) + + assert tool_spec["description"] == custom_description + + +def test_convert_pydantic_with_empty_docstring(): + """Test that empty docstrings use default description.""" + + class EmptyDocUser(BaseModel): + name: str = Field(description="The name of the user") + + tool_spec = convert_pydantic_to_tool_spec(EmptyDocUser) + assert tool_spec["description"] == "EmptyDocUser structured output tool" diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 8a6b406f..1b65156b 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -14,9 +14,13 @@ def test_validate_tool_use_name_valid(): - tool = {"name": "valid_tool_name", "toolUseId": "123"} + tool1 = {"name": "valid_tool_name", "toolUseId": "123"} + # Should not raise an exception + validate_tool_use_name(tool1) + + tool2 = {"name": "valid-name", "toolUseId": "123"} # Should not raise an exception - validate_tool_use_name(tool) + validate_tool_use_name(tool2) def test_validate_tool_use_name_missing(): @@ -46,11 +50,10 @@ def test_validate_tool_use(): def test_normalize_schema_basic(): schema = {"type": "object"} normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert normalized["properties"] == {} - assert "required" in normalized - assert normalized["required"] == [] + + expected = {"type": "object", "properties": {}, "required": []} + + assert normalized == expected def test_normalize_schema_with_properties(): @@ -62,14 +65,17 @@ def test_normalize_schema_with_properties(): }, } normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "User name" - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["description"] == "User age" + + expected = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_removed(): @@ -78,27 +84,40 @@ def test_normalize_schema_with_property_removed(): "properties": {"name": "invalid"}, } normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_defaults(): schema = {"properties": {"name": {}}} normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_enum(): schema = {"properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}} normalized = normalize_schema(schema) - assert "color" in normalized["properties"] - assert normalized["properties"]["color"]["type"] == "string" - assert normalized["properties"]["color"]["description"] == "color" - assert "enum" in normalized["properties"]["color"] - assert normalized["properties"]["color"]["enum"] == ["red", "green", "blue"] + + expected = { + "type": "object", + "properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_numeric_constraints(): @@ -109,21 +128,170 @@ def test_normalize_schema_with_property_numeric_constraints(): } } normalized = normalize_schema(schema) - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["minimum"] == 0 - assert normalized["properties"]["age"]["maximum"] == 120 - assert "score" in normalized["properties"] - assert normalized["properties"]["score"]["type"] == "number" - assert normalized["properties"]["score"]["minimum"] == 0.0 - assert normalized["properties"]["score"]["maximum"] == 100.0 + + expected = { + "type": "object", + "properties": { + "age": {"type": "integer", "description": "age", "minimum": 0, "maximum": 120}, + "score": {"type": "number", "description": "score", "minimum": 0.0, "maximum": 100.0}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_required(): schema = {"type": "object", "required": ["name", "email"]} normalized = normalize_schema(schema) - assert "required" in normalized - assert normalized["required"] == ["name", "email"] + + expected = {"type": "object", "properties": {}, "required": ["name", "email"]} + + assert normalized == expected + + +def test_normalize_schema_with_nested_object(): + """Test normalization of schemas with nested objects.""" + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + assert normalized == expected + + +def test_normalize_schema_with_deeply_nested_objects(): + """Test normalization of deeply nested object structures.""" + schema = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": {"type": "object", "properties": {"value": {"type": "string", "const": "fixed"}}} + }, + } + }, + } + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": { + "type": "object", + "properties": { + "value": {"type": "string", "description": "Property value", "const": "fixed"} + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_const_constraint(): + """Test that const constraints are preserved.""" + schema = { + "type": "object", + "properties": { + "status": {"type": "string", "const": "ACTIVE"}, + "config": {"type": "object", "properties": {"mode": {"type": "string", "const": "PRODUCTION"}}}, + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "status": {"type": "string", "description": "Property status", "const": "ACTIVE"}, + "config": { + "type": "object", + "properties": {"mode": {"type": "string", "description": "Property mode", "const": "PRODUCTION"}}, + "required": [], + }, + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_additional_properties(): + """Test that additionalProperties constraint is preserved.""" + schema = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": {"type": "object", "properties": {"id": {"type": "string"}}, "additionalProperties": False} + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": { + "type": "object", + "additionalProperties": False, + "properties": {"id": {"type": "string", "description": "Property id"}}, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected def test_normalize_tool_spec_with_json_schema(): @@ -133,14 +301,20 @@ def test_normalize_tool_spec_with_json_schema(): "inputSchema": {"json": {"type": "object", "properties": {"query": {}}, "required": ["query"]}}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["properties"]["query"]["type"] == "string" - assert normalized["inputSchema"]["json"]["properties"]["query"]["description"] == "Property query" + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_with_direct_schema(): @@ -150,22 +324,29 @@ def test_normalize_tool_spec_with_direct_schema(): "inputSchema": {"type": "object", "properties": {"query": {}}, "required": ["query"]}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["required"] == ["query"] + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_without_input_schema(): tool_spec = {"name": "test_tool", "description": "A test tool"} normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not present - assert "inputSchema" not in normalized + + expected = {"name": "test_tool", "description": "A test tool"} + + assert normalized == expected def test_normalize_tool_spec_empty_input_schema(): @@ -175,10 +356,10 @@ def test_normalize_tool_spec_empty_input_schema(): "inputSchema": "", } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not a dict - assert normalized["inputSchema"] == "" + + expected = {"name": "test_tool", "description": "A test tool", "inputSchema": ""} + + assert normalized == expected def test_validate_tool_use_with_valid_input(): diff --git a/tests/strands/types/models/__init__.py b/tests/strands/types/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py new file mode 100644 index 00000000..03690733 --- /dev/null +++ b/tests/strands/types/models/test_model.py @@ -0,0 +1,98 @@ +from typing import Type + +import pytest +from pydantic import BaseModel + +from strands.types.models import Model as SAModel + + +class Person(BaseModel): + name: str + age: int + + +class TestModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: + return output_model(name="test", age=20) + + def format_request(self, messages, tool_specs, system_prompt): + return { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + } + + def format_chunk(self, event): + return {"event": event} + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test_converse(model, messages, tool_specs, system_prompt): + response = model.converse(messages, tool_specs, system_prompt) + + tru_events = list(response) + exp_events = [ + { + "event": { + "request": { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + }, + }, + }, + ] + assert tru_events == exp_events + + +def test_structured_output(model): + response = model.structured_output(Person) + + assert response == Person(name="test", age=20) diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py new file mode 100644 index 00000000..3a1a940b --- /dev/null +++ b/tests/strands/types/models/test_openai.py @@ -0,0 +1,376 @@ +import base64 +import unittest.mock + +import pytest + +from strands.types.models import OpenAIModel as SAOpenAIModel + + +class TestOpenAIModel(SAOpenAIModel): + def __init__(self): + self.config = {"model_id": "m1", "params": {"max_tokens": 1}} + + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestOpenAIModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "file": { + "file_data": "data:application/pdf;base64,ZG9jdW1lbnQ=", + "filename": "test doc", + }, + "type": "file", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,aW1hZ2U=", + }, + "type": "image_url", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = SAOpenAIModel.format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + SAOpenAIModel.format_request_message_content(content) + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_message_tool_call(tool_use) + exp_result = { + "function": { + "arguments": '{"expression": "2+2"}', + "name": "calculator", + }, + "id": "c1", + "type": "function", + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"text": "4"}, {"json": ["4"]}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], + "role": "user", + }, + ] + + tru_result = SAOpenAIModel.format_request_messages(messages, system_prompt) + exp_result = [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + { + "content": [{"text": "call tool", "type": "text"}], + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + "id": "c1", + "type": "function", + } + ], + }, + { + "content": [{"text": "4", "type": "text"}], + "role": "tool", + "tool_call_id": "c1", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "max_tokens": 1, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model.format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result