8000 fix: Handles input_embeds in GenerateReqInput when n>1 (#7830) · sgl-project/sglang@136c6e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 136c6e0

Browse files
fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
1 parent 43e20c0 commit 136c6e0

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

python/sglang/srt/managers/io_struct.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def _handle_parallel_sampling(self):
200200
self.text = [self.text]
201201
if self.input_ids is not None:
202202
self.input_ids = [self.input_ids]
203+
if self.input_embeds is not None:
204+
self.input_embeds = [self.input_embeds]
203205

204206
def _normalize_single_inputs(self):
205207
"""Normalize inputs for a single example."""
@@ -324,7 +326,9 @@ def _normalize_rid(self, num):
324326
new_rids = [f"{self.rid}_{i}" for i in range(num)]
325327
self.rid = new_rids
326328
elif isinstance(self.rid, list):
327-
if len(self.rid) != num:
329+
# Note: the length of rid shall be the same as the batch_size,
330+
# as the rid would be expanded for parallel sampling in tokenizer_manager
331+
if len(self.rid) != self.batch_size:
328332
raise ValueError(
329333
"The specified rids length mismatch with the batch_size for batch processing."
330334
)
@@ -400,6 +404,9 @@ def __getitem__(self, i):
400404
return GenerateReqInput(
401405
text=self.text[i] if self.text is not None else None,
402406
input_ids=self.input_ids[i] if self.input_ids is not None else None,
407+
input_embeds=(
408+
self.input_embeds[i] if self.input_embeds is not None else None
409+
),
403410
image_data=self.image_data[i],
404411
audio_data=self.audio_data[i],
405412
sampling_params=self.sampling_params[i],

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class TestFile:
6767
TestFile("test_hidden_states.py", 55),
6868
TestFile("test_int8_kernel.py", 8),
6969
TestFile("test_input_embeddings.py", 38),
70+
TestFile("test_io_struct.py", 8),
7071
TestFile("test_jinja_template_utils.py", 1),
7172
TestFile("test_metrics.py", 32),
7273
TestFile("test_mla.py", 167),

test/srt/test_io_struct.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def test_mixed_none_and_images_with_parallel_samples(self):
159159
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
160160
req = copy.deepcopy(self.base_req)
161161
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
162+
req.rid = ["id1", "id2", "id3"]
162163
req.image_data = [
163164
["image1.jpg"],
164165
None,
@@ -311,6 +312,71 @@ def test_input_embeds_normalization(self):
311312
self.assertFalse(req.is_single)
312313
self.assertEqual(req.batch_size, 2)
313314

315+
def test_input_embeds_with_parallel_sampling(self):
316+
"""Test input_embeds normalization with parallel sampling (n > 1)."""
317+
# Test single input_embeds with parallel sampling
318+
req = GenerateReqInput(
319+
input_embeds=[[0.1, 0.2]], # single embedding vector
320+
sampling_params={"n": 2},
321+
)
322+
req.normalize_batch_and_arguments()
323+
324+
# Should be converted from single to batch and then expanded
325+
self.assertFalse(req.is_single)
326+
self.assertEqual(len(req.input_embeds), 2)
327+
# Both should be the same input_embeds
328+
self.assertEqual(req.input_embeds[0], [[0.1, 0.2]])
329+
self.assertEqual(req.input_embeds[1], [[0.1, 0.2]])
330+
331+
# Test batch input_embeds with parallel sampling
332+
req = GenerateReqInput(
333+
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3}
334+
)
335+
req.normalize_batch_and_arguments()
336+
337+
# Should be expanded
338+
self.assertFalse(req.is_single)
339+
self.assertEqual(len(req.input_embeds), 6)
340+
341+
# Check that the expansion is correct
342+
expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3
343+
self.assertEqual(req.input_embeds, expected_embeds)
344+
345+
# Test with different n values per sample (should raise error)
346+
req = GenerateReqInput(
347+
input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]],
348+
sampling_params=[{"n": 2}, {"n": 3}],
349+
)
350+
with self.assertRaises(ValueError):
351+
req.normalize_batch_and_arguments()
352+
353+
def test_input_embeds_single_to_batch_conversion(self):
354+
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
355+
# Test the specific case that was fixed: single input_embeds with n > 1
356+
req = GenerateReqInput(
357+
input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} # Single embedding
358+
)
359+
req.normalize_batch_and_arguments()
360+
361+
# Should convert single to batch and then expand
362+
self.assertFalse(req.is_single)
363+
self.assertEqual(len(req.input_embeds), 2)
364+
365+
# Both should be the same single embedding
366+
self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]])
367+
self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]])
368+
369+
# Test with higher n value
370+
req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5})
371+
req.normalize_batch_and_arguments()
372+
373+
self.assertFalse(req.is_single)
374+
self.assertEqual(len(req.input_embeds), 5)
375+
376+
# All should be the same
377+
for i in range(5):
378+
self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]])
379+
314380
def test_lora_path_normalization(self):
315381
"""Test normalization of lora_path."""
316382
# Test single lora_path with batch input

0 commit comments

Comments
 (0)
0