8000 Merge pull request #1 from Praznat/fix-malformed-function-calls · Praznat/adk-python@9c2b523 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c2b523

Browse files
authored
Merge pull request #1 from Praznat/fix-malformed-function-calls
Fix malformed function calls
2 parents d3910e7 + 9ffe2e2 commit 9c2b523

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

src/google/adk/models/google_llm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from __future__ import annotations
1717

18+
import re
19+
import ast
1820
import contextlib
1921
from functools import cached_property
2022
import logging
@@ -42,6 +44,46 @@
4244
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
4345

4446

47+
def _extract_function_call(input_string):
48+
""" Use regex to convert a function call string into a FunctionCall object. """
49+
pattern = r"Malformed function call:\s*(\w+)\((.*)\)"
50+
match = re.search(pattern, input_string)
51+
if match:
52+
func_name = match.group(1)
53+
args_str = match.group(2).strip()
54+
55+
# Create a dummy function call with the captured arguments.
56+
# This will allow us to use ast to parse the function call.
57+
dummy_call = f"dummy({args_str})"
58+
59+
# Parse the dummy function call.
60+
tree = ast.parse(dummy_call, mode='eval')
61+
call_node = tree.body
62+
63+
# Extract keyword arguments (if there are any).
64+
args_dict = {kw.arg: ast.literal_eval(kw.value) for kw in call_node.keywords}
65+
66+
return types.FunctionCall(args=args_dict, name=func_name)
67+
else:
68+
return None
69+
70+
71+
def _fix_malformed_function_calls(response):
72+
""" Check if there's malformed error, create FunctionCall object using args.
73+
Then remove the error and insert the FunctionCall into the response.
74+
"""
75+
for candidate in response.candidates:
76+
if candidate.finish_reason == types.FinishReason.MALFORMED_FUNCTION_CALL:
77+
function_call = _extract_function_call(candidate.finish_message)
78+
if function_call is None:
79+
logging.warning("could not parse function call: %s", candidate.finish_message)
80+
continue
81+
logging.warning("malformed function call caught and overwritten: %s", candidate.finish_message)
82+
candidate.content = types.Content(parts=[types.Part(function_call=function_call)], role="model")
83+
candidate.finish_message = None
84+
candidate.finish_reason = types.FinishReason.STOP
85+
86+
4587
class Gemini(BaseLlm):
4688
"""Integration for Gemini models.
4789
@@ -105,6 +147,7 @@ async def generate_content_async(
105147
# previous partial content. The only difference is bidi rely on
106148
# complete_turn flag to detect end while sse depends on finish_reason.
107149
async for response in responses:
150+
_fix_malformed_function_calls(response)
108151
logger.info(_build_response_log(response))
109152
llm_response = LlmResponse.create(response)
110153
usage_metadata = llm_response.usage_metadata
@@ -148,6 +191,7 @@ async def generate_content_async(
148191
contents=llm_request.contents,
149192
config=llm_request.config,
150193
)
194+
_fix_malformed_function_calls(response)
151195
logger.info(_build_response_log(response))
152196
yield LlmResponse.create(response)
153197

0 commit comments

Comments
 (0)
0