8000 feat: Add index tracking to handle parallel tool call using litellm · google/adk-python@cef3ca1 · GitHub
[go: up one dir, main page]

Skip to content

Commit cef3ca1

Browse files
committed
feat: Add index tracking to handle parallel tool call using litellm
1 parent b72573c commit cef3ca1

File tree

2 files changed

+203
-20
lines changed

2 files changed

+203
-20
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class FunctionChunk(BaseModel):
6161
id: Optional[str]
6262
name: Optional[str]
6363
args: Optional[str]
64+
index: Optional[int] = 0
6465

6566

6667
class TextChunk(BaseModel):
@@ -385,6 +386,7 @@ def _model_response_to_chunk(
385386
id=tool_call.id,
386387
name=tool_call.function.name,
387388
args=tool_call.function.arguments,
389+
index=tool_call.index,
388390
), finish_reason
389391

390392
if finish_reason and not (
@@ -653,9 +655,8 @@ async def generate_content_async(
653655

654656
if stream:
655657
text = ""
656-
function_name = ""
657-
function_args = ""
658-
function_id = None
658+
# Track function calls by index
659+
function_calls = {} # index -> {name, args, id}
659660
completion_args["stream"] = True
660661
aggregated_llm_response = None
661662
aggregated_llm_response_with_tool_call = None
@@ -664,11 +665,17 @@ async def generate_content_async(
664665
for part in self.llm_client.completion(**completion_args):
665666
for chunk, finish_reason in _model_response_to_chunk(part):
666667
if isinstance(chunk, FunctionChunk):
668+
index = chunk.index or 0
669+
if index not in function_calls:
670+
function_calls[index] = {"name": "", "args": "", "id": None}
671+
667672
if chunk.name:
668-
function_name += chunk.name
673+
function_calls[index]["name"] += chunk.name
669674
if chunk.args:
670-
function_args += chunk.args
671-
function_id = chunk.id or function_id
675+
function_calls[index]["args"] += chunk.args
676+
function_calls[index]["id"] = (
677+
chunk.id or function_calls[index]["id"]
678+
)
672679
elif isinstance(chunk, TextChunk):
673680
text += chunk.text
674681
yield _message_to_generate_content_response(
@@ -685,28 +692,31 @@ async def generate_content_async(
685692
total_token_count=chunk.total_tokens,
686693
)
687694

688-
if finish_reason == "tool_calls" and function_id:
695+
if finish_reason == "tool_calls" and function_calls:
696+
tool_calls = []
697+
for index, func_data in function_calls.items():
698+
if func_data["id"]:
699+
tool_calls.append(
700+
ChatCompletionMessageToolCall(
701+
type="function",
702+
id=func_data["id"],
703+
function=Function(
704+
name=func_data["name"],
705+
arguments=func_data["args"],
706+
index=index,
707+
),
708+
)
709+
)
689710
aggregated_llm_response_with_tool_call = (
690711
_message_to_generate_content_response(
691712
ChatCompletionAssistantMessage(
692713
role="assistant",
693714
content="",
694 67E6 -
tool_calls=[
695-
ChatCompletionMessageToolCall(
696-
type="function",
697-
id=function_id,
698-
function=Function(
699-
name=function_name,
700-
arguments=function_args,
701-
),
702-
)
703-
],
715+
tool_calls=tool_calls,
704716
)
705717
)
706718
)
707-
function_name = ""
708-
function_args = ""
709-
function_id = None
719+
function_calls.clear()
710720
elif finish_reason == "stop" and text:
711721
aggregated_llm_response = _message_to_generate_content_response(
712722
ChatCompletionAssistantMessage(role="assistant", content=text)

tests/unittests/models/test_litellm.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from litellm.types.utils import ModelResponse
3838
from litellm.types.utils import StreamingChoices
3939
import pytest
40+
import json
4041

4142
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
4243
contents=[
@@ -169,6 +170,101 @@
169170
),
170171
]
171172

173+
MULTIPLE_FUNCTION_CALLS_STREAM = [
174+
ModelResponse(
175+
choices=[
176+
StreamingChoices(
177+
finish_reason=None,
178+
delta=Delta(
179+
role="assistant",
180+
tool_calls=[
181+
ChatCompletionDeltaToolCall(
182+
type="function",
183+
id="call_1",
184+
function=Function(
185+
name="function_1",
186+
arguments='{"arg": "val',
187+
),
188+
index=0,
189+
)
190+
],
191+
),
192+
)
193+
]
194+
),
195+
ModelResponse(
196+
choices=[
197+
StreamingChoices(
198+
finish_reason=None,
199+
delta=Delta(
200+
role="assistant",
201+
tool_calls=[
202+
ChatCompletionDeltaToolCall(
203+
type="function",
204+
id=None,
205+
function=Function(
206+
name=None,
207+
arguments='ue1"}',
208+
),
209+
index=0,
210+
)
211+
],
212+
),
213+
)
214+
]
215+
),
216+
ModelResponse(
217+
choices=[
218+
StreamingChoices(
219+
finish_reason=None,
220+
delta=Delta(
221+
role="assistant",
222+
tool_calls=[
223+
ChatCompletionDeltaToolCall(
224+
type="function",
225+
id="call_2",
226+
function=Function(
227+
name="function_2",
228+
arguments='{"arg": "val',
229+
),
230+
index=1,
231+
)
232+
],
233+
),
234+
)
235+
]
236+
),
237+
ModelResponse(
238+
choices=[
239+
StreamingChoices(
240+
finish_reason=None,
241+
delta=Delta(
242+
role="assistant",
243+
tool_calls=[
244+
ChatCompletionDeltaToolCall(
245+
type="function",
246+
id=None,
247+
function=Function(
248+
name=None,
249+
arguments='ue2"}',
250+
),
251+
index=1,
252+
)
253+
],
254+
),
255+
)
256+
]
257+
),
258+
ModelResponse(
259+
choices=[
260+
StreamingChoices(
261+
finish_reason="tool_calls",
262+
)
263+
]
264+
),
265+
]
266+
267+
172268
@pytest.fixture
173269
def mock_response():
174270
return ModelResponse(
@@ -1086,3 +1182,80 @@ async def test_generate_content_async_stream_with_usage_metadata(
10861182
]
10871183
== "string"
10881184
)
1185+
1186+
1187+
@pytest.mark.asyncio
1188+
async def test_generate_content_async_multiple_function_calls(
1189+
mock_completion, lite_llm_instance
1190+
):
1191+
"""Test handling of multiple function calls with different indices in streaming mode.
1192+
1193+
This test verifies that:
1194+
1. Multiple function calls with different indices are handled correctly
1195+
2. Arguments and names are properly accumulated for each function call
1196+
3. The final response contains all function calls with correct indices
1197+
"""
1198+
mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM
1199+
1200+
llm_request = LlmRequest(
1201+
contents=[
1202+
types.Content(
1203+
role="user",
1204+
parts=[types.Part.from_text(text="Test multiple function calls")],
1205+
)
1206+
],
1207+
config=types.GenerateContentConfig(
1208+
tools=[
1209+
types.Tool(
1210+
function_declarations=[
1211+
types.FunctionDeclaration(
1212+
name="function_1",
1213+
description="First test function",
1214+
parameters=types.Schema(
1215+
type=types.Type.OBJECT,
1216+
properties={
1217+
"arg": types.Schema(type=types.Type.STRING),
1218+
},
1219+
),
1220+
),
1221+
types.FunctionDeclaration(
1222+
name="function_2",
1223+
description="Second test function",
1224+
parameters=types.Schema(
1225+
type=types.Type.OBJECT,
1226+
properties={
1227+
"arg": types.Schema(type=types.Type.STRING),
1228+
},
1229+
),
1230+
),
1231+
]
1232+
)
1233+
],
1234+
),
1235+
)
1236+
1237+
responses = []
1238+
async for response in lite_llm_instance.generate_content_async(
1239+
llm_request, stream=True
1240+
):
1241+
responses.append(response)
1242+
1243+
# Verify we got the final response with both function calls
1244+
assert len(responses) > 0
1245+
final_response = responses[-1]
1246+
assert final_response.content.role == "model"
1247+
assert len(final_response.content.parts) == 2
1248+
1249+
# Verify first function call
1250+
assert final_response.content.parts[0].function_call.name == "function_1"
1251+
assert final_response.content.parts[0].function_call.id == "call_1"
1252+
assert final_response.content.parts[0].function_call.args == {
1253+
"arg": "value1"
1254+
}
1255+
1256+
# Verify second function call
1257+
assert final_response.content.parts[1].function_call.name == "function_2"
1258+
assert final_response.content.parts[1].function_call.id == "call_2"
1259+
assert final_response.content.parts[1].function_call.args == {
1260+
"arg": "value2"
1261+
}

0 commit comments

Comments
 (0)
0