10000 Add seed parameter support for completion and chat_completion request… · bdx0/llama-cpp-python@86aeb9f · GitHub
[go: up one dir, main page]

Skip to content

Commit 86aeb9f

Browse files
committed
Add seed parameter support for completion and chat_completion requests. Closes abetlen#884
1 parent da1b802 commit 86aeb9f

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

llama_cpp/llama.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,7 @@ def _create_completion(
12921292
repeat_penalty: float = 1.1,
12931293
top_k: int = 40,
12941294
stream: bool = False,
1295+
seed: Optional[int] = None,
12951296
tfs_z: float = 1.0,
12961297
mirostat_mode: int = 0,
12971298
mirostat_tau: float = 5.0,
@@ -1367,6 +1368,9 @@ def _create_completion(
13671368
except KeyError:
13681369
if self.verbose:
13691370
print("Llama._create_completion: cache miss", file=sys.stderr)
1371+
1372+
if seed is not None:
1373+
self._ctx.set_rng_seed(seed)
13701374

13711375
finish_reason = "length"
13721376
multibyte_fix = 0
@@ -1750,6 +1754,7 @@ def create_completion(
17501754
repeat_penalty: float = 1.1,
17511755
top_k: int = 40,
17521756
stream: bool = False,
1757+
seed: Optional[int] = None,
17531758
tfs_z: float = 1.0,
17541759
mirostat_mode: int = 0,
17551760
mirostat_tau: float = 5.0,
@@ -1795,6 +1800,7 @@ def create_completion(
17951800
repeat_penalty=repeat_penalty,
17961801
top_k=top_k,
17971802
stream=stream,
1803+
seed=seed,
17981804
tfs_z=tfs_z,
17991805
mirostat_mode=mirostat_mode,
18001806
mirostat_tau=mirostat_tau,
@@ -1825,6 +1831,7 @@ def __call__(
18251831
repeat_penalty: float = 1.1,
18261832
top_k: int = 40,
18271833
stream: bool = False,
1834+
seed: Optional[int] = None,
18281835
tfs_z: float = 1.0,
18291836
mirostat_mode: int = 0,
18301837
mirostat_tau: float = 5.0,
@@ -1870,6 +1877,7 @@ def __call__(
18701877
repeat_penalty=repeat_penalty,
18711878
top_k=top_k,
18721879
stream=stream,
1880+
seed=seed,
18731881
tfs_z=tfs_z,
18741882
mirostat_mode=mirostat_mode,
18751883
mirostat_tau=mirostat_tau,
@@ -1892,6 +1900,7 @@ def create_chat_completion(
18921900
top_k: int = 40,
18931901
stream: bool = False,
18941902
stop: Optional[Union[str, List[str]]] = [],
1903+
seed: Optional[int] = None,
18951904
max_tokens: int = 256,
18961905
presence_penalty: float = 0.0,
18971906
frequency_penalty: float = 0.0,
@@ -1936,6 +1945,7 @@ def create_chat_completion(
19361945
top_k=top_k,
19371946
stream=stream,
19381947
stop=stop,
1948+
seed=seed,
19391949
max_tokens=max_tokens,
19401950
presence_penalty=presence_penalty,
19411951
frequency_penalty=frequency_penalty,

llama_cpp/server/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ class CreateCompletionRequest(BaseModel):
608608
frequency_penalty: Optional[float] = frequency_penalty_field
609609
logit_bias: Optional[Dict[str, float]] = Field(None)
610610
logprobs: Optional[int] = Field(None)
611+
seed: Optional[int] = Field(None)
611612

612613
# ignored or currently unsupported
613614
model: Optional[str] = model_field
@@ -790,6 +791,7 @@ class CreateChatCompletionRequest(BaseModel):
790791
presence_penalty: Optional[float] = presence_penalty_field
791792
frequency_penalty: Optional[float] = frequency_penalty_field
792793
logit_bias: Optional[Dict[str, float]] = Field(None)
794+
seed: Optional[int] = Field(None)
793795

794796
# ignored or currently unsupported
795797
model: Optional[str] = model_field

0 commit comments

Comments
 (0)
0