8000 fix: Handle non-indexed function call chunks with incremental fallbac… · google/adk-python@b181cbc · GitHub
[go: up one dir, main page]

Skip to content

Commit b181cbc

Browse files
selcukguncopybara-github
authored andcommitted
fix: Handle non-indexed function call chunks with incremental fallback index
This is in response to the litellm v1.71.2 + ollama v0.9.0 sending function call chunks with 0 indices across multiple calls and lacking call ids. Solutions introduced: 1. increment fallback index when accumulated arg becomes json parsable. 2. tolerate finish reason == stop when tool calls are present 3. fallback to index when tool call id is None Fixes #294 PiperOrigin-RevId: 766258344
1 parent bd588bc commit b181cbc

File tree

3 files changed

+189
-5
lines changed

3 files changed

+189
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ test = [
8383
"anthropic>=0.43.0", # For anthropic model tests
8484
"langchain-community>=0.3.17",
8585
"langgraph>=0.2.60", # For LangGraphAgent
86-
"litellm>=1.63.11", # For LiteLLM tests
86+
"litellm>=1.71.2", # For LiteLLM tests
8787
"llama-index-readers-file>=0.4.0", # For retrieval tests
8888

8989
"pytest-asyncio>=0.25.0",

src/google/adk/models/lite_llm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

1617
import base64
1718
import json
@@ -669,20 +670,29 @@ async def generate_content_async(
669670
aggregated_llm_response = None
670671
aggregated_llm_response_with_tool_call = None
671672
usage_metadata = None
672-
673+
fallback_index = 0
673674
for part in self.llm_client.completion(**completion_args):
674675
for chunk, finish_reason in _model_response_to_chunk(part):
675676
if isinstance(chunk, FunctionChunk):
676-
index = chunk.index or 0
677+
index = chunk.index or fallback_index
677678
if index not in function_calls:
678679
function_calls[index] = {"name": "", "args": "", "id": None}
679680

680681
if chunk.name:
681682
function_calls[index]["name"] += chunk.name
682683
if chunk.args:
683684
function_calls[index]["args"] += chunk.args
685+
686+
# check if args is completed (workaround for improper chunk
687+
# indexing)
688+
try:
689+
json.loads(function_calls[index]["args"])
690+
fallback_index += 1
691+
except json.JSONDecodeError:
692+
pass
693+
684694
function_calls[index]["id"] = (
685-
chunk.id or function_calls[index]["id"]
695+
chunk.id or function_calls[index]["id"] or str(index)
686696
)
687697
elif isinstance(chunk, TextChunk):
688698
text += chunk.text
@@ -700,7 +710,9 @@ async def generate_content_async(
700710
total_token_count=chunk.total_tokens,
701711
)
702712

703-
if finish_reason == "tool_calls" and function_calls:
713+
if (
714+
finish_reason == "tool_calls" or finish_reason == "stop"
715+
) and function_calls:
704716
tool_calls = []
705717
for index, func_data in function_calls.items():
706718
if func_data["id"]:

tests/unittests/models/test_litellm.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,105 @@ def mock_response():
290290
)
291291

292292

293+
# Test case reflecting litellm v1.71.2, ollama v0.9.0 streaming response
294+
# no tool call ids
295+
# indices all 0
296+
# finish_reason stop instead of tool_calls
297+
NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
298+
ModelResponse(
299+
choices=[
300+
StreamingChoices(
301+
finish_reason=None,
302+
delta=Delta(
303+
role="assistant",
304+
tool_calls=[
305+
ChatCompletionDeltaToolCall(
306+
type="function",
307+
id=None,
308+
function=Function(
309+
name="function_1",
310+
arguments='{"arg": "val',
311+
),
312+
index=0,
313+
)
314+
],
315+
),
316+
)
317+
]
318+
),
319+
ModelResponse(
320+
choices=[
321+
StreamingChoices(
322+
finish_reason=None,
323+
delta=Delta(
324+
role="assistant",
325+
tool_calls=[
326+
ChatCompletionDeltaToolCall(
327+
type="function",
328+
id=None,
329+
function=Function(
330+
name=None,
331+
arguments='ue1"}',
332+
),
333+
index=0,
334+
)
335+
],
336+
),
337+
)
338+
]
339+
),
340+
ModelResponse(
341+
choices=[
342+
StreamingChoices(
343+
finish_reason=None,
344+
delta=Delta(
345+
role="assistant",
346+
tool_calls=[
347+
ChatCompletionDeltaToolCall(
348+
type="function",
349+
id=None,
350+
function=Function(
351+
name="function_2",
352+
arguments='{"arg": "val',
353+
),
354+
index=0,
355+
)
356+
],
357+
),
358+
)
359+
]
360+
),
361+
ModelResponse(
362+
choices=[
363+
StreamingChoices(
364+
finish_reason=None,
365+
delta=Delta(
366+
role="assistant",
367+
tool_calls=[
368+
ChatCompletionDeltaToolCall(
369+
type="function",
370+
id=None,
371+
function=Function(
372+
name=None,
373+
arguments='ue2"}',
374+
),
375+
index=0,
376+
)
377+
],
378+
),
379+
)
380+
]
381+
),
382+
ModelResponse(
383+
choices=[
384+
StreamingChoices(
385+
finish_reason="stop",
386+
)
387+
]
388+
),
389+
]
390+
391+
293392
@pytest.fixture
294393
def mock_acompletion(mock_response):
295394
return AsyncMock(return_value=mock_response)
@@ -1257,3 +1356,76 @@ async def test_generate_content_async_multiple_function_calls(
12571356
assert final_response.content.parts[1].function_call.name == "function_2"
12581357
assert final_response.content.parts[1].function_call.id == "call_2"
12591358
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
1359+
1360+
1361+
@pytest.mark.asyncio
1362+
async def test_generate_content_async_non_compliant_multiple_function_calls(
1363+
mock_completion, lite_llm_instance
1364+
):
1365+
"""Test handling of multiple function calls with same 0 indices in streaming mode.
1366+
1367+
This test verifies that:
1368+
1. Multiple function calls with same indices (0) are handled correctly
1369+
2. Arguments and names are properly accumulated for each function call
1370+
3. The final response contains all function calls with correct incremented indices
1371+
"""
1372+
mock_completion.return_value = NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM
1373+
1374+
llm_request = LlmRequest(
1375+
contents=[
1376+
types.Content(
1377+
role="user",
1378+
parts=[types.Part.from_text(text="Test multiple function calls")],
1379+
)
1380+
],
1381+
config=types.GenerateContentConfig(
1382+
tools=[
1383+
types.Tool(
1384+
function_declarations=[
1385+
types.FunctionDeclaration(
1386+
name="function_1",
1387+
description="First test function",
1388+
parameters=types.Schema(
1389+
type=types.Type.OBJECT,
1390+
properties={
1391+
"arg": types.Schema(type=types.Type.STRING),
1392+
},
1393+
),
1394+
),
1395+
types.FunctionDeclaration(
1396+
name="function_2",
1397+
description="Second test function",
1398+
parameters=types.Schema(
1399+
type=types.Type.OBJECT,
1400+
properties={
1401+
"arg": types.Schema(type=types.Type.STRING),
1402+
},
1403+
),
1404+
),
1405+
]
1406+
)
1407+
],
1408+
),
1409+
)
1410+
1411+
responses = []
1412+
async for response in lite_llm_instance.generate_content_async(
1413+
llm_request, stream=True
1414+
):
1415+
responses.append(response)
1416+
1417+
# Verify we got the final response with both function calls
1418+
assert len(responses) > 0
1419+
final_response = responses[-1]
1420+
assert final_response.content.role == "model"
1421+
assert len(final_response.content.parts) == 2
1422+
1423+
# Verify first function call
1424+
assert final_response.content.parts[0].function_call.name == "function_1"
1425+
assert final_response.content.parts[0].function_call.id == "0"
1426+
assert final_response.content.parts[0].function_call.args == {"arg": "value1"}
1427+
1428+
# Verify second function call
1429+
assert final_response.content.parts[1].function_call.name == "function_2"
1430+
assert final_response.content.parts[1].function_call.id == "1"
1431+
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}

0 commit comments

Comments
 (0)
0