8000 Fix mirostat sampling · coderonion/llama-cpp-python@3babe35 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3babe35

Browse files
committed
Fix mirostat sampling
1 parent 141293a commit 3babe35

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

llama_cpp/llama.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def __init__(
329329
(n_ctx, self._n_vocab), dtype=np.single
330330
)
331331

332+
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
333+
332334
@property
333335
def ctx(self) -> llama_cpp.llama_context_p:
334336
assert self._ctx.ctx is not None
@@ -516,7 +518,7 @@ def sample(
516518
candidates=self._candidates,
517519
tau=mirostat_tau,
518520
eta=mirostat_eta,
519-
mu=2.0 * mirostat_tau,
521+
mu=ctypes.pointer(self._mirostat_mu),
520522
m=100,
521523
)
522524
elif mirostat_mode == 2:
@@ -525,7 +527,7 @@ def sample(
525527
candidates=self._candidates,
526528
tau=mirostat_tau,
527529
eta=mirostat_eta,
528-
mu=2.0 * mirostat_tau,
530+
mu=ctypes.pointer(self._mirostat_mu)
529531
)
530532
else:
531533
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
@@ -581,6 +583,10 @@ def generate(
581583
Yields:
582584
The generated tokens.
583585
"""
586+
# Reset mirostat sampling
587+
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
588+
589+
# Check for kv cache prefix match
584590
if reset and self.n_tokens > 0:
585591
longest_prefix = 0
586592
for a, b in zip(self._input_ids, tokens[:-1]):
@@ -595,12 +601,15 @@ def generate(
595601
tokens = tokens[longest_prefix:]
596602
self.n_tokens = longest_prefix
597603

604+
# Reset the model state
598605
if reset:
599606
self.reset()
600607

608+
# Reset the grammar
601609
if grammar is not None:
602610
grammar.reset()
603611

612+
# Eval and sample
604613
while True:
605614
self.eval(tokens)
606615
token = self.sample(

0 commit comments

Comments
 (0)
0