8000 Add parallel tool call support for litellm (#172) · JJGuilloryGit/adk-python@9a44831 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a44831

Browse files
selcukgunhangfei
andauthored
Add parallel tool call support for litellm (google#172)
Co-authored-by: Hangfei Lin <hangfei@google.com>
1 parent 4e8b944 commit 9a44831

File tree

2 files changed

+78
-33
lines changed

2 files changed

+78
-33
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str:
136136

137137
def _content_to_message_param(
138138
content: types.Content,
139-
) -> Message:
140-
"""Converts a types.Content to a litellm Message.
139+
) -> Union[Message, list[Message]]:
140+
"""Converts a types.Content to a litellm Message or list of Messages.
141+
142+
Handles multipart function responses by returning a list of
143+
ChatCompletionToolMessage objects if multiple function_response parts exist.
141144
142145
Args:
143146
content: The content to convert.
144147
145148
Returns:
146-
The litellm Message.
149+
A litellm Message, a list of litellm Messages.
147150
"""
148151

149-
if content.parts and content.parts[0].function_response:
150-
return ChatCompletionToolMessage(
151-
role="tool",
152-
tool_call_id=content.parts[0].function_response.id,
153-
content=_safe_json_serialize(
154-
content.parts[0].function_response.response
155-
),
156-
)
152+
tool_messages = []
153+
for part in content.parts:
154+
if part.function_response:
155+
tool_messages.append(
156+
ChatCompletionToolMessage(
157+
role="tool",
158+
tool_call_id=part.function_response.id,
159+
content=_safe_json_serialize(part.function_response.response),
160+
)
161+
)
162+
if tool_messages:
163+
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
157164

165+
# Handle user or assistant messages
158166
role = _to_litellm_role(content.role)
167+
message_content = _get_content(content.parts) or None
159168

160169
if role == "user":
161-
return ChatCompletionUserMessage(
162-
role="user", content=_get_content(content.parts)
163-
)
164-
else:
165-
166-
tool_calls = [
167-
ChatCompletionMessageToolCall(
168-
type="function",
169-
id=part.function_call.id,
170-
function=Function(
171-
name=part.function_call.name,
172-
arguments=part.function_call.args,
173-
),
174-
)
175-
for part in content.parts
176-
if part.function_call
177-
]
170+
return ChatCompletionUserMessage(role="user", content=message_content)
171+
else: # assistant/model
172+
tool_calls = []
173+
content_present = False
174+
for part in content.parts:
175+
if part.function_call:
176+
tool_calls.append(
177+
ChatCompletionMessageToolCall(
178+
type="function",
179+
id=part.function_call.id,
180+
function=Function(
181+
name=part.function_call.name,
182+
arguments=part.function_call.args,
183+
),
184+
)
185+
)
186+
elif part.text or part.inline_data:
187+
content_present = True
188+
189+
final_content = message_content if content_present else None
178190

179191
return ChatCompletionAssistantMessage(
180192
role=role,
181-
content=_get_content(content.parts) or None,
193+
content=final_content,
182194
tool_calls=tool_calls or None,
183195
)
184196

@@ -437,10 +449,13 @@ def _get_completion_inputs(
437449
Returns:
438450
The litellm inputs (message list and tool dictionary).
439451
"""
440-
messages = [
441-
_content_to_message_param(content)
442-
for content in llm_request.contents or []
443-
]
452+
messages = []
453+
for content in llm_request.contents or []:
454+
message_param_or_list = _content_to_message_param(content)
455+
if isinstance(message_param_or_list, list):
456+
messages.extend(message_param_or_list)
457+
elif message_param_or_list: # Ensure it's not None before appending
458+
messages.append(message_param_or_list)
444459

445460
if llm_request.config.system_instruction:
446461
messages.insert(

tests/unittests/models/test_litellm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,36 @@ def test_content_to_message_param_user_message():
515515
assert message["content"] == "Test prompt"
516516

517517

518+
def test_content_to_message_param_multi_part_function_response():
519+
part1 = types.Part.from_function_response(
520+
name="function_one",
521+
response={"result": "result_one"},
522+
)
523+
part1.function_response.id = "tool_call_1"
524+
525+
part2 = types.Part.from_function_response(
526+
name="function_two",
527+
response={"value": 123},
528+
)
529+
part2.function_response.id = "tool_call_2"
530+
531+
content = types.Content(
532+
role="tool",
533+
parts=[part1, part2],
534+
)
535+
messages = _content_to_message_param(content)
536+
assert isinstance(messages, list)
537+
assert len(messages) == 2
538+
539+
assert messages[0]["role"] == "tool"
540+
assert messages[0]["tool_call_id"] == "tool_call_1"
541+
assert messages[0]["content"] == '{"result": "result_one"}'
542+
543+
assert messages[1]["role"] == "tool"
544+
assert messages[1]["tool_call_id"] == "tool_call_2"
545+
assert messages[1]["content"] == '{"value": 123}'
546+
547+
518548
def test_content_to_message_param_assistant_message():
519549
content = types.Content(
520550
role="assistant", parts=[types.Part.from_text(text="Test response")]

0 commit comments

Comments
 (0)
0