@@ -341,6 +341,255 @@ async def __aexit__(self, *args):
341
341
assert connection is mock_connection
342
342
343
343
344
+ @pytest .mark .asyncio
345
+ async def test_generate_content_async_with_custom_headers (
346
+ gemini_llm , llm_request , generate_content_response
347
+ ):
348
+ """Test that tracking headers are updated when custom headers are provided."""
349
+ # Add custom headers to the request config
350
+ custom_headers = {"custom-header" : "custom-value" }
351
+ for key in gemini_llm ._tracking_headers :
352
+ custom_headers [key ] = "custom " + gemini_llm ._tracking_headers [key ]
353
+ llm_request .config .http_options = types .HttpOptions (headers = custom_headers )
354
+
355
+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
356
+ # Create a mock coroutine that returns the generate_content_response
357
+ async def mock_coro ():
358
+ return generate_content_response
359
+
360
+ mock_client .aio .models .generate_content .return_value = mock_coro ()
361
+
362
+ responses = [
363
+ resp
364
+ async for resp in gemini_llm .generate_content_async (
365
+ llm_request , stream = False
366
+ )
367
+ ]
368
+
369
+ # Verify that the config passed to generate_content contains merged headers
370
+ mock_client .aio .models .generate_content .assert_called_once ()
371
+ call_args = mock_client .aio .models .generate_content .call_args
372
+ config_arg = call_args .kwargs ["config" ]
373
+
374
+ for key , value in config_arg .http_options .headers .items ():
375
+ if key in gemini_llm ._tracking_headers :
376
+ assert value == gemini_llm ._tracking_headers [key ]
377
+ else :
378
+ assert value == custom_headers [key ]
379
+
380
+ assert len (responses ) == 1
381
+ assert isinstance (responses [0 ], LlmResponse )
382
+
383
+
384
+ @pytest .mark .asyncio
385
+ async def test_generate_content_async_stream_with_custom_headers (
386
+ gemini_llm , llm_request
387
+ ):
388
+ """Test that tracking headers are updated when custom headers are provided in streaming mode."""
389
+ # Add custom headers to the request config
390
+ custom_headers = {"custom-header" : "custom-value" }
391
+ llm_request .config .http_options = types .HttpOptions (headers = custom_headers )
392
+
393
+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
394
+ # Create mock stream responses
395
+ class MockAsyncIterator :
396
+
397
+ def __init__ (self , seq ):
398
+ self .iter = iter (seq )
399
+
400
+ def __aiter__ (self ):
401
+ return self
402
+
403
+ async def __anext__ (self ):
404
+ try :
405
+ return next (self .iter )
406
+ except StopIteration :
407
+ raise StopAsyncIteration
408
+
409
+ mock_responses = [
410
+ types .GenerateContentResponse (
411
+ candidates = [
412
+ types .Candidate (
413
+ content = Content (
414
+ role = "model" , parts = [Part .from_text (text = "Hello" )]
415
+ ),
416
+ finish_reason = types .FinishReason .STOP ,
417
+ )
418
+ ]
419
+ )
420
+ ]
421
+
422
+ async def mock_coro ():
423
+ return MockAsyncIterator (mock_responses )
424
+
425
+ mock_client .aio .models .generate_content_stream .return_value = mock_coro ()
426
+
427
+ responses = [
428
+ resp
429
+ async for resp in gemini_llm .generate_content_async (
430
+ llm_request , stream = True
431
+ )
432
+ ]
433
+
434
+ # Verify that the config passed to generate_content_stream contains merged headers
435
+ mock_client .aio .models .generate_content_stream .assert_called_once ()
436
+ call_args = mock_client .aio .models .generate_content_stream .call_args
437
+ config_arg = call_args .kwargs ["config" ]
438
+
439
+ expected_headers = custom_headers .copy ()
440
+ expected_headers .update (gemini_llm ._tracking_headers )
441
+ assert config_arg .http_options .headers == expected_headers
442
+
443
+ assert len (responses ) == 2
444
+
445
+
446
+ @pytest .mark .asyncio
447
+ async def test_generate_content_async_without_custom_headers (
448
+ gemini_llm , llm_request , generate_content_response
449
+ ):
450
+ """Test that tracking headers are not modified when no custom headers exist."""
451
+ # Ensure no http_options exist initially
452
+ llm_request .config .http_options = None
453
+
454
+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
455
+
456
+ async def mock_coro ():
457
+ return generate_content_response
458
+
459
+ mock_client .aio .models .generate_content .return_value = mock_coro ()
460
+
461
+ responses = [
462
+ resp
463
+ async for resp in gemini_llm .generate_content_async (
464
+ llm_request , stream = False
465
+ )
466
+ ]
467
+
468
+ # Verify that the config passed to generate_content has no http_options
469
+ mock_client .aio .models .generate_content .assert_called_once ()
470
+ call_args = mock_client .aio .models .generate_content .call_args
471
+ config_arg = call_args .kwargs ["config" ]
472
+ assert config_arg .http_options is None
473
+
474
+ assert len (responses ) == 1
475
+
476
+
477
+ def test_live_api_version_vertex_ai (gemini_llm ):
478
+ """Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
479
+ with mock .patch .object (
480
+ gemini_llm , "_api_backend" , GoogleLLMVariant .VERTEX_AI
481
+ ):
482
+ assert gemini_llm ._live_api_version == "v1beta1"
483
+
484
+
485
+ def test_live_api_version_gemini_api (gemini_llm ):
486
+ """Test that _live_api_version returns 'v1alpha' for Gemini API backend."""
487
+ with mock .patch .object (
488
+ gemini_llm , "_api_backend" , GoogleLLMVariant .GEMINI_API
489
+ ):
490
+ assert gemini_llm ._live_api_version == "v1alpha"
491
+
492
+
493
+ def test_live_api_client_properties (gemini_llm ):
494
+ """Test that _live_api_client is properly configured with tracking headers and API version."""
495
+ with mock .patch .object (
496
+ gemini_llm , "_api_backend" , GoogleLLMVariant .VERTEX_AI
497
+ ):
498
+ client = gemini_llm ._live_api_client
499
+
500
+ # Verify that the client has the correct headers and API version
501
+ http_options = client ._api_client ._http_options
502
+ assert http_options .api_version == "v1beta1"
503
+
504
+ # Check that tracking headers are included
505
+ tracking_headers = gemini_llm ._tracking_headers
506
+ for key , value in tracking_headers .items ():
507
+ assert key in http_options .headers
508
+ assert value in http_options .headers [key ]
509
+
510
+
511
+ @pytest .mark .asyncio
512
+ async def test_connect_with_custom_headers (gemini_llm , llm_request ):
513
+ """Test that connect method updates tracking headers and API version when custom headers are provided."""
514
+ # Setup request with live connect config and custom headers
515
+ custom_headers = {"custom-live-header" : "live-value" }
516
+ llm_request .live_connect_config = types .LiveConnectConfig (
517
+ http_options = types .HttpOptions (headers = custom_headers )
518
+ )
519
+
520
+ mock_live_session = mock .AsyncMock ()
521
+
522
+ # Mock the _live_api_client to return a mock client
523
+ with mock .patch .object (gemini_llm , "_live_api_client" ) as mock_live_client :
524
+ # Create a mock context manager
525
+ class MockLiveConnect :
526
+
527
+ async def __aenter__ (self ):
528
+ return mock_live_session
529
+
530
+ async def __aexit__ (self , * args ):
531
+ pass
532
+
533
+ mock_live_client .aio .live .connect .return_value = MockLiveConnect ()
534
+
535
+ async with gemini_llm .connect (llm_request ) as connection :
536
+ # Verify that the connect method was called with the right config
537
+ mock_live_client .aio .live .connect .assert_called_once ()
538
+ call_args = mock_live_client .aio .live .connect .call_args
539
+ config_arg = call_args .kwargs ["config" ]
540
+
541
+ # Verify that tracking headers were merged with custom headers
542
+ expected_headers = custom_headers .copy ()
543
+ expected_headers .update (gemini_llm ._tracking_headers )
544
+ assert config_arg .http_options .headers == expected_headers
545
+
546
+ # Verify that API version was set
547
+ assert config_arg .http_options .api_version == gemini_llm ._live_api_version
548
+
549
+ # Verify that system instruction and tools were set
550
+ assert config_arg .system_instruction is not None
551
+ assert config_arg .tools == llm_request .config .tools
552
+
553
+ # Verify connection is properly wrapped
554
+ assert isinstance (connection , GeminiLlmConnection )
555
+
556
+
557
+ @pytest .mark .asyncio
558
+ async def test_connect_without_custom_headers (gemini_llm , llm_request ):
559
+ """Test that connect method works properly when no custom headers are provided."""
560
+ # Setup request with live connect config but no custom headers
561
+ llm_request .live_connect_config = types .LiveConnectConfig ()
562
+
563
+ mock_live_session = mock .AsyncMock ()
564
+
565
+ with mock .patch .object (gemini_llm , "_live_api_client" ) as mock_live_client :
566
+
567
+ class MockLiveConnect :
568
+
569
+ async def __aenter__ (self ):
570
+ return mock_live_session
571
+
572
+ async def __aexit__ (self , * args ):
573
+ pass
574
+
575
+ mock_live_client .aio .live .connect .return_value = MockLiveConnect ()
576
+
577
+ async with gemini_llm .connect (llm_request ) as connection :
578
+ # Verify that the connect method was called with the right config
579
+ mock_live_client .aio .live .connect .assert_called_once ()
580
+ call_args = mock_live_client .aio .live .connect .call_args
581
+ config_arg = call_args .kwargs ["config" ]
582
+
583
+ # Verify that http_options remains None since no custom headers were provided
584
+ assert config_arg .http_options is None
585
+
586
+ # Verify that system instruction and tools were still set
587
+ assert config_arg .system_instruction is not None
588
+ assert config_arg .tools == llm_request .config .tools
589
+
590
+ assert isinstance (connection , GeminiLlmConnection )
591
+
592
+
344
593
@pytest .mark .parametrize (
345
594
(
346
595
"api_backend, "
0 commit comments