8000 Update llama_cpp.py · coderonion/llama-cpp-python@a40476e · GitHub
[go: up one dir, main page]

Skip to content

Commit a40476e

Browse files
Update llama_cpp.py
Make shared library code more robust with some platform specific functionality and more descriptive errors when failures occur
1 parent b9a4513 commit a40476e

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

llama_cpp/llama_cpp.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,49 @@
1+
import sys
2+
import os
13
import ctypes
2-
34
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
4-
55
import pathlib
6-
from itertools import chain
76

87
# Load the library
9-
# TODO: fragile, should fix
10-
_base_path = pathlib.Path(__file__).parent
11-
(_lib_path,) = chain(
12-
_base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll")
13-
)
14-
_lib = ctypes.CDLL(str(_lib_path))
8+
def load_shared_library(lib_base_name):
9+
# Determine the file extension based on the platform
10+
if sys.platform.startswith("linux"):
11+
lib_ext = ".so"
12+
elif sys.platform == "darwin":
13+
lib_ext = ".dylib"
14+
elif sys.platform == "win32":
15+
lib_ext = ".dll"
16+
else:
17+
raise RuntimeError("Unsupported platform")
18+
19+
# Construct the paths to the possible shared library names
20+
_base_path = pathlib.Path(__file__).parent.resolve()
21+
# Searching for the library in the current directory under the name "libllama" (default name
22+
# for llamacpp) and "llama" (default name for this repo)
23+
_lib_paths = [
24+
_base_path / f"lib{lib_base_name}{lib_ext}",
25+
_base_path / f"{lib_base_name}{lib_ext}"
26+
]
27+
28+
# Add the library directory to the DLL search path on Windows (if needed)
29+
if sys.platform == "win32" and sys.version_info >= (3, 8):
30+
os.add_dll_directory(str(_base_path))
31+
32+
# Try to load the shared library, handling potential errors
33+
for _lib_path in _lib_paths:
34+
if _lib_path.exists():
35+
try:
36+
return ctypes.CDLL(str(_lib_path))
37+
except Exception as e:
38+
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
39+
40+
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
41+
42+
# Specify the base name of the shared library to load
43+
lib_base_name = "llama"
44+
45+
# Load the library
46+
_lib = load_shared_library(lib_base_name)
1547

1648
# C types
1749
llama_context_p = c_void_p

0 commit comments

Comments
 (0)
0