File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -553,7 +553,7 @@ def sample(
553
553
logits [:] = (
554
554
logits_processor (self ._input_ids , logits )
555
555
if idx is None
556
- else logits_processor (self ._input_ids [:idx ], logits )
556
+ else logits_processor (self ._input_ids [:idx + 1 ], logits )
557
557
)
558
558
559
559
sampling_params = _LlamaSamplingParams (
@@ -658,7 +658,6 @@ def generate(
658
658
while True :
659
659
self .eval (tokens )
660
660
while sample_idx < self .n_tokens :
661
- next_sample_idx = sample_idx + 1
662
661
token = self .sample (
663
662
top_k = top_k ,
664
663
top_p = top_p ,
@@ -675,7 +674,7 @@ def generate(
675
674
logits_processor = logits_processor ,
676
675
grammar = grammar ,
677
676
penalize_nl = penalize_nl ,
678
- idx = next_sample_idx ,
677
+ idx = sample_idx ,
679
678
)
680
679
681
680
sample_idx += 1
You can’t perform that action at this time.
0 commit comments