8000 Copybara import of the project: · google/adk-python@05f4834 · GitHub
[go: up one dir, main page]

Skip to content

Commit 05f4834

Browse files
lucasnobre212copybara-github
authored andcommitted
Copybara import of the project:
-- cef3ca1 by Lucas Nobre <lucaas.sn@gmail.com>: feat: Add index tracking to handle parallel tool call using litellm COPYBARA_INTEGRATE_REVIEW=#759 from lucasnobre212:fix/issue_484 65e2293 PiperOrigin-RevId: 764902433
1 parent 841e10a commit 05f4834

File tree

2 files changed

+202
-20
lines changed

2 files changed

+202
-20
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
6262
id: Optional[str]
6363
name: Optional[str]
6464
args: Optional[str]
65+
index: Optional[int] = 0
6566

6667

6768
class TextChunk(BaseModel):
@@ -386,6 +387,7 @@ def _model_response_to_chunk(
386387
id=tool_call.id,
387388
name=tool_call.function.name,
388389
args=tool_call.function.arguments,
390+
index=tool_call.index,
389391
), finish_reason
390392

391393
if finish_reason and not (
@@ -661,9 +663,8 @@ async def generate_content_async(
661663

662664
if stream:
663665
text = ""
664-
function_name = ""
665-
function_args = ""
666-
function_id = None
666+
# Track function calls by index
667+
function_calls = {} # index -> {name, args, id}
667668
completion_args["stream"] = True
668669
aggregated_llm_response = None
669670
aggregated_llm_response_with_tool_call = None
@@ -672,11 +673,17 @@ async def generate_content_async(
672673
for part in self.llm_client.completion(**completion_args):
673674
for chunk, finish_reason in _model_response_to_chunk(part):
674675
if isinstance(chunk, FunctionChunk):
676+
index = chunk.index or 0
677+
if index not in function_calls:
678+
function_calls[index] = {"name": "", "args": "", "id": None}
679+
675680
if chunk.name:
676-
function_name += chunk.name
681+
function_calls[index]["name"] += chunk.name
677682
if chunk.args:
678-
function_args += chunk.args
679-
function_id = chunk.id or function_id
683+
function_calls[index]["args"] += chunk.args
684+
function_calls[index]["id"] = (
685+
chunk.id or function_calls[index]["id"]
686+
)
680687
elif isinstance(chunk, TextChunk):
681688
text += chunk.text
682689
yield _message_to_generate_content_response(
@@ -693,28 +700,31 @@ async def generate_content_async(
693700
total_token_count=chunk.total_tokens,
694701
)
695702

696-
if finish_reason == "tool_calls" and function_id:
703+
if finish_reason == "tool_calls" and function_calls:
704+
tool_calls = []
705+
for index, func_data in function_calls.items():
706+
if func_data["id"]:
707+
tool_calls.append(
708+
ChatCompletionMessageToolCall(
709+
type="function",
710+
id=func_data["id"],
711+
function=Function(
712+
name=func_data["name"],
713+
arguments=func_data["args"],
714+
index=index,
715+
),
716+
)
717+
)
697718
aggregated_llm_response_with_tool_call = (
698719
_message_to_generate_content_response(
699720
ChatCompletionAssistantMessage(
700721
role="assistant",
701722
content="",
702-
tool_calls=[
703-
ChatCompletionMessageToolCall(
704-
type="function",
705-
id=function_id,
706-
function=Function(
707-
name=function_name,
708-
arguments=function_args,
709-
),
710-
)
711-
],
723+
tool_calls=tool_calls,
712724
)
713725
)
714726
)
715-
function_name = ""
716-
function_args = ""
717-
function_id = None
727+
function_calls.clear()
718728
elif finish_reason == "stop" and text:
719729
aggregated_llm_response = _message_to_generate_content_response(
720730
ChatCompletionAssistantMessage(role="assistant", content=text)

tests/unittests/models/test_litellm.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from litellm.types.utils import ModelResponse
3939
from litellm.types.utils import StreamingChoices
4040
import pytest
41+
import json
4142

4243
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
4344
contents=[
@@ -170,6 +171,100 @@
170171
),
171172
]
172173

174+
MULTIPLE_FUNCTION_CALLS_STREAM = [
175+
ModelResponse(
176+
choices=[
177+
StreamingChoices(
178+
finish_reason=None,
179+
delta=Delta(
180+
role="assistant",
181+
tool_calls=[
182+
ChatCompletionDeltaToolCall(
183+
type="function",
184+
id="call_1",
185+
function=Function(
186+
name="function_1",
187+
arguments='{"arg": "val',
188+
),
189+
index=0,
190+
)
191+
],
192+
),
193+
)
194+
]
195+
),
196+
ModelResponse(
197+
choices=[
198+
StreamingChoices(
199+
finish_reason=None,
200+
delta=Delta(
201+
role="assistant",
202+
tool_calls=[
203+
ChatCompletionDeltaToolCall(
204+
type="function",
205+
id=None,
206+
function=Function(
207+
name=None,
208+
arguments='ue1"}',
209+
),
210+
index=0,
211+
)
212+
],
213+
),
214+
)
215+
]
216+
),
217+
ModelResponse(
218+
choices=[
219+
StreamingChoices(
220+
finish_reason=None,
221+
delta=Delta(
222+
role="assistant",
223+
tool_calls=[
224+
ChatCompletionDeltaToolCall(
225+
type="function",
226+
id="call_2",
227+
function=Function(
228+
name="function_2",
229+
arguments='{"arg": "val',
230+
),
231+
index=1,
232+
)
233+
],
234+
),
235+
)
236+
]
237+
),
238+
ModelResponse(
239+
choices=[
240+
StreamingChoices(
241+
finish_reason=None,
242+
delta=Delta(
243+
role="assistant",
244+
tool_calls=[
245+
ChatCompletionDeltaToolCall(
246+
type="function",
247+
id=None,
248+
function=Function(
249+
name=None,
250+
arguments='ue2"}',
251+
),
252+
index=1,
253+
)
254+
],
255+
),
256+
)
257+
]
258+
),
259+
ModelResponse(
260+
choices=[
261+
StreamingChoices(
262+
finish_reason="tool_calls",
263+
)
264+
]
265+
),
266+
]
267+
173268

174269
@pytest.fixture
175270
def mock_response():
@@ -1089,3 +1184,80 @@ async def test_generate_content_async_stream_with_usage_metadata(
10891184
]
10901185
== "string"
10911186
)
1187+
1188+
1189+
@pytest.mark.asyncio
1190+
async def test_generate_content_async_multiple_function_calls(
1191+
mock_completion, lite_llm_instance
1192+
):
1193+
"""Test handling of multiple function calls with different indices in streaming mode.
1194+
1195+
This test verifies that:
1196+
1. Multiple function calls with different indices are handled correctly
1197+
2. Arguments and names are properly accumulated for each function call
1198+
3. The final response contains all function calls with correct indices
1199+
"""
1200+
mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM
1201+
1202+
llm_request = LlmRequest(
1203+
contents=[
1204+
types.Content(
1205+
role="user",
1206+
parts=[types.Part.from_text(text="Test multiple function calls")],
1207+
)
1208+
],
1209+
config=types.GenerateContentConfig(
1210+
tools=[
1211+
types.Tool(
1212+
function_declarations=[
1213+
types.FunctionDeclaration(
1214+
name="function_1",
1215+
description="First test function",
1216+
parameters=types.Schema(
1217+
type=types.Type.OBJECT,
1218+
properties={
1219+
"arg": types.Schema(type= 2851 types.Type.STRING),
1220+
},
1221+
),
1222+
),
1223+
types.FunctionDeclaration(
1224+
name="function_2",
1225+
description="Second test function",
1226+
F438 parameters=types.Schema(
1227+
type=types.Type.OBJECT,
1228+
properties={
1229+
"arg": types.Schema(type=types.Type.STRING),
1230+
},
1231+
),
1232+
),
1233+
]
1234+
)
1235+
],
1236+
),
1237+
)
1238+
1239+
responses = []
1240+
async for response in lite_llm_instance.generate_content_async(
1241+
llm_request, stream=True
1242+
):
1243+
responses.append(response)
1244+
1245+
# Verify we got the final response with both function calls
1246+
assert len(responses) > 0
1247+
final_response = responses[-1]
1248+
assert final_response.content.role == "model"
1249+
assert len(final_response.content.parts) == 2
1250+
1251+
# Verify first function call
1252+
assert final_response.content.parts[0].function_call.name == "function_1"
1253+
assert final_response.content.parts[0].function_call.id == "call_1"
1254+
assert final_response.content.parts[0].function_call.args == {
1255+
"arg": "value1"
1256+
}
1257+
1258+
# Verify second function call
1259+
assert final_response.content.parts[1].function_call.name == "function_2"
1260+
assert final_response.content.parts[1].function_call.id == "call_2"
1261+
assert final_response.content.parts[1].function_call.args == {
1262+
"arg": "value2"
1263+
}

0 commit comments

Comments
 (0)
0