File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -1029,16 +1029,16 @@ def eval(self, tokens: Sequence[int]):
1029
1029
)
1030
1030
self ._ctx .decode (self ._batch )
1031
1031
# Save tokens
1032
- self .input_ids [self . n_tokens : self . n_tokens + n_tokens ] = batch
1032
+ self .input_ids [n_past : n_past + n_tokens ] = batch
1033
1033
# Save logits
1034
- rows = n_tokens if self . context_params . logits_all else 1
1034
+ rows = n_tokens
1035
1035
cols = self ._n_vocab
1036
1036
offset = (
1037
1037
0 if self .context_params .logits_all else n_tokens - 1
1038
1038
) # NOTE: Only save the last token logits if logits_all is False
1039
- self .scores [self . n_tokens + offset : self . n_tokens + n_tokens , :].reshape (
1039
+ self .scores [n_past + offset : n_past + n_tokens , :].reshape (
1040
1040
- 1
1041
- )[:] = self ._ctx .get_logits ()[: rows * cols ]
1041
+ )[:] = self ._ctx .get_logits ()[offset * cols : rows * cols ]
1042
1042
# Update n_tokens
1043
1043
self .n_tokens += n_tokens
1044
1044
You can’t perform that action at this time.
0 commit comments