8000 feat: Update for anthropic models · google/adk-python@16f7d98 · GitHub
[go: up one dir, main page]

Skip to content

Commit 16f7d98

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Update for anthropic models
Enable parallel tools for anthropic models, and add agent examples, and also added functional test for anthropic models. PiperOrigin-RevId: 766703018
1 parent 44f5078 commit 16f7d98

File tree

5 files changed

+312
-9
lines changed

5 files changed

+312
-9
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from . import agent
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import random
17+
18+
from google.adk import Agent
19+
from google.adk.models.anthropic_llm import Claude
20+
21+
22+
def roll_die(sides: int) -> int:
23+
"""Roll a die and return the rolled result.
24+
25+
Args:
26+
sides: The integer number of sides the die has.
27+
28+
Returns:
29+
An integer of the result of rolling the die.
30+
"""
31+
return random.randint(1, sides)
32+
33+
34+
async def check_prime(nums: list[int]) -> str:
35+
"""Check if a given list of numbers are prime.
36+
37+
Args:
38+
nums: The list of numbers to check.
39+
40+
Returns:
41+
A str indicating which number is prime.
42+
"""
43+
primes = set()
44+
for number in nums:
45+
number = int(number)
46+
if number <= 1:
47+
continue
48+
is_prime = True
49+
for i in range(2, int(number**0.5) + 1):
50+
if number % i == 0:
51+
is_prime = False
52+
break
53+
if is_prime:
54+
primes.add(number)
55+
return (
56+
"No prime numbers found."
57+
if not primes
58+
else f"{', '.join(str(num) for num in primes)} are prime numbers."
59+
)
60+
61+
62+
root_agent = Agent(
63+
model=Claude(model="claude-3-5-sonnet-v2@20241022"),
64+
name="data_processing_agent",
65+
description=(
66+
"hello world agent that can roll a dice of 8 sides and check prime"
67+
" numbers."
68+
),
69+
instruction="""
70+
You roll dice and answer questions about the outcome of the dice rolls.
71+
You can roll dice of different sizes.
72+
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
73+
It is ok to discuss previous dice roles, and comment on the dice rolls.
74+
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
75+
You should never roll a die on your own.
76+
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
77+
You should not check prime numbers before calling the tool.
78+
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
79+
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
80+
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
81+
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
82+
3. When you respond, you must include the roll_die result from step 1.
83+
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
84+
You should not rely on the previous history on prime results.
85+
""",
86+
tools=[
87+
roll_die,
88+
check_prime,
89+
],
90+
)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import asyncio
17+
import time
18+
19+
import agent
20+
from dotenv import load_dotenv
21+
from google.adk import Runner
22+
from google.adk.artifacts import InMemoryArtifactService
23+
from google.adk.cli.utils import logs
24+
from google.adk.sessions import InMemorySessionService
25+
from google.adk.sessions import Session
26+
from google.genai import types
27+
28+
load_dotenv(override=True)
29+
logs.log_to_tmp_folder()
30+
31+
32+
async def main():
33+
app_name = 'my_app'
34+
user_id_1 = 'user1'
35+
session_service = InMemorySessionService()
36+
artifact_service = InMemoryArtifactService()
37+
runner = Runner(
38+
app_name=app_name,
39+
agent=agent.root_agent,
40+
artifact_service=artifact_service,
41+
session_service=session_service,
42+
)
43+
session_11 = await session_service.create_session(
44+
app_name=app_name, user_id=user_id_1
45+
)
46+
47+
async def run_prompt(session: Session, new_message: str):
48+
content = types.Content(
49+
role='user', parts=[types.Part.from_text(text=new_message)]
50+
)
51+
print('** User says:', content.model_dump(exclude_none=True))
52+
async for event in runner.run_async(
53+
user_id=user_id_1,
54+
session_id=session.id,
55+
new_message=content,
56+
):
57+
if event.content.parts and event.content.parts[0].text:
58+
print(f'** {event.author}: {event.content.parts[0].text}')
59+
60+
start_time = time.time()
61+
print('Start time:', start_time)
62+
print('------------------------------------')
63+
await run_prompt(session_11, 'Hi, introduce yourself.')
64+
await run_prompt(
65+
session_11,
66+
'Run the following request 10 times: roll a die with 100 sides and check'
67+
' if it is prime',
68+
)
69+
end_time = time.time()
70+
print('------------------------------------')
71+
print('End time:', end_time)
72+
print('Total time:', end_time - start_time)
73+
74+
75+
if __name__ == '__main__':
76+
asyncio.run(main())

src/google/adk/models/anthropic_llm.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def content_block_to_part(
135135
def message_to_generate_content_response(
136136
message: anthropic_types.Message,
137137
) -> LlmResponse:
138+
logger.info(
139+
"Claude response: %s",
140+
message.model_dump_json(indent=2, exclude_none=True),
141+
)
138142

139143
return LlmResponse(
140144
content=types.Content(
@@ -229,14 +233,11 @@ async def generate_content_async(
229233
for tool in llm_request.config.tools[0].function_declarations
230234
]
231235
tool_choice = (
232-
anthropic_types.ToolChoiceAutoParam(
233-
type="auto",
234-
# TODO: allow parallel tool use.
235-
disable_parallel_tool_use=True,
236-
)
236+
anthropic_types.ToolChoiceAutoParam(type="auto")
237237
if llm_request.tools_dict
238238
else NOT_GIVEN
239239
)
240+
# TODO(b/421255973): Enable streaming for anthropic models.
240241
message = self._anthropic_client.messages.create(
241242
model=llm_request.model,
242243
system=llm_request.config.system_instruction,
@@ -245,10 +246,6 @@ async def generate_content_async(
245246
tool_choice=tool_choice,
246247
max_tokens=MAX_TOKEN,
247248
)
248-
logger.info(
249-
"Claude response: %s",
250-
message.model_dump_json(indent=2, exclude_none=True),
251-
)
252249
yield message_to_generate_content_response(message)
253250

254251
@cached_property
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
from unittest import mock
18+
19+
from anthropic import types as anthropic_types
20+
from google.adk import version as adk_version
21+
from google.adk.models import anthropic_llm
22+
from google.adk.models.anthropic_llm import Claude
23+
from google.adk.models.llm_request import LlmRequest
24+
from google.adk.models.llm_response import LlmResponse
25+
from google.genai import types
26+
from google.genai import version as genai_version
27+
from google.genai.types import Content
28+
from google.genai.types import Part
29+
import pytest
30+
31+
32+
@pytest.fixture
33+
def generate_content_response():
34+
return anthropic_types.Message(
35+
id="msg_vrtx_testid",
36+
content=[
37+
anthropic_types.TextBlock(
38+
citations=None, text=" 10000 ;Hi! How can I help you today?", type="text"
39+
)
40+
],
41+
model="claude-3-5-sonnet-v2-20241022",
42+
role="assistant",
43+
stop_reason="end_turn",
44+
stop_sequence=None,
45+
type="message",
46+
usage=anthropic_types.Usage(
47+
cache_creation_input_tokens=0,
48+
cache_read_input_tokens=0,
49+
input_tokens=13,
50+
output_tokens=12,
51+
server_tool_use=None,
52+
service_tier=None,
53+
),
54+
)
55+
56+
57+
@pytest.fixture
58+
def generate_llm_response():
59+
return LlmResponse.create(
60+
types.GenerateContentResponse(
61+
candidates=[
62+
types.Candidate(
63+
content=Content(
64+
role="model",
65+
parts=[Part.from_text(text="Hello, how can I help you?")],
66+
),
67+
finish_reason=types.FinishReason.STOP,
68+
)
69+
]
70+
)
71+
)
72+
73+
74+
@pytest.fixture
75+
def claude_llm():
76+
return Claude(model="claude-3-5-sonnet-v2@20241022")
77+
78+
79+
@pytest.fixture
80+
def llm_request():
81+
return LlmRequest(
82+
model="claude-3-5-sonnet-v2@20241022",
83+
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
84+
config=types.GenerateContentConfig(
85+
temperature=0.1,
86+
response_modalities=[types.Modality.TEXT],
87+
system_instruction="You are a helpful assistant",
88+
),
89+
)
90+
91+
92+
def test_supported_models():
93+
models = Claude.supported_models()
94+
assert len(models) == 2
95+
assert models[0] == r"claude-3-.*"
96+
assert models[1] == r"claude-.*-4.*"
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_generate_content_async(
101+
claude_llm, llm_request, generate_content_response, generate_llm_response
102+
):
103+
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
104+
with mock.patch.object(
105+
anthropic_llm,
106+
"message_to_generate_content_response",
107+
return_value=generate_llm_response,
108+
):
109+
# Create a mock coroutine that returns the generate_content_response.
110+
async def mock_coro():
111+
return generate_content_response
112+
113+
# Assign the coroutine to the mocked method
114+
mock_client.messages.create.return_value = mock_coro()
115+
116+
responses = [
117+
resp
118+
async for resp in claude_llm.generate_content_async(
119+
llm_request, stream=False
120+
)
121+
]
122+
assert len(responses) == 1
123+
assert isinstance(responses[0], LlmResponse)
124+
assert responses[0].content.parts[0].text == "Hello, how can I help you?"

0 commit comments

Comments
 (0)
0