@@ -159,6 +159,7 @@ def test_mixed_none_and_images_with_parallel_samples(self):
159
159
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
160
160
req = copy .deepcopy (self .base_req )
161
161
req .text = ["Prompt 1" , "Prompt 2" , "Prompt 3" ]
162
+ req .rid = ["id1" , "id2" , "id3" ]
162
163
req .image_data = [
163
164
["image1.jpg" ],
164
165
None ,
@@ -311,6 +312,71 @@ def test_input_embeds_normalization(self):
311
312
self .assertFalse (req .is_single )
312
313
self .assertEqual (req .batch_size , 2 )
313
314
315
+ def test_input_embeds_with_parallel_sampling (self ):
316
+ """Test input_embeds normalization with parallel sampling (n > 1)."""
317
+ # Test single input_embeds with parallel sampling
318
+ req = GenerateReqInput (
319
+ input_embeds = [[0.1 , 0.2 ]], # single embedding vector
320
+ sampling_params = {"n" : 2 },
321
+ )
322
+ req .normalize_batch_and_arguments ()
323
+
324
+ # Should be converted from single to batch and then expanded
325
+ self .assertFalse (req .is_single )
326
+ self .assertEqual (len (req .input_embeds ), 2 )
327
+ # Both should be the same input_embeds
328
+ self .assertEqual (req .input_embeds [0 ], [[0.1 , 0.2 ]])
329
+ self .assertEqual (req .input_embeds [1 ], [[0.1 , 0.2 ]])
330
+
331
+ # Test batch input_embeds with parallel sampling
332
+ req = GenerateReqInput (
333
+ input_embeds = [[[0.1 , 0.2 ]], [[0.3 , 0.4 ]]], sampling_params = {"n" : 3 }
334
+ )
335
+ req .normalize_batch_and_arguments ()
336
+
337
+ # Should be expanded
338
+ self .assertFalse (req .is_single )
339
+ self .assertEqual (len (req .input_embeds ), 6 )
340
+
341
+ # Check that the expansion is correct
342
+ expected_embeds = [[[0.1 , 0.2 ]], [[0.3 , 0.4 ]]] * 3
343
+ self .assertEqual (req .input_embeds , expected_embeds )
344
+
345
+ # Test with different n values per sample (should raise error)
346
+ req = GenerateReqInput (
347
+ input_embeds = [[[0.1 , 0.2 ]], [[0.3 , 0.4 ]]],
348
+ sampling_params = [{"n" : 2 }, {"n" : 3 }],
349
+ )
350
+ with self .assertRaises (ValueError ):
351
+ req .normalize_batch_and_arguments ()
352
+
353
+ def test_input_embeds_single_to_batch_conversion (self ):
354
+ """Test that single input_embeds are properly converted to batch when using parallel sampling."""
355
+ # Test the specific case that was fixed: single input_embeds with n > 1
356
+ req = GenerateReqInput (
357
+ input_embeds = [[0.1 , 0.2 , 0.3 ]], sampling_params = {"n" : 2 } # Single embedding
358
+ )
359
+ req .normalize_batch_and_arguments ()
360
+
361
+ # Should convert single to batch and then expand
362
+ self .assertFalse (req .is_single )
363
+ self .assertEqual (len (req .input_embeds ), 2 )
364
+
365
+ # Both should be the same single embedding
366
+ self .assertEqual (req .input_embeds [0 ], [[0.1 , 0.2 , 0.3 ]])
367
+ self .assertEqual (req .input_embeds [1 ], [[0.1 , 0.2 , 0.3 ]])
368
+
369
+ # Test with higher n value
370
+ req = GenerateReqInput (input_embeds = [[0.1 , 0.2 , 0.3 ]], sampling_params = {"n" : 5 })
371
+ req .normalize_batch_and_arguments ()
372
+
373
+ self .assertFalse (req .is_single )
374
+ self .assertEqual (len (req .input_embeds ), 5 )
375
+
376
+ # All should be the same
377
+ for i in range (5 ):
378
+ self .assertEqual (req .input_embeds [i ], [[0.1 , 0.2 , 0.3 ]])
379
+
314
380
def test_lora_path_normalization (self ):
315
381
"""Test normalization of lora_path."""
316
382
# Test single lora_path with batch input
0 commit comments