8000 fix: Merge custom http options with adk specific http options in mode… · pandasanjay/adk-python@4ccda99 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4ccda99

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Merge custom http options with adk specific http options in model api request
PiperOrigin-RevId: 770836112
1 parent d22920b commit 4ccda99

File tree

2 files changed

+282
-14
lines changed

2 files changed

+282
-14
lines changed

src/google/adk/models/google_llm.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ async def generate_content_async(
9595
)
9696
logger.info(_build_request_log(llm_request))
9797

98+
# add tracking headers to custom headers given it will override the headers
99+
# set in the api client constructor
100+
if llm_request.config and llm_request.config.http_options:
101+
if not llm_request.config.http_options.headers:
102+
llm_request.config.http_options.headers = {}
103+
llm_request.config.http_options.headers.update(self._tracking_headers)
104+
98105
if stream:
99106
responses = await self.api_client.aio.models.generate_content_stream(
100107
model=llm_request.model,
@@ -201,24 +208,21 @@ def _tracking_headers(self) -> dict[str, str]:
201208
return tracking_headers
202209

203210
@cached_property
204-
def _live_api_client(self) -> Client:
211+
def _live_api_version(self) -> str:
205212
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
206213
# use beta version for vertex api
207-
api_version = 'v1beta1'
208-
# use default api version for vertex
209-
return Client(
210-
http_options=types.HttpOptions(
211-
headers=self._tracking_headers, api_version=api_version
212-
)
213-
)
214+
return 'v1beta1'
214215
else:
215216
# use v1alpha for using API KEY from Google AI Studio
216-
api_version = 'v1alpha'
217-
return Client(
218-
http_options=types.HttpOptions(
219-
headers=self._tracking_headers, api_version=api_version
220-
)
221-
)
217+
return 'v1alpha'
218+
219+
@cached_property
220+
def _live_api_client(self) -> Client:
221+
return Client(
222+
http_options=types.HttpOptions(
223+
headers=self._tracking_headers, api_version=self._live_api_version
224+
)
225+
)
222226

223227
@contextlib.asynccontextmanager
224228
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
@@ -230,6 +234,21 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
230234
Yields:
231235
BaseLlmConnection, the connection to the Gemini model.
232236
"""
237+
# add tracking headers to custom headers and set api_version given
238+
# the customized http options will override the one set in the api client
239+
# constructor
240+
if (
241+
llm_request.live_connect_config
242+
and llm_request.live_connect_config.http_options
243+
):
244+
if not llm_request.live_connect_config.http_options.headers:
245+
llm_request.live_connect_config.http_options.headers = {}
246+
llm_request.live_connect_config.http_options.headers.update(
247+
self._tracking_headers
248+
)
249+
llm_request.live_connect_config.http_options.api_version = (
250+
self._live_api_version
251+
)
233252

234253
llm_request.live_connect_config.system_instruction = types.Content(
235254
role='system',

tests/unittests/models/test_google_llm.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,255 @@ async def __aexit__(self, *args):
341341
assert connection is mock_connection
342342

343343

344+
@pytest.mark.asyncio
345+
async def test_generate_content_async_with_custom_headers(
346+
gemini_llm, llm_request, generate_content_response
347+
):
348+
"""Test that tracking headers are updated when custom headers are provided."""
349+
# Add custom headers to the request config
350+
custom_headers = {"custom-header": "custom-value"}
351+
for key in gemini_llm._tracking_headers:
352+
custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
353+
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
354+
355+
with mock.patch.object(gemini_llm, "api_client") as mock_client:
356+
# Create a mock coroutine that returns the generate_content_response
357+
async def mock_coro():
358+
return generate_content_response
359+
360+
mock_client.aio.models.generate_content.return_value = mock_coro()
361+
362+
responses = [
363+
resp
364+
async for resp in gemini_llm.generate_content_async(
365+
llm_request, stream=False
366+
)
367+
]
368+
369+
# Verify that the config passed to generate_content contains merged headers
370+
mock_client.aio.models.generate_content.assert_called_once()
371+
call_args = mock_client.aio.models.generate_content.call_args
372+
config_arg = call_args.kwargs["config"]
373+
374+
for key, value in config_arg.http_options.headers.items():
375+
if key in gemini_llm._tracking_headers:
376+
assert value == gemini_llm._tracking_headers[key]
377+
else:
378+
assert value == custom_headers[key]
379+
380+
assert len(responses) == 1
381+
assert isinstance(responses[0], LlmResponse)
382+
383+
384+
@pytest.mark.asyncio
385+
async def test_generate_content_async_stream_with_custom_headers(
386+
gemini_llm, llm_request
387+
):
388+
"""Test that tracking headers are updated when custom headers are provided in streaming mode."""
389+
# Add custom headers to the request config
390+
custom_headers = {"custom-header": "custom-value"}
391+
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
392+
393+
with mock.patch.object(gemini_llm, "api_client") as mock_client:
394+
# Create mock stream responses
395+
class MockAsyncIterator:
396+
397+
def __init__(self, seq):
398+
self.iter = iter(seq)
399+
400+
def __aiter__(self):
401+
return self
402+
403+
async def __anext__(self):
404+
try:
405+
return next(self.iter)
406+
except StopIteration:
407+
raise StopAsyncIteration
408+
409+
mock_responses = [
410+
types.GenerateContentResponse(
411+
candidates=[
412+
types.Candidate(
413+
content=Content(
414+
role="model", parts=[Part.from_text(text="Hello")]
415+
),
416+
finish_reason=types.FinishReason.STOP,
417+
)
418+
]
419+
)
420+
]
421+
422+
async def mock_coro():
423+
return MockAsyncIterator(mock_responses)
424+
425+
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
426+
427+
responses = [
428+
resp
429+
async for resp in gemini_llm.generate_content_async(
430+
llm_request, stream=True
431+
)
432+
]
433+
434+
# Verify that the config passed to generate_content_stream contains merged headers
435+
mock_client.aio.models.generate_content_stream.assert_called_once()
436+
call_args = mock_client.aio.models.generate_content_stream.call_args
437+
config_arg = call_args.kwargs["config"]
438+
439+
expected_headers = custom_headers.copy()
440+
expected_headers.update(gemini_llm._tracking_headers)
441+
assert config_arg.http_options.headers == expected_headers
442+
443+
assert len(responses) == 2
444+
445+
446+
@pytest.mark.asyncio
447+
async def test_generate_content_async_without_custom_headers(
448+
gemini_llm, llm_request, generate_content_response
449+
):
450+
"""Test that tracking headers are not modified when no custom headers exist."""
451+
# Ensure no http_options exist initially
452+
llm_request.config.http_options = None
453+
454+
with mock.patch.object(gemini_llm, "api_client") as mock_client:
455+
456+
async def mock_coro():
457+
return generate_content_response
458+
459+
mock_client.aio.models.generate_content.return_value = mock_coro()
460+
461+
responses = [
462+
resp
463+
async for resp in gemini_llm.generate_content_async(
464+
llm_request, stream=False
465+
)
466+
]
467+
468+
# Verify that the config passed to generate_content has no http_options
469+
mock_client.aio.models.generate_content.assert_called_once()
470+
call_args = mock_client.aio.models.generate_content.call_args
471+
config_arg = call_args.kwargs["config"]
472+
assert config_arg.http_options is None
473+
474+
assert len(responses) == 1
475+
476+
477+
def test_live_api_version_vertex_ai(gemini_llm):
478+
"""Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
479+
with mock.patch.object(
480+
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
481+
):
482+
assert gemini_llm._live_api_version == "v1beta1"
483+
484+
485+
def test_live_api_version_gemini_api(gemini_llm):
486+
"""Test that _live_api_version returns 'v1alpha' for Gemini API backend."""
487+
with mock.patch.object(
488+
gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API
489+
):
490+
assert gemini_llm._live_api_version == "v1alpha"
491+
492+
493+
def test_live_api_client_properties(gemini_llm):
494+
"""Test that _live_api_client is properly configured with tracking headers and API version."""
495+
with mock.patch.object(
496+
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
497+
):
498+
client = gemini_llm._live_api_client
499+
500+
# Verify that the client has the correct headers and API version
501+
http_options = client._api_client._http_options
502+
assert http_options.api_version == "v1beta1"
503+
504+
# Check that tracking headers are included
505+
tracking_headers = gemini_llm._tracking_headers
506+
for key, value in tracking_headers.items():
507+
assert key in http_options.headers
508+
assert value in http_options.headers[key]
509+
510+
511+
@pytest.mark.asyncio
512+
async def test_connect_with_custom_headers(gemini_llm, llm_request):
513+
"""Test that connect method updates tracking headers and API version when custom headers are provided."""
514+
# Setup request with live connect config and custom headers
515+
custom_headers = {"custom-live-header": "live-value"}
516+
llm_request.live_connect_config = types.LiveConnectConfig(
517+
http_options=types.HttpOptions(headers=custom_headers)
518+
)
519+
520+
mock_live_session = mock.AsyncMock()
521+
522+
# Mock the _live_api_client to return a mock client
523+
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
524+
# Create a mock context manager
525+
class MockLiveConnect:
526+
527+
async def __aenter__(self):
528+
return mock_live_session
529+
530+
async def __aexit__(self, *args):
531+
pass
532+
533+
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
534+
535+
async with gemini_llm.connect(llm_request) as connection:
536+
# Verify that the connect method was called with the right config
537+
mock_live_client.aio.live.connect.assert_called_once()
538+
call_args = mock_live_client.aio.live.connect.call_args
539+
config_arg = call_args.kwargs["config"]
540+
541+
# Verify that tracking headers were merged with custom headers
542+
expected_headers = custom_headers.copy()
543+
expected_headers.update(gemini_llm._tracking_headers)
544+
assert config_arg.http_options.headers == expected_headers
545+
546+
# Verify that API version was set
547+
assert config_arg.http_options.api_version == gemini_llm._live_api_version
548+
549+
# Verify that system instruction and tools were set
550+
assert config_arg.system_instruction is not None
551+
assert config_arg.tools == llm_request.config.tools
552+
553+
# Verify connection is properly wrapped
554+
assert isinstance(connection, GeminiLlmConnection)
555+
556+
557+
@pytest.mark.asyncio
558+
async def test_connect_without_custom_headers(gemini_llm, llm_request):
559+
"""Test that connect method works properly when no custom headers are provided."""
560+
# Setup request with live connect config but no custom headers
561+
llm_request.live_connect_config = types.LiveConnectConfig()
562+
563+
mock_live_session = mock.AsyncMock()
564+
565+
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
566+
567+
class MockLiveConnect:
568+
569+
async def __aenter__(self):
570+
return mock_live_session
571+
572+
async def __aexit__(self, *args):
573+
pass
574+
575+
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
576+
577+
async with gemini_llm.connect(llm_request) as connection:
578+
# Verify that the connect method was called with the right config
579+
mock_live_client.aio.live.connect.assert_called_once()
580+
call_args = mock_live_client.aio.live.connect.call_args
581+
config_arg = call_args.kwargs["config"]
582+
583+
# Verify that http_options remains None since no custom headers were provided
584+
assert config_arg.http_options is None
585+
586+
# Verify that system instruction and tools were still set
587+
assert config_arg.system_instruction is not None
588+
assert config_arg.tools == llm_request.config.tools
589+
590+
assert isinstance(connection, GeminiLlmConnection)
591+
592+
344593
@pytest.mark.parametrize(
345594
(
346595
"api_backend, "

0 commit comments

Comments
 (0)
0