8000 fix: Support for parallel tool calling using SSE with LiteLLM by lucasnobre212 · Pull Request #759 · google/adk-python · GitHub
[go: up one dir, main page]

Skip to content

fix: Support for parallel tool calling using SSE with LiteLLM #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jum 8000 p to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
id: Optional[str]
name: Optional[str]
args: Optional[str]
index: Optional[int] = 0


class TextChunk(BaseModel):
Expand Down Expand Up @@ -386,6 +387,7 @@ def _model_response_to_chunk(
id=tool_call.id,
name=tool_call.function.name,
args=tool_call.function.arguments,
index=tool_call.index,
), finish_reason

if finish_reason and not (
Expand Down Expand Up @@ -661,9 +663,8 @@ async def generate_content_async(

if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
# Track function calls by index
function_calls = {} # index -> {name, args, id}
completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
Expand All @@ -672,11 +673,17 @@ async def generate_content_async(
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
index = chunk.index or 0
if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None}

if chunk.name:
function_name += chunk.name
function_calls[index]["name"] += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
function_calls[index]["args"] += chunk.args
function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"]
)
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
Expand All @@ -693,28 +700,31 @@ async def generate_content_async(
total_token_count=chunk.total_tokens,
)

if finish_reason == "tool_calls" and function_id:
if finish_reason == "tool_calls" and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
tool_calls=tool_calls,
)
)
)
function_name = ""
function_args = ""
function_id = None
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
Expand Down
172 changes: 172 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from litellm.types.utils import ModelResponse
from litellm.types.utils import StreamingChoices
import pytest
import json

LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
contents=[
Expand Down Expand Up @@ -170,6 +171,100 @@
),
]

MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="call_1",
function=Function(
name="function_1",
arguments='{"arg": "val',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue1"}',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="call_2",
function=Function(
name="function_2",
arguments='{"arg": "val',
),
index=1,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue2"}',
),
index=1,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason="tool_calls",
)
]
),
]


@pytest.fixture
def mock_response():
Expand Down Expand Up @@ -1089,3 +1184,80 @@ async def test_generate_content_async_stream_with_usage_metadata(
]
== "string"
)


@pytest.mark.asyncio
async def test_generate_content_async_multiple_function_calls(
mock_completion, lite_llm_instance
):
"""Test handling of multiple function calls with different indices in streaming mode.

This test verifies that:
1. Multiple function calls with different indices are handled correctly
2. Arguments and names are properly accumulated for each function call
3. The final response contains all function calls with correct indices
"""
mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM

llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[types.Part.from_text(text="Test multiple function calls")],
)
],
config=types.GenerateContentConfig(
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="function_1",
description="First test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
types.FunctionDeclaration(
name="function_2",
description="Second test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
]
)
],
),
)

responses = []
async for response in lite_llm_instance.generate_content_async(
llm_request, stream=True
):
responses.append(response)

# Verify we got the final response with both function calls
assert len(responses) > 0
final_response = responses[-1]
assert final_response.content.role == "model"
assert len(final_response.content.parts) == 2

# Verify first function call
assert final_response.content.parts[0].function_call.name == "function_1"
assert final_response.content.parts[0].function_call.id == "call_1"
assert final_response.content.parts[0].function_call.args == {
"arg": "value1"
}

# Verify second function call
assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "call_2"
assert final_response.content.parts[1].function_call.args == {
"arg": "value2"
}
Loading