4
4
import sys
5
5
import uuid
6
6
import time
7
+ import json
8
+ import fnmatch
7
9
import multiprocessing
8
10
from typing import (
9
11
List ,
16
18
Callable ,
17
19
)
18
20
from collections import deque
21
+ from pathlib import Path
19
22
20
23
import ctypes
21
24
29
32
LlamaDiskCache , # type: ignore
30
33
LlamaRAMCache , # type: ignore
31
34
)
32
- from .llama_tokenizer import (
33
- BaseLlamaTokenizer ,
34
- LlamaTokenizer
35
- )
35
+ from .llama_tokenizer import BaseLlamaTokenizer , LlamaTokenizer
36
36
import llama_cpp .llama_cpp as llama_cpp
37
37
import llama_cpp .llama_chat_format as llama_chat_format
38
38
50
50
_LlamaSamplingContext , # type: ignore
51
51
)
52
52
from ._logger import set_verbose
53
- from ._utils import (
54
- suppress_stdout_stderr
55
- )
53
+ from ._utils import suppress_stdout_stderr
56
54
57
55
58
56
class Llama :
@@ -189,7 +187,11 @@ def __init__(
189
187
Llama .__backend_initialized = True
190
188
191
189
if isinstance (numa , bool ):
192
- self .numa = llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp .GGML_NUMA_STRATEGY_DISABLED
190
+ self .numa = (
191
+ llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE
192
+ if numa
193
+ else llama_cpp .GGML_NUMA_STRATEGY_DISABLED
194
+ )
193
195
else :
194
196
self .numa = numa
195
197
@@ -246,17 +248,17 @@ def __init__(
246
248
else :
247
249
raise ValueError (f"Unknown value type for { k } : { v } " )
248
250
249
- self ._kv_overrides_array [
250
- - 1
251
- ]. key = b" \0 " # ensure sentinel element is zeroed
251
+ self ._kv_overrides_array [- 1 ]. key = (
252
+ b" \0 " # ensure sentinel element is zeroed
253
+ )
252
254
self .model_params .kv_overrides = self ._kv_overrides_array
253
255
254
256
self .n_batch = min (n_ctx , n_batch ) # ???
255
257
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
256
258
self .n_threads_batch = n_threads_batch or max (
257
259
multiprocessing .cpu_count () // 2 , 1
258
260
)
259
-
261
+
260
262
# Context Params
261
263
self .context_params = llama_cpp .llama_context_default_params ()
262
264
self .context_params .seed = seed
@@ -289,7 +291,9 @@ def __init__(
289
291
)
290
292
self .context_params .yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
291
293
self .context_params .mul_mat_q = mul_mat_q
292
- self .context_params .logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
294
+ self .context_params .logits_all = (
295
+ logits_all if draft_model is None else True
296
+ ) # Must be set to True for speculative decoding
293
297
self .context_params .embedding = embedding
294
298
self .context_params .offload_kqv = offload_kqv
295
299
@@ -379,8 +383,14 @@ def __init__(
379
383
if self .verbose :
380
384
print (f"Model metadata: { self .metadata } " , file = sys .stderr )
381
385
382
- if self .chat_format is None and self .chat_handler is None and "tokenizer.chat_template" in self .metadata :
383
- chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (self .metadata )
386
+ if (
387
+ self .chat_format is None
388
+ and self .chat_handler is None
389
+ and "tokenizer.chat_template" in self .metadata
390
+ ):
391
+ chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (
392
+ self .metadata
393
+ )
384
394
385
395
if chat_format is not None :
386
396
self .chat_format = chat_format
@@ -406,9 +416,7 @@ def __init__(
406
416
print (f"Using chat bos_token: { bos_token } " , file = sys .stderr )
407
417
408
418
self .chat_handler = llama_chat_format .Jinja2ChatFormatter (
409
- template = template ,
410
- eos_token = eos_token ,
411
- bos_token = bos_token
419
+ template = template , eos_token = eos_token , bos_token = bos_token
412
420
).to_chat_handler ()
413
421
414
422
if self .chat_format is None and self .chat_handler is None :
@@ -459,7 +467,9 @@ def tokenize(
459
467
"""
460
468
return self .tokenizer_ .tokenize (text , add_bos , special )
461
469
462
- def detokenize (self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None ) -> bytes :
470
+ def detokenize (
471
+ self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None
472
+ ) -> bytes :
463
473
"""Detokenize a list of tokens.
464
474
465
475
Args:
@@ -565,7 +575,7 @@ def sample(
565
575
logits [:] = (
566
576
logits_processor (self ._input_ids , logits )
567
577
if idx is None
568
- else logits_processor (self ._input_ids [:idx + 1 ], logits )
578
+ else logits_processor (self ._input_ids [: idx + 1 ], logits )
569
579
)
570
580
571
581
sampling_params = _LlamaSamplingParams (
@@ -707,7 +717,9 @@ def generate(
707
717
708
718
if self .draft_model is not None :
709
719
self .input_ids [self .n_tokens : self .n_tokens + len (tokens )] = tokens
710
- draft_tokens = self .draft_model (self .input_ids [:self .n_tokens + len (tokens )])
720
+ draft_tokens = self .draft_model (
721
+ self .input_ids [: self .n_tokens + len (tokens )]
722
+ )
711
723
tokens .extend (
712
724
draft_tokens .astype (int )[
713
725
: self ._n_ctx - self .n_tokens - len (tokens )
@@ -792,6 +804,7 @@ def embed(
792
804
793
805
# decode and fetch embeddings
794
806
data : List [List [float ]] = []
807
+
795
808
def decode_batch (n_seq : int ):
796
809
assert self ._ctx .ctx is not None
797
810
llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
@@ -800,9 +813,9 @@ def decode_batch(n_seq: int):
800
813
801
814
# store embeddings
802
815
for i in range (n_seq ):
803
- embedding : List [float ] = llama_cpp .llama_get_embeddings_ith (self . _ctx . ctx , i )[
804
- : n_embd
805
- ]
816
+ embedding : List [float ] = llama_cpp .llama_get_embeddings_ith (
817
+ self . _ctx . ctx , i
818
+ )[: n_embd ]
806
819
if normalize :
807
820
norm = float (np .linalg .norm (embedding ))
808
821
embedding = [v / norm for v in embedding ]
@@ -1669,12 +1682,13 @@ def create_chat_completion_openai_v1(
1669
1682
"""
1670
1683
try :
1671
1684
from openai .types .chat import ChatCompletion , ChatCompletionChunk
1672
- stream = kwargs .get ("stream" , False ) # type: ignore
1685
+
1686
+ stream = kwargs .get ("stream" , False ) # type: ignore
1673
1687
assert isinstance (stream , bool )
1674
1688
if stream :
1675
- return (ChatCompletionChunk (** chunk ) for chunk in self .create_chat_completion (* args , ** kwargs )) # type: ignore
1689
+ return (ChatCompletionChunk (** chunk ) for chunk in self .create_chat_completion (* args , ** kwargs )) # type: ignore
1676
1690
else :
1677
- return ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore
1691
+ return ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore
1678
1692
except ImportError :
1679
1693
raise ImportError (
1680
1694
"To use create_chat_completion_openai_v1, you must install the openai package."
@@ -1866,7 +1880,88 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
1866
1880
break
1867
1881
return longest_prefix
1868
1882
1883
+ @classmethod
1884
+ def from_pretrained (
1885
+ cls ,
1886
+ repo_id : str ,
1887
+ filename : Optional [str ],
1888
+ local_dir : Optional [Union [str , os .PathLike [str ]]] = "." ,
1889
+ local_dir_use_symlinks : Union [bool , Literal ["auto" ]] = "auto" ,
1890
+ ** kwargs : Any ,
1891
+ ) -> "Llama" :
1892
+ """Create a Llama model from a pretrained model name or path.
1893
+ This method requires the huggingface-hub package.
1894
+ You can install it with `pip install huggingface-hub`.
1895
+
1896
+ Args:
1897
+ repo_id: The model repo id.
1898
+ filename: A filename or glob pattern to match the model file in the repo.
1899
+ local_dir: The local directory to save the model to.
1900
+ local_dir_use_symlinks: Whether to use symlinks when downloading the model.
1901
+ **kwargs: Additional keyword arguments to pass to the Llama constructor.
1902
+
1903
+ Returns:
1904
+ A Llama model."""
1905
+ try :
1906
+ from huggingface_hub import hf_hub_download , HfFileSystem
1907
+ from huggingface_hub .utils import validate_repo_id
1908
+ except ImportError :
1909
+ raise ImportError (
1910
+ "Llama.from_pretrained requires the huggingface-hub package. "
1911
+ "You can install it with `pip install huggingface-hub`."
1912
+ )
1913
+
1914
+ validate_repo_id (repo_id )
1915
+
1916
+ hffs = HfFileSystem ()
1917
+
1918
+ files = [
1919
+ file ["name" ] if isinstance (file , dict ) else file
1920
+ for file in hffs .ls (repo_id )
1921
+ ]
1922
+
1923
+ # split each file into repo_id, subfolder, filename
1924
+ file_list : List [str ] = []
1925
+ for file in files :
1926
+ rel_path = Path (file ).relative_to (repo_id )
1927
+ file_list .append (str (rel_path ))
1869
1928
1929
+ matching_files = [file for file in file_list if fnmatch .fnmatch (file , filename )] # type: ignore
1930
+
1931
+ if len (matching_files ) == 0 :
1932
+ raise ValueError (
1933
+ f"No file found in { repo_id } that match { filename } \n \n "
1934
+ f"Available Files:\n { json .dumps (file_list )} "
1935
+ )
1936
+
1937
+ if len (matching_files ) > 1 :
1938
+ raise ValueError (
1939
+ f"Multiple files found in { repo_id } matching { filename } \n \n "
1940
+ f"Available Files:\n { json .dumps (files )} "
1941
+ )
1942
+
1943
+ (matching_file ,) = matching_files
1944
+
1945
+ subfolder = str (Path (matching_file ).parent )
1946
+ filename = Path (matching_file ).name
1947
+
1948
+ local_dir = "."
1949
+
1950
+ # download the file
1951
+ hf_hub_download (
1952
+ repo_id = repo_id ,
1953
+ local_dir = local_dir ,
1954
+ filename = filename ,
1955
+ subfolder = subfolder ,
1956
+ local_dir_use_symlinks = local_dir_use_symlinks ,
1957
+ )
1958
+
1959
+ model_path = os .path .join (local_dir , filename )
1960
+
1961
+ return cls (
1962
+ model_path = model_path ,
1963
+ ** kwargs ,
1964
+ )
1870
1965
1871
1966
1872
1967
class LlamaState :
0 commit comments