@@ -1292,6 +1292,7 @@ def _create_completion(
1292
1292
repeat_penalty : float = 1.1 ,
1293
1293
top_k : int = 40 ,
1294
1294
stream : bool = False ,
1295
+ seed : Optional [int ] = None ,
1295
1296
tfs_z : float = 1.0 ,
1296
1297
mirostat_mode : int = 0 ,
1297
1298
mirostat_tau : float = 5.0 ,
@@ -1367,6 +1368,9 @@ def _create_completion(
1367
1368
except KeyError :
1368
1369
if self .verbose :
1369
1370
print ("Llama._create_completion: cache miss" , file = sys .stderr )
1371
+
1372
+ if seed is not None :
1373
+ self ._ctx .set_rng_seed (seed )
1370
1374
1371
1375
finish_reason = "length"
1372
1376
multibyte_fix = 0
@@ -1750,6 +1754,7 @@ def create_completion(
1750
1754
repeat_penalty : float = 1.1 ,
1751
1755
top_k : int = 40 ,
1752
1756
stream : bool = False ,
1757
+ seed : Optional [int ] = None ,
1753
1758
tfs_z : float = 1.0 ,
1754
1759
mirostat_mode : int = 0 ,
1755
1760
mirostat_tau : float = 5.0 ,
@@ -1795,6 +1800,7 @@ def create_completion(
1795
1800
repeat_penalty = repeat_penalty ,
1796
1801
top_k = top_k ,
1797
1802
stream = stream ,
1803
+ seed = seed ,
1798
1804
tfs_z = tfs_z ,
1799
1805
mirostat_mode = mirostat_mode ,
1800
1806
mirostat_tau = mirostat_tau ,
@@ -1825,6 +1831,7 @@ def __call__(
1825
1831
repeat_penalty : float = 1.1 ,
1826
1832
top_k : int = 40 ,
1827
1833
stream : bool = False ,
1834
+ seed : Optional [int ] = None ,
1828
1835
tfs_z : float = 1.0 ,
1829
1836
mirostat_mode : int = 0 ,
1830
1837
mirostat_tau : float = 5.0 ,
@@ -1870,6 +1877,7 @@ def __call__(
1870
1877
repeat_penalty = repeat_penalty ,
1871
1878
top_k = top_k ,
1872
1879
stream = stream ,
1880
+ seed = seed ,
1873
1881
tfs_z = tfs_z ,
1874
1882
mirostat_mode = mirostat_mode ,
1875
1883
mirostat_tau = mirostat_tau ,
@@ -1892,6 +1900,7 @@ def create_chat_completion(
1892
1900
top_k : int = 40 ,
1893
1901
stream : bool = False ,
1894
1902
stop : Optional [Union [str , List [str ]]] = [],
1903
+ seed : Optional [int ] = None ,
1895
1904
max_tokens : int = 256 ,
1896
1905
presence_penalty : float = 0.0 ,
1897
1906
frequency_penalty : float = 0.0 ,
@@ -1936,6 +1945,7 @@ def create_chat_completion(
1936
1945
top_k = top_k ,
1937
1946
stream = stream ,
1938
1947
stop = stop ,
1948
+ seed = seed ,
1939
1949
max_tokens = max_tokens ,
1940
1950
presence_penalty = presence_penalty ,
1941
1951
frequency_penalty = frequency_penalty ,
0 commit comments