8000 Add parallel tool call support for litellm by selcukgun · Pull Request #172 · google/adk-python · GitHub
[go: up one dir, main page]

Skip to content

Add parallel tool call support for litellm #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
10000 Loading
Diff view
Diff view
81 changes: 48 additions & 33 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str:

def _content_to_message_param(
content: types.Content,
) -> Message:
"""Converts a types.Content to a litellm Message.
) -> Union[Message, list[Message]]:
"""Converts a types.Content to a litellm Message or list of Messages.

Handles multipart function responses by returning a list of
ChatCompletionToolMessage objects if multiple function_response parts exist.

Args:
content: The content to convert.

Returns:
The litellm Message.
A litellm Message, a list of litellm Messages.
"""

if content.parts and content.parts[0].function_response:
return ChatCompletionToolMessage(
role="tool",
tool_call_id=content.parts[0].function_response.id,
content=_safe_json_serialize(
content.parts[0].function_response.response
),
)
tool_messages = []
for part in content.parts:
if part.function_response:
tool_messages.append(
ChatCompletionToolMessage(
role="tool",
tool_call_id=part.function_response.id,
content=_safe_json_serialize(part.function_response.response),
)
)
if tool_messages:
return tool_messages if len(tool_messages) > 1 else tool_messages[0]

# Handle user or assistant messages
role = _to_litellm_role(content.role)
message_content = _get_content(content.parts) or None

if role == "user":
return ChatCompletionUserMessage(
role="user", content=_get_content(content.parts)
)
else:

tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
for part in content.parts
if part.function_call
]
return ChatCompletionUserMessage(role="user", content=message_content)
else: # assistant/model
tool_calls = []
content_present = False
for part in content.parts:
if part.function_call:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
)
elif part.text or part.inline_data:
content_present = True

final_content = message_content if content_present else None

return ChatCompletionAssistantMessage(
role=role,
content=_get_content(content.parts) or None,
content=final_content,
tool_calls=tool_calls or None,
)

Expand Down Expand Up @@ -437,10 +449,13 @@ def _get_completion_inputs(
Returns:
The litellm inputs (message list and tool dictionary).
"""
messages = [
_content_to_message_param(content)
for content in llm_request.contents or []
]
messages = []
for content in llm_request.contents or []:
message_param_or_list = _content_to_message_param(content)
if isinstance(message_param_or_list, list):
messages.extend(message_param_or_list)
elif message_param_or_list: # Ensure it's not None before appending
messages.append(message_param_or_list)

if llm_request.config.system_instruction:
messages.insert(
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,36 @@ def test_content_to_message_param_user_message():
assert message["content"] == "Test prompt"


def test_content_to_message_param_multi_part_function_response():
part1 = types.Part.from_function_response(
name="function_one",
response={"result": "result_one"},
)
part1.function_response.id = "tool_call_1"

part2 = types.Part.from_function_response(
name="function_two",
response={"value": 123},
)
part2.function_response.id = "tool_call_2"

content = types.Content(
role="tool",
parts=[part1, part2],
)
messages = _content_to_message_param(content)
assert isinstance(messages, list)
assert len(messages) == 2

assert messages[0]["role"] == "tool"
assert messages[0]["tool_call_id"] == "tool_call_1"
assert messages[0]["content"] == '{"result": "result_one"}'

assert messages[1]["role"] == "tool"
assert messages[1]["tool_call_id"] == "tool_call_2"
assert messages[1]["content"] == '{"value": 123}'


def test_content_to_message_param_assistant_message():
content = types.Content(
role="assistant", parts=[types.Part.from_text(text="Test response")]
Expand Down
0