diff --git a/contributing/samples/quickstart/agent.py b/contributing/samples/quickstart/agent.py index fdd6b7f9d..b251069ad 100644 --- a/contributing/samples/quickstart/agent.py +++ b/contributing/samples/quickstart/agent.py @@ -29,7 +29,7 @@ def get_weather(city: str) -> dict: "status": "success", "report": ( "The weather in New York is sunny with a temperature of 25 degrees" - " Celsius (41 degrees Fahrenheit)." + " Celsius (77 degrees Fahrenheit)." ), } else: diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index e34299f6f..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Generator from typing import Iterable -from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -482,22 +481,16 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> Tuple[ - List[Message], - Optional[List[dict]], - Optional[types.SchemaUnion], - Optional[Dict], -]: - """Converts an LlmRequest to litellm inputs and extracts generation params. +) -> tuple[Iterable[Message], Iterable[dict]]: + """Converts an LlmRequest to litellm inputs. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary, response format and generation params). + The litellm inputs (message list, tool dictionary and response format). """ - # 1. Construct messages - messages: List[Message] = [] + messages = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -514,8 +507,7 @@ def _get_completion_inputs( ), ) - # 2. Convert tool declarations - tools: Optional[List[Dict]] = None + tools = None if ( llm_request.config and llm_request.config.tools @@ -526,39 +518,12 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - # 3. Handle response format - response_format: Optional[types.SchemaUnion] = None - if llm_request.config and llm_request.config.response_schema: - response_format = llm_request.config.response_schema - - # 4. Extract generation parameters - generation_params: Optional[Dict] = None - if llm_request.config: - config_dict = llm_request.config.model_dump(exclude_none=True) - # Generate LiteLlm parameters here, - # Following https://docs.litellm.ai/docs/completion/input. - generation_params = {} - param_mapping = { - "max_output_tokens": "max_completion_tokens", - "stop_sequences": "stop", - } - for key in ( - "temperature", - "max_output_tokens", - "top_p", - "top_k", - "stop_sequences", - "presence_penalty", - "frequency_penalty", - ): - if key in config_dict: - mapped_key = param_mapping.get(key, key) - generation_params[mapped_key] = config_dict[key] + response_format = None - if not generation_params: - generation_params = None + if llm_request.config.response_schema: + response_format = llm_request.config.response_schema - return messages, tools, response_format, generation_params + return messages, tools, response_format def _build_function_declaration_log( @@ -695,9 +660,7 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) - ) + messages, tools, response_format = _get_completion_inputs(llm_request) completion_args = { "model": self.model, @@ -707,9 +670,6 @@ async def generate_content_async( } completion_args.update(self._additional_args) - if generation_params: - completion_args.update(generation_params) - if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 0125872fd..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1429,35 +1430,3 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} - - -@pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): - # Test that generation_params are extracted and mapped correctly - req = LlmRequest( - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="hi")]), - ], - config=types.GenerateContentConfig( - temperature=0.33, - max_output_tokens=123, - top_p=0.88, - top_k=7, - stop_sequences=["foo", "bar"], - presence_penalty=0.1, - frequency_penalty=0.2, - ), - ) - from google.adk.models.lite_llm import _get_completion_inputs - - _, _, _, generation_params = _get_completion_inputs(req) - assert generation_params["temperature"] == 0.33 - assert generation_params["max_completion_tokens"] == 123 - assert generation_params["top_p"] == 0.88 - assert generation_params["top_k"] == 7 - assert generation_params["stop"] == ["foo", "bar"] - assert generation_params["presence_penalty"] == 0.1 - assert generation_params["frequency_penalty"] == 0.2 - # Should not include max_output_tokens - assert "max_output_tokens" not in generation_params - assert "stop_sequences" not in generation_params