8000 Add set_seed to Llama class · chiensen/llama-cpp-python@fd41ed3 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd41ed3

Browse files
committed
Add set_seed to Llama class
1 parent ca4cb88 commit fd41ed3

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

llama_cpp/llama.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,15 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
998998
"""
999999
self.cache = cache
10001000

1001+
def set_seed(self, seed: int):
1002+
"""Set the random seed.
1003+
1004+
Args:
1005+
seed: The random seed.
1006+
"""
1007+
assert self._ctx.ctx is not None
1008+
llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
1009+
10011010
def reset(self):
10021011
"""Reset the model state."""
10031012
self.n_tokens = 0
@@ -1318,10 +1327,14 @@ def _create_completion(
13181327
completion_tokens: List[int] = []
13191328
# Add blank space to start of prompt to match OG llama tokenizer
13201329
prompt_tokens: List[int] = (
1321-
self.tokenize(prompt.encode("utf-8"), special=True)
1322-
if prompt != ""
1323-
else [self.token_bos()]
1324-
) if isinstance(prompt, str) else prompt
1330+
(
1331+
self.tokenize(prompt.encode("utf-8"), special=True)
1332+
if prompt != ""
1333+
else [self.token_bos()]
1334+
)
1335+
if isinstance(prompt, str)
1336+
else prompt
1337+
)
13251338
text: bytes = b""
13261339
returned_tokens: int = 0
13271340
stop = (
@@ -1374,7 +1387,7 @@ def _create_completion(
13741387
except KeyError:
13751388
if self.verbose:
13761389
print("Llama._create_completion: cache miss", file=sys.stderr)
1377-
1390+
13781391
if seed is not None:
13791392
self._ctx.set_rng_seed(seed)
13801393

0 commit comments

Comments
 (0)
0