File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -451,7 +451,7 @@ def free_lora_adapter():
451
451
self .n_tokens = 0
452
452
self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
453
453
self .scores : npt .NDArray [np .single ] = np .ndarray (
454
- (n_ctx , self ._n_vocab ), dtype = np .single
454
+ (n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
455
455
)
456
456
457
457
self ._mirostat_mu = ctypes .c_float (
@@ -648,12 +648,14 @@ def eval(self, tokens: Sequence[int]):
648
648
)
649
649
self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[::] = logits
650
650
else :
651
- rows = 1
652
- cols = self ._n_vocab
653
- logits = np .ctypeslib .as_array (
654
- self ._ctx .get_logits (), shape = (rows * cols ,)
655
- )
656
- self .scores [n_past + n_tokens - 1 , :].reshape (- 1 )[::] = logits
651
+ # rows = 1
652
+ # cols = self._n_vocab
653
+ # logits = np.ctypeslib.as_array(
654
+ # self._ctx.get_logits(), shape=(rows * cols,)
655 + # )
656
+ # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
657
+ # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
658
+ pass
657
659
# Update n_tokens
658
660
self .n_tokens += n_tokens
659
661
You can’t perform that action at this time.
0 commit comments