8000 feat: Pull models directly from huggingface (#1206) · sunnykim1206/llama-cpp-python@0f8aa4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f8aa4a

Browse files
authored
feat: Pull models directly from huggingface (abetlen#1206)
* Add from_pretrained method to Llama class * Update docs * Merge filename and pattern
1 parent e42f62c commit 0f8aa4a

File tree

3 files changed

+136
-27
lines changed

3 files changed

+136
-27
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi
212212

213213
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
214214

215+
## Pulling models from Hugging Face
216+
217+
You can pull `Llama` models from Hugging Face using the `from_pretrained` method.
218+
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
219+
220+
```python
221+
llama = Llama.from_pretrained(
222+
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
223+
filename="*q8_0.gguf",
224+
verbose=False
225+
)
226+
```
227+
215228
### Chat Completion
216229

217230
The high-level API also provides a simple interface for chat completion.

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp.
2626
- load_state
2727
- token_bos
2828
- token_eos
29+
- from_pretrained
2930
show_root_heading: true
3031

3132
::: llama_cpp.LlamaGrammar

llama_cpp/llama.py

Lines changed: 122 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import sys
55
import uuid
66
import time
7+
import json
8+
import fnmatch
79
import multiprocessing
810
from typing import (
911
List,
@@ -16,6 +18,7 @@
1618
Callable,
1719
)
1820
from collections import deque
21+
from pathlib import Path
1922

2023
import ctypes
2124

@@ -29,10 +32,7 @@
2932
LlamaDiskCache, # type: ignore
3033
LlamaRAMCache, # type: ignore
3134
)
32-
from .llama_tokenizer import (
33-
BaseLlamaTokenizer,
34-
LlamaTokenizer
35-
)
35+
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
3636
import llama_cpp.llama_cpp as llama_cpp
3737
import llama_cpp.llama_chat_format as llama_chat_format
3838

@@ -50,9 +50,7 @@
5050
_LlamaSamplingContext, # type: ignore
5151
)
5252
from ._logger import set_verbose
53-
from ._utils import (
54-
suppress_stdout_stderr
55-
)
53+
from ._utils import suppress_stdout_stderr
5654

5755

5856
class Llama:
@@ -189,7 +187,11 @@ def __init__(
189187
Llama.__backend_initialized = True
190188

191189
if isinstance(numa, bool):
192-
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
190+
self.numa = (
191+
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
192+
if numa
193+
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
194+
)
193195
else:
194196
self.numa = numa
195197

@@ -246,17 +248,17 @@ def __init__(
246248
else:
247249
raise ValueError(f"Unknown value type for {k}: {v}")
248250

249-
self._kv_overrides_array[
250-
-1
251-
].key = b"\0" # ensure sentinel element is zeroed
251+
self._kv_overrides_array[-1].key = (
252+
b"\0" # ensure sentinel element is zeroed
253+
)
252254
self.model_params.kv_overrides = self._kv_overrides_array
253255

254256
self.n_batch = min(n_ctx, n_batch) # ???
255257
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
256258
self.n_threads_batch = n_threads_batch or max(
257259
multiprocessing.cpu_count() // 2, 1
258260
)
259-
261+
260262
# Context Params
261263
self.context_params = llama_cpp.llama_context_default_params()
262264
self.context_params.seed = seed
@@ -289,7 +291,9 @@ def __init__(
289291
)
290292
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
291293
self.context_params.mul_mat_q = mul_mat_q
292-
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
294+
self.context_params.logits_all = (
295+
logits_all if draft_model is None else True
296+
) # Must be set to True for speculative decoding
293297
self.context_params.embedding = embedding
294298
self.context_params.offload_kqv = offload_kqv
295299

@@ -379,8 +383,14 @@ def __init__(
379383
if self.verbose:
380384
print(f"Model metadata: {self.metadata}", file=sys.stderr)
381385

382-
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
383-
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
386+
if (
387+
self.chat_format is None
388+
and self.chat_handler is None
389+
and "tokenizer.chat_template" in self.metadata
390+
):
391+
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
392+
self.metadata
393+
)
384394

385395
if chat_format is not None:
386396
self.chat_format = chat_format
@@ -406,9 +416,7 @@ def __init__(
406416
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
407417

408418
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
409-
template=template,
410-
eos_token=eos_token,
411-
bos_token=bos_token
419+
template=template, eos_token=eos_token, bos_token=bos_token
412420
).to_chat_handler()
413421

414422
if self.chat_format is None and self.chat_handler is None:
@@ -459,7 +467,9 @@ def tokenize(
459467
"""
460468
return self.tokenizer_.tokenize(text, add_bos, special)
461469

462-
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
470+
def detokenize(
471+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
472+
) -> bytes:
463473
"""Detokenize a list of tokens.
464474
465475
Args:
@@ -565,7 +575,7 @@ def sample(
565575
logits[:] = (
566576
logits_processor(self._input_ids, logits)
567577
if idx is None
568-
else logits_processor(self._input_ids[:idx + 1], logits)
578+
else logits_processor(self._input_ids[: idx + 1], logits)
569579
)
570580

571581
sampling_params = _LlamaSamplingParams(
@@ -707,7 +717,9 @@ def generate(
707717

708718
if self.draft_model is not None:
709719
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
710-
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
720+
draft_tokens = self.draft_model(
721+
self.input_ids[: self.n_tokens + len(tokens)]
722+
)
711723
tokens.extend(
712724
draft_tokens.astype(int)[
713725
: self._n_ctx - self.n_tokens - len(tokens)
@@ -792,6 +804,7 @@ def embed(
792804

793805
# decode and fetch embeddings
794806
data: List[List[float]] = []
807+
795808
def decode_batch(n_seq: int):
796809
assert self._ctx.ctx is not None
797810
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
@@ -800,9 +813,9 @@ def decode_batch(n_seq: int):
800813

801814
# store embeddings
802815
for i in range(n_seq):
803-
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
804-
:n_embd
805-
]
816+
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
817+
self._ctx.ctx, i
818+
)[:n_embd]
806819
if normalize:
807820
norm = float(np.linalg.norm(embedding))
808821
embedding = [v / norm for v in embedding]
@@ -1669,12 +1682,13 @@ def create_chat_completion_openai_v1(
16691682
"""
16701683
try:
16711684
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1672-
stream = kwargs.get("stream", False) # type: ignore
1685+
1686+
stream = kwargs.get("stream", False) # type: ignore
16731687
assert isinstance(stream, bool)
16741688
if stream:
1675-
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
1689+
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
16761690
else:
1677-
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
1691+
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
16781692
except ImportError:
16791693
raise ImportError(
16801694
"To use create_chat_completion_openai_v1, you must install the openai package."
@@ -1866,7 +1880,88 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
18661880
break
18671881
return longest_prefix
18681882

1883+
@classmethod
1884+
def from_pretrained(
1885+
cls,
1886+
repo_id: str,
1887+
filename: Optional[str],
1888+
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
1889+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
1890+
**kwargs: Any,
1891+
) -> "Llama":
1892+
"""Create a Llama model from a pretrained model name or path.
1893+
This method requires the huggingface-hub package.
1894+
You can install it with `pip install huggingface-hub`.
1895+
1896+
Args:
1897+
repo_id: The model repo id.
1898+
filename: A filename or glob pattern to match the model file in the repo.
1899+
local_dir: The local directory to save the model to.
1900+
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
1901+
**kwargs: Additional keyword arguments to pass to the Llama constructor.
1902+
1903+
Returns:
1904+
A Llama model."""
1905+
try:
1906+
from huggingface_hub import hf_hub_download, HfFileSystem
1907+
from huggingface_hub.utils import validate_repo_id
1908+
except ImportError:
1909+
raise ImportError(
1910+
"Llama.from_pretrained requires the huggingface-hub package. "
1911+
"You can install it with `pip install huggingface-hub`."
1912+
)
1913+
1914+
validate_repo_id(repo_id)
1915+
1916+
hffs = HfFileSystem()
1917+
1918+
files = [
1919+
file["name"] if isinstance(file, dict) else file
1920+
for file in hffs.ls(repo_id)
1921+
]
1922+
1923+
# split each file into repo_id, subfolder, filename
1924+
file_list: List[str] = []
1925+
for file in files:
1926+
rel_path = Path(file).relative_to(repo_id)
1927+
file_list.append(str(rel_path))
18691928

1929+
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
1930+
1931+
if len(matching_files) == 0:
1932+
raise ValueError(
1933+
f"No file found in {repo_id} that match {filename}\n\n"
1934+
f"Available Files:\n{json.dumps(file_list)}"
1935+
)
1936+
1937+
if len(matching_files) > 1:
1938+
raise ValueError(
1939+
f"Multiple files found in {repo_id} matching {filename}\n\n"
1940+
f"Available Files:\n{json.dumps(files)}"
1941+
)
1942+
1943+
(matching_file,) = matching_files
1944+
1945+
subfolder = str(Path(matching_file).parent)
1946+
filename = Path(matching_file).name
1947+
1948+
local_dir = "."
1949+
1950+
# download the file
1951+
hf_hub_download(
1952+
repo_id=repo_id,
1953+
local_dir=local_dir,
1954+
filename=filename,
1955+
subfolder=subfolder,
1956+
local_dir_use_symlinks=local_dir_use_symlinks,
1957+
)
1958+
1959+
model_path = os.path.join(local_dir, filename)
1960+
1961+
return cls(
1962+
model_path=model_path,
1963+
**kwargs,
1964+
)
18701965

18711966

18721967
class LlamaState:

0 commit comments

Comments
 (0)
0