@@ -291,6 +291,7 @@ def _sample(
291
291
mirostat_mode : llama_cpp .c_int ,
292
292
mirostat_tau : llama_cpp .c_float ,
293
293
mirostat_eta : llama_cpp .c_float ,
294
+ penalize_nl : bool = True ,
294
295
):
295
296
assert self .ctx is not None
296
297
assert len (self .eval_logits ) > 0
@@ -299,6 +300,7 @@ def _sample(
299
300
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
300
301
last_n_tokens_size = llama_cpp .c_int (n_ctx ) if last_n_tokens_size .value < 0 else last_n_tokens_size
301
302
logits = self .eval_logits [- 1 ]
303
+ nl_logit = logits [llama_cpp .llama_token_nl ().value ]
302
304
data = (llama_cpp .llama_token_data * n_vocab )(
303
305
* [
304
306
llama_cpp .llama_token_data (
@@ -331,6 +333,8 @@ def _sample(
331
333
alpha_frequency = frequency_penalty ,
332
334
alpha_presence = presence_penalty ,
333
335
)
336
+ if not penalize_nl :
337
+ candidates .data [llama_cpp .llama_token_nl ().value ].logit = nl_logit
334
338
if temp .value == 0.0 :
335
339
return llama_cpp .llama_sample_token_greedy (
336
340
ctx = self .ctx ,
@@ -413,6 +417,7 @@ def sample(
413
417
mirostat_mode : int = 0 ,
414
418
mirostat_eta : float = 0.1 ,
415
419
mirostat_tau : float = 5.0 ,
420
+ penalize_nl : bool = True ,
416
421
):
417
422
"""Sample a token from the model.
418
423
@@ -444,6 +449,7 @@ def sample(
444
449
mirostat_mode = llama_cpp .c_int (mirostat_mode ),
445
450
mirostat_tau = llama_cpp .c_float (mirostat_tau ),
446
451
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
452
+ penalize_nl = penalize_nl ,
447
453
)
448
454
449
455
def generate (
@@ -1170,6 +1176,11 @@ def token_bos() -> llama_cpp.llama_token:
1170
1176
"""Return the beginning-of-sequence token."""
1171
1177
return llama_cpp .llama_token_bos ()
1172
1178
1179
+ @staticmethod
1180
+ def token_nl () -> llama_cpp .llama_token :
1181
+ """Return the newline token."""
1182
+ return llama_cpp .llama_token_nl ()
1183
+
1173
1184
@staticmethod
1174
1185
def logits_to_logprobs (logits : List [float ]) -> List [float ]:
1175
1186
exps = [math .exp (float (x )) for x in logits ]
0 commit comments