8000 Bugfix: cross-platform method to find shared lib · coderonion/llama-cpp-python@4da5faa · GitHub
[go: up one dir, main page]

Skip to content

Commit 4da5faa

Browse files
committed
Bugfix: cross-platform method to find shared lib
1 parent b936756 commit 4da5faa

File tree

1 file changed

+58
-60
lines changed

1 file changed

+58
-60
lines changed

llama_cpp/llama_cpp.py

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
)
1313

1414
import pathlib
15+
from itertools import chain
1516

1617
# Load the library
17-
libfile = pathlib.Path(__file__).parent / "libllama.so"
18-
lib = ctypes.CDLL(str(libfile))
19-
18+
# TODO: fragile, should fix
19+
_base_path = pathlib.Path(__file__).parent
20+
(_lib_path,) = chain(
21+
_base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll")
22+
)
23+
_lib = ctypes.CDLL(str(_lib_path))
2024

2125
# C types
2226
llama_context_p = c_void_p
@@ -60,12 +64,12 @@ class llama_context_params(Structure):
6064

6165

6266
def llama_context_default_params() -> llama_context_params:
63-
params = lib.llama_context_default_params()
67+
params = _lib.llama_context_default_params()
6468
return params
6569

6670

67-
lib.llama_context_default_params.argtypes = []
68-
lib.llama_context_default_params.restype = llama_context_params
71+
_lib.llama_context_default_params.argtypes = []
72+
_lib.llama_context_default_params.restype = llama_context_params
6973

7074

7175
# Various functions for loading a ggml llama model.
@@ -74,32 +78,32 @@ def llama_context_default_params() -> llama_context_params:
7478
def llama_init_from_file(
7579
path_model: bytes, params: llama_context_params
7680
) -> llama_context_p:
77-
return lib.llama_init_from_file(path_model, params)
81+
return _lib.llama_init_from_file(path_model, params)
7882

7983

80-
lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params]
81-
lib.llama_init_from_file.restype = llama_context_p
84+
_lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params]
85+
_lib.llama_init_from_file.restype = llama_context_p
8286

8387

8488
# Frees all allocated memory
8589
def llama_free(ctx: llama_context_p):
86-
lib.llama_free(ctx)
90+
_lib.llama_free(ctx)
8791

8892

89-
lib.llama_free.argtypes = [llama_context_p]
90-
lib.llama_free.restype = None
93+
_lib.llama_free.argtypes = [llama_context_p]
94+
_lib.llama_free.restype = None
9195

9296

9397
# TODO: not great API - very likely to change
9498
# Returns 0 on success
9599
def llama_model_quantize(
96100
fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int
97101
) -> c_int:
98-
return lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
102+
return _lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
99103

100104

101-
lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
102-
lib.llama_model_quantize.restype = c_int
105+
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
106+
_lib.llama_model_quantize.restype = c_int
103107

104108

105109
# Run the llama inference to obtain the logits and probabilities for the next token.
@@ -113,11 +117,11 @@ def llama_eval(
113117
n_past: c_int,
114118
n_threads: c_int,
115119
) -> c_int:
116-
return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
120+
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
117121

118122

119-
lib.llama_eval.arg 6D4E types = [llama_context_p, llama_token_p, c_int, c_int, c_int]
120-
lib.llama_eval.restype = c_int
123+
_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_int]
124+
_lib.llama_eval.restype = c_int
121125

122126

123127
# Convert the provided text into tokens.
@@ -132,32 +136,27 @@ def llama_tokenize(
132136
n_max_tokens: c_int,
133137
add_bos: c_bool,
134138
) -> c_int:
135-
"""Convert the provided text into tokens.
136-
The tokens pointer must be large enough to hold the resulting tokens.
137-
Returns the number of tokens on success, no more than n_max_tokens
138-
Returns a negative number on failure - the number of tokens that would have been returned
139-
"""
140-
return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
139+
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
141140

142141

143-
lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool]
144-
lib.llama_tokenize.restype = c_int
142+
_lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool]
143+
_lib.llama_tokenize.restype = c_int
145144

146145

147146
def llama_n_vocab(ctx: llama_context_p) -> c_int:
148-
return lib.llama_n_vocab(ctx)
147+
return _lib.llama_n_vocab(ctx)
149148

150149

151-
lib.llama_n_vocab.argtypes = [llama_context_p]
152-
lib.llama_n_vocab.restype = c_int
150+
_lib.llama_n_vocab.argtypes = [llama_context_p]
151+
_lib.llama_n_vocab.restype = c_int
153152

154153

155154
def llama_n_ctx(ctx: llama_context_p) -> c_int:
156-
return lib.llama_n_ctx(ctx)
155+
return _lib.llama_n_ctx(ctx)
157156

158157

159-
lib.llama_n_ctx.argtypes = [llama_context_p]
160-
lib.llama_n_ctx.restype = c_int
158+
_lib.llama_n_ctx.argtypes = [llama_context_p]
159+
_lib.llama_n_ctx.restype = c_int
161160

162161

163162
# Token logits obtained from the last call to llama_eval()
@@ -166,48 +165,48 @@ def llama_n_ctx(ctx: llama_context_p) -> c_int:
166165
# Rows: n_tokens
167166
# Cols: n_vocab
168167
def llama_get_logits(ctx: llama_context_p):
169-
return lib.llama_get_logits(ctx)
168+
return _lib.llama_get_logits(ctx)
170169

171170

172-
lib.llama_get_logits.argtypes = [llama_context_p]
173-
lib.llama_get_logits.restype = POINTER(c_float)
171+
_lib.llama_get_logits.argtypes = [llama_context_p]
172+
_lib.llama_get_logits.restype = POINTER(c_float)
174173

175174

176175
# Get the embeddings for the input
177176
# shape: [n_embd] (1-dimensional)
178177
def llama_get_embeddings(ctx: llama_context_p):
179-
return lib.llama_get_embeddings(ctx)
178+
return _lib.llama_get_embeddings(ctx)
180179

181180

182-
lib.llama_get_embeddings.argtypes = [llama_context_p]
183-
lib.llama_get_embeddings.restype = POINTER(c_float)
181+
_lib.llama_get_embeddings.argtypes = [llama_context_p]
182+
_lib.llama_get_embeddings.restype = POINTER(c_float)
184183

185184

186185
# Token Id -> String. Uses the vocabulary in the provided context
187186
def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes:
188-
return lib.llama_token_to_str(ctx, token)
187+
return _lib.llama_token_to_str(ctx, token)
189188

190189

191-
lib.llama_token_to_str.argtypes = [llama_context_p, llama_token]
192-
lib.llama_token_to_str.restype = c_char_p
190+
_lib.llama_token_to_str.argtypes = [llama_context_p, llama_token]
191+
_lib.llama_token_to_str.restype = c_char_p
193192

194193
# Special tokens
195194

196195

197196
def llama_token_bos() -> llama_token:
198-
return lib.llama_token_bos()
197+
return _lib.llama_token_bos()
199198

200199

201-
lib.llama_token_bos.argtypes = []
202-
lib.llama_token_bos.restype = llama_token
200+
_lib.llama_token_bos.argtypes = []
201+
_lib.llama_token_bos.restype = llama_token
203202

204203

205204
def llama_token_eos() -> llama_token:
206-
return lib.llama_token_eos()
205+
return _lib.llama_token_eos()
207206

208207

209-
lib.llama_token_eos.argtypes = []
210-
lib.llama_token_eos.restype = llama_token
208+
_lib.llama_token_eos.argtypes = []
209+
_lib.llama_token_eos.restype = llama_token
211210

212211

213212
# TODO: improve the last_n_tokens interface ?
@@ -220,12 +219,12 @@ def llama_sample_top_p_top_k(
220219
temp: c_double,
221220
repeat_penalty: c_double,
222221
) -> llama_token:
223-
return lib.llama_sample_top_p_top_k(
222+
return _lib.llama_sample_top_p_top_k(
224223
ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty
225224
)
226225

227226

228-
lib.llama_sample_top_p_top_k.argtypes = [
227+
_lib.llama_sample_top_p_top_k.argtypes = [
229228
llama_context_p,
230229
llama_token_p,
231230
c_int,
@@ -234,33 +233,32 @@ def llama_sample_top_p_top_k(
234233
c_double,
235234
c_double,
236235
]
237-
lib.llama_sample_top_p_top_k.restype = llama_token
236+
_lib.llama_sample_top_p_top_k.restype = llama_token
238237

239238

240239
# Performance information
241240

242241

243242
def llama_print_timings(ctx: llama_context_p):
244-
lib.llama_print_timings(ctx)
243+
_lib.llama_print_timings(ctx)
245244

246245

247-
lib.llama_print_timings.argtypes = [llama_context_p]
248-
lib.llama_print_timings.restype = None
246+
_lib.llama_print_timings.argtypes = [llama_context_p]
247+
_lib.llama_print_timings.restype = None
249248

250249

251250
def llama_reset_timings(ctx: llama_context_p):
252-
lib.llama_reset_timings(ctx)
251+
_lib.llama_reset_timings(ctx)
253252

254253

255-
lib.llama_reset_timings.argtypes = [llama_context_p]
256-
lib.llama_reset_timings.restype = None
254+
_lib.llama_reset_timings.argtypes = [llama_context_p]
255+
_lib.llama_reset_timings.restype = None
257256

258257

259258
# Print system information
260259
def llama_print_system_info() -> bytes:
261-
"""Print system informaiton"""
262-
return lib.llama_print_system_info()
260+
return _lib.llama_print_system_info()
263261

264262

265-
lib.llama_print_system_info.argtypes = []
266-
lib.llama_print_system_info.restype = c_char_p
263+
_lib.llama_print_system_info.argtypes = []
264+
_lib.llama_print_system_info.restype = c_char_p

0 commit comments

Comments
 (0)
0