|
15 | 15 |
|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
| 18 | +import re |
| 19 | +import ast |
18 | 20 | import contextlib
|
19 | 21 | from functools import cached_property
|
20 | 22 | import logging
|
|
42 | 44 | _EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
43 | 45 |
|
44 | 46 |
|
| 47 | +def _extract_function_call(input_string): |
| 48 | + """ Use regex to convert a function call string into a FunctionCall object. """ |
| 49 | + pattern = r"Malformed function call:\s*(\w+)\((.*)\)" |
| 50 | + match = re.search(pattern, input_string) |
| 51 | + if match: |
| 52 | + func_name = match.group(1) |
| 53 | + args_str = match.group(2).strip() |
| 54 | + |
| 55 | + # Create a dummy function call with the captured arguments. |
| 56 | + # This will allow us to use ast to parse the function call. |
| 57 | + dummy_call = f"dummy({args_str})" |
| 58 | + |
| 59 | + # Parse the dummy function call. |
| 60 | + tree = ast.parse(dummy_call, mode='eval') |
| 61 | + call_node = tree.body |
| 62 | + |
| 63 | + # Extract keyword arguments (if there are any). |
| 64 | + args_dict = {kw.arg: ast.literal_eval(kw.value) for kw in call_node.keywords} |
| 65 | + |
| 66 | + return types.FunctionCall(args=args_dict, name=func_name) |
| 67 | + else: |
| 68 | + return None |
| 69 | + |
| 70 | + |
| 71 | +def _fix_malformed_function_calls(response): |
| 72 | + """ Check if there's malformed error, create FunctionCall object using args. |
| 73 | + Then remove the error and insert the FunctionCall into the response. |
| 74 | + """ |
| 75 | + for candidate in response.candidates: |
| 76 | + if candidate.finish_reason == types.FinishReason.MALFORMED_FUNCTION_CALL: |
| 77 | + function_call = _extract_function_call(candidate.finish_message) |
| 78 | + if function_call is None: |
| 79 | + logging.warning("could not parse function call: %s", candidate.finish_message) |
| 80 | + continue |
| 81 | + logging.warning("malformed function call caught and overwritten: %s", candidate.finish_message) |
| 82 | + candidate.content = types.Content(parts=[types.Part(function_call=function_call)], role="model") |
| 83 | + candidate.finish_message = None |
| 84 | + candidate.finish_reason = types.FinishReason.STOP |
| 85 | + |
| 86 | + |
45 | 87 | class Gemini(BaseLlm):
|
46 | 88 | """Integration for Gemini models.
|
47 | 89 |
|
@@ -105,6 +147,7 @@ async def generate_content_async(
|
105 | 147 | # previous partial content. The only difference is bidi rely on
|
106 | 148 | # complete_turn flag to detect end while sse depends on finish_reason.
|
107 | 149 | async for response in responses:
|
| 150 | + _fix_malformed_function_calls(response) |
108 | 151 | logger.info(_build_response_log(response))
|
109 | 152 | llm_response = LlmResponse.create(response)
|
110 | 153 | usage_metadata = llm_response.usage_metadata
|
@@ -148,6 +191,7 @@ async def generate_content_async(
|
148 | 191 | contents=llm_request.contents,
|
149 | 192 | config=llm_request.config,
|
150 | 193 | )
|
| 194 | + _fix_malformed_function_calls(response) |
151 | 195 | logger.info(_build_response_log(response))
|
152 | 196 | yield LlmResponse.create(response)
|
153 | 197 |
|
|
0 commit comments