8000 fix: Update from_pretrained defaults to match hf_hub_download · ducky777/llama-cpp-python@e6d6260 · GitHub
[go: up one dir, main page]

Skip to content

Commit e6d6260

Browse files
committed
fix: Update from_pretrained defaults to match hf_hub_download
1 parent dd22010 commit e6d6260

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,8 +1885,9 @@ def from_pretrained(
18851885
cls,
18861886
repo_id: str,
18871887
filename: Optional[str],
1888-
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
1888+
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
18891889
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
1890+
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
18901891
**kwargs: Any,
18911892
) -> "Llama":
18921893
"""Create a Llama model from a pretrained model name or path.
@@ -1945,18 +1946,29 @@ def from_pretrained(
19451946
subfolder = str(Path(matching_file).parent)
19461947
filename = Path(matching_file).name
19471948

1948-
local_dir = "."
1949-
19501949
# download the file
19511950
hf_hub_download(
19521951
repo_id=repo_id,
1953-
local_dir=local_dir,
19541952
filename=filename,
19551953
subfolder=subfolder,
1954+
local_dir=local_dir,
19561955
local_dir_use_symlinks=local_dir_use_symlinks,
1956+
cache_dir=cache_dir,
19571957
)
19581958

1959-
model_path = os.path.join(local_dir, filename)
1959+
if local_dir is None:
1960+
model_path = hf_hub_download(
1961+
repo_id=repo_id,
1962+
filename=filename,
1963+
subfolder=subfolder,
1964+
local_dir=local_dir,
1965+
local_dir_use_symlinks=local_dir_use_symlinks,
1966+
cache_dir=cache_dir,
1967+
local_files_only=True,
1968+
1969+
)
1970+
else:
1971+
model_path = os.path.join(local_dir, filename)
19601972

19611973
return cls(
19621974
model_path=model_path,

0 commit comments

Comments
 (0)
0