8000 Implement penalize_nl · chabotsi/llama-cpp-python@d28b753 · GitHub
[go: up one dir, main page]

Skip to content

Commit d28b753

Browse files
committed
Implement penalize_nl
1 parent f11e2a7 commit d28b753

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

llama_cpp/llama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def _sample(
291291
mirostat_mode: llama_cpp.c_int,
292292
mirostat_tau: llama_cpp.c_float,
293293
mirostat_eta: llama_cpp.c_float,
294+
penalize_nl: bool = True,
294295
):
295296
assert self.ctx is not None
296297
assert len(self.eval_logits) > 0
@@ -299,6 +300,7 @@ def _sample(
299300
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
300301
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
301302
logits = self.eval_logits[-1]
303+
nl_logit = logits[llama_cpp.llama_token_nl().value]
302304
data = (llama_cpp.llama_token_data * n_vocab)(
303305
*[
304306
llama_cpp.llama_token_data(
@@ -331,6 +333,8 @@ def _sample(
331333
alpha_frequency=frequency_penalty,
332334
alpha_presence=presence_penalty,
333335
)
336+
if not penalize_nl:
337+
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
334338
if temp.value == 0.0:
335339
return llama_cpp.llama_sample_token_greedy(
336340
ctx=self.ctx,
@@ -413,6 +417,7 @@ def sample(
413417
mirostat_mode: int = 0,
414418
mirostat_eta: float = 0.1,
415419
mirostat_tau: float = 5.0,
420+
penalize_nl: bool = True,
416421
):
417422
"""Sample a token from the model.
418423
@@ -444,6 +449,7 @@ def sample(
444449
mirostat_mode=llama_cpp.c_int(mirostat_mode),
445450
mirostat_tau=llama_cpp.c_float(mirostat_tau),
446451
mirostat_eta=llama_cpp.c_float(mirostat_eta),
452+
penalize_nl=penalize_nl,
447453
)
448454

449455
def generate(
@@ -1170,6 +1176,11 @@ def token_bos() -> llama_cpp.llama_token:
11701176
"""Return the beginning-of-sequence token."""
11711177
return llama_cpp.llama_token_bos()
11721178

1179+
@staticmethod
1180+
def token_nl() -> llama_cpp.llama_token:
1181+
"""Return the newline token."""
1182+
return llama_cpp.llama_token_nl()
1183+
11731184
@staticmethod
11741185
def logits_to_logprobs(logits: List[float]) -> List[float]:
11751186
exps = [math.exp(float(x)) for x in logits]

0 commit comments

Comments
 (0)
0