@@ -1552,7 +1552,7 @@ def logit_bias_processor(
1552
1552
self .detokenize (completion_tokens [:returned_tokens ])
1553
1553
)
1554
1554
token_offset = len (prompt_tokens ) + returned_tokens
1555
- logits = self ._scores [token_offset - 1 , :]. tolist ()
1555
+ logits = self ._scores [token_offset - 1 , :]
1556
1556
current_logprobs = Llama .logits_to_logprobs (logits )
1557
1557
sorted_logprobs = list (
1558
1558
sorted (
@@ -1671,7 +1671,7 @@ def logit_bias_processor(
1671
1671
self .detokenize (completion_tokens [:returned_tokens ])
1672
1672
)
1673
1673
token_offset = len (prompt_tokens ) + returned_tokens - 1
1674
- logits = self ._scores [token_offset , :]. tolist ()
1674
+ logits = self ._scores [token_offset , :]
1675
1675
current_logprobs = Llama .logits_to_logprobs (logits )
1676
1676
sorted_logprobs = list (
1677
1677
sorted (
@@ -1785,9 +1785,8 @@ def logit_bias_processor(
1785
1785
self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1786
1786
for token in all_tokens
1787
1787
]
1788
- all_logprobs = [
1789
- Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
1790
- ][token_offset :]
1788
+ all_logprobs = Llama .logits_to_logprobs (self ._scores )[token_offset :]
1789
+ # TODO: may be able to change this loop to use np.take_along_dim
1791
1790
for token , token_str , logprobs_token in zip (
1792
1791
all_tokens , all_token_strs , all_logprobs
1793
1792
):
@@ -2282,7 +2281,7 @@ def token_nl(self) -> int:
2282
2281
2283
2282
@staticmethod
2284
2283
def logits_to_logprobs (
2285
- logits : Union [List , npt .NDArray [np .single ]], axis : int = - 1
2284
+ logits : Union [npt .NDArray [np .single ], List ], axis : int = - 1
2286
2285
) -> npt .NDArray [np .single ]:
2287
2286
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
2288
2287
logits_maxs : np .ndarray = np .amax (logits , axis = axis , keepdims = True )
@@ -2293,7 +2292,7 @@ def logits_to_logprobs(
2293
2292
subtract_maxs = np .subtract (logits , logits_maxs , dtype = np .single )
2294
2293
exp = np .exp (subtract_maxs )
2295
2294
# Suppress warnings about log of zero
2296
- with np .errstate (divide = ' ignore' ):
2295
+ with np .errstate (divide = " ignore" ):
2297
2296
summed = np .sum (exp , axis = axis , keepdims = True )
2298
2297
out = np .log (summed )
2299
2298
return subtract_maxs - out
0 commit comments