12
12
)
13
13
14
14
import pathlib
15
+ from itertools import chain
15
16
16
17
# Load the library
17
- libfile = pathlib .Path (__file__ ).parent / "libllama.so"
18
- lib = ctypes .CDLL (str (libfile ))
19
-
18
+ # TODO: fragile, should fix
19
+ _base_path = pathlib .Path (__file__ ).parent
20
+ (_lib_path ,) = chain (
21
+ _base_path .glob ("*.so" ), _base_path .glob ("*.dylib" ), _base_path .glob ("*.dll" )
22
+ )
23
+ _lib = ctypes .CDLL (str (_lib_path ))
20
24
21
25
# C types
22
26
llama_context_p = c_void_p
@@ -60,12 +64,12 @@ class llama_context_params(Structure):
60
64
61
65
62
66
def llama_context_default_params () -> llama_context_params :
63
- params = lib .llama_context_default_params ()
67
+ params = _lib .llama_context_default_params ()
64
68
return params
65
69
66
70
67
- lib .llama_context_default_params .argtypes = []
68
- lib .llama_context_default_params .restype = llama_context_params
71
+ _lib .llama_context_default_params .argtypes = []
72
+ _lib .llama_context_default_params .restype = llama_context_params
69
73
70
74
71
75
# Various functions for loading a ggml llama model.
@@ -74,32 +78,32 @@ def llama_context_default_params() -> llama_context_params:
74
78
def llama_init_from_file (
75
79
path_model : bytes , params : llama_context_params
76
80
) -> llama_context_p :
77
- return lib .llama_init_from_file (path_model , params )
81
+ return _lib .llama_init_from_file (path_model , params )
78
82
79
83
80
- lib .llama_init_from_file .argtypes = [c_char_p , llama_context_params ]
81
- lib .llama_init_from_file .restype = llama_context_p
84
+ _lib .llama_init_from_file .argtypes = [c_char_p , llama_context_params ]
85
+ _lib .llama_init_from_file .restype = llama_context_p
82
86
83
87
84
88
# Frees all allocated memory
85
89
def llama_free (ctx : llama_context_p ):
86
- lib .llama_free (ctx )
90
+ _lib .llama_free (ctx )
87
91
88
92
89
- lib .llama_free .argtypes = [llama_context_p ]
90
- lib .llama_free .restype = None
93
+ _lib .llama_free .argtypes = [llama_context_p ]
94
+ _lib .llama_free .restype = None
91
95
92
96
93
97
# TODO: not great API - very likely to change
94
98
# Returns 0 on success
95
99
def llama_model_quantize (
96
100
fname_inp : bytes , fname_out : bytes , itype : c_int , qk : c_int
97
101
) -> c_int :
98
- return lib .llama_model_quantize (fname_inp , fname_out , itype , qk )
102
+ return _lib .llama_model_quantize (fname_inp , fname_out , itype , qk )
99
103
100
104
101
- lib .llama_model_quantize .argtypes = [c_char_p , c_char_p , c_int , c_int ]
102
- lib .llama_model_quantize .restype = c_int
105
+ _lib .llama_model_quantize .argtypes = [c_char_p , c_char_p , c_int , c_int ]
106
+ _lib .llama_model_quantize .restype = c_int
103
107
104
108
105
109
# Run the llama inference to obtain the logits and probabilities for the next token.
@@ -113,11 +117,11 @@ def llama_eval(
113
117
n_past : c_int ,
114
118
n_threads : c_int ,
115
119
) -> c_int :
116
- return lib .llama_eval (ctx , tokens , n_tokens , n_past , n_threads )
120
+ return _lib .llama_eval (ctx , tokens , n_tokens , n_past , n_threads )
117
121
118
122
119
- lib .llama_eval .arg
6D4E
types = [llama_context_p , llama_token_p , c_int , c_int , c_int ]
120
- lib .llama_eval .restype = c_int
123
+ _lib .llama_eval .argtypes = [llama_context_p , llama_token_p , c_int , c_int , c_int ]
124
+ _lib .llama_eval .restype = c_int
121
125
122
126
123
127
# Convert the provided text into tokens.
@@ -132,32 +136,27 @@ def llama_tokenize(
132
136
n_max_tokens : c_int ,
133
137
add_bos : c_bool ,
134
138
) -> c_int :
135
- """Convert the provided text into tokens.
136
- The tokens pointer must be large enough to hold the resulting tokens.
137
- Returns the number of tokens on success, no more than n_max_tokens
138
- Returns a negative number on failure - the number of tokens that would have been returned
139
- """
140
- return lib .llama_tokenize (ctx , text , tokens , n_max_tokens , add_bos )
139
+ return _lib .llama_tokenize (ctx , text , tokens , n_max_tokens , add_bos )
141
140
142
141
143
- lib .llama_tokenize .argtypes = [llama_context_p , c_char_p , llama_token_p , c_int , c_bool ]
144
- lib .llama_tokenize .restype = c_int
142
+ _lib .llama_tokenize .argtypes = [llama_context_p , c_char_p , llama_token_p , c_int , c_bool ]
143
+ _lib .llama_tokenize .restype = c_int
145
144
146
145
147
146
def llama_n_vocab (ctx : llama_context_p ) -> c_int :
148
- return lib .llama_n_vocab (ctx )
147
+ return _lib .llama_n_vocab (ctx )
149
148
150
149
151
- lib .llama_n_vocab .argtypes = [llama_context_p ]
152
- lib .llama_n_vocab .restype = c_int
150
+ _lib .llama_n_vocab .argtypes = [llama_context_p ]
151
+ _lib .llama_n_vocab .restype = c_int
153
152
154
153
155
154
def llama_n_ctx (ctx : llama_context_p ) -> c_int :
156
- return lib .llama_n_ctx (ctx )
155
+ return _lib .llama_n_ctx (ctx )
157
156
158
157
159
- lib .llama_n_ctx .argtypes = [llama_context_p ]
160
- lib .llama_n_ctx .restype = c_int
158
+ _lib .llama_n_ctx .argtypes = [llama_context_p ]
159
+ _lib .llama_n_ctx .restype = c_int
161
160
162
161
163
162
# Token logits obtained from the last call to llama_eval()
@@ -166,48 +165,48 @@ def llama_n_ctx(ctx: llama_context_p) -> c_int:
166
165
# Rows: n_tokens
167
166
# Cols: n_vocab
168
167
def llama_get_logits (ctx : llama_context_p ):
169
- return lib .llama_get_logits (ctx )
168
+ return _lib .llama_get_logits (ctx )
170
169
171
170
172
- lib .llama_get_logits .argtypes = [llama_context_p ]
173
- lib .llama_get_logits .restype = POINTER (c_float )
171
+ _lib .llama_get_logits .argtypes = [llama_context_p ]
172
+ _lib .llama_get_logits .restype = POINTER (c_float )
174
173
175
174
176
175
# Get the embeddings for the input
177
176
# shape: [n_embd] (1-dimensional)
178
177
def llama_get_embeddings (ctx : llama_context_p ):
179
- return lib .llama_get_embeddings (ctx )
178
+ return _lib .llama_get_embeddings (ctx )
180
179
181
180
182
- lib .llama_get_embeddings .argtypes = [llama_context_p ]
183
- lib .llama_get_embeddings .restype = POINTER (c_float )
181
+ _lib .llama_get_embeddings .argtypes = [llama_context_p ]
182
+ _lib .llama_get_embeddings .restype = POINTER (c_float )
184
183
185
184
186
185
# Token Id -> String. Uses the vocabulary in the provided context
187
186
def llama_token_to_str (ctx : llama_context_p , token : int ) -> bytes :
188
- return lib .llama_token_to_str (ctx , token )
187
+ return _lib .llama_token_to_str (ctx , token )
189
188
190
189
191
- lib .llama_token_to_str .argtypes = [llama_context_p , llama_token ]
192
- lib .llama_token_to_str .restype = c_char_p
190
+ _lib .llama_token_to_str .argtypes = [llama_context_p , llama_token ]
191
+ _lib .llama_token_to_str .restype = c_char_p
193
192
194
193
# Special tokens
195
194
196
195
197
196
def llama_token_bos () -> llama_token :
198
- return lib .llama_token_bos ()
197
+ return _lib .llama_token_bos ()
199
198
200
199
201
- lib .llama_token_bos .argtypes = []
202
- lib .llama_token_bos .restype = llama_token
200
+ _lib .llama_token_bos .argtypes = []
201
+ _lib .llama_token_bos .restype = llama_token
203
202
204
203
205
204
def llama_token_eos () -> llama_token :
206
- return lib .llama_token_eos ()
205
+ return _lib .llama_token_eos ()
207
206
208
207
209
- lib .llama_token_eos .argtypes = []
210
- lib .llama_token_eos .restype = llama_token
208
+ _lib .llama_token_eos .argtypes = []
209
+ _lib .llama_token_eos .restype = llama_token
211
210
212
211
213
212
# TODO: improve the last_n_tokens interface ?
@@ -220,12 +219,12 @@ def llama_sample_top_p_top_k(
220
219
temp : c_double ,
221
220
repeat_penalty : c_double ,
222
221
) -> llama_token :
223
- return lib .llama_sample_top_p_top_k (
222
+ return _lib .llama_sample_top_p_top_k (
224
223
ctx , last_n_tokens_data , last_n_tokens_size , top_k , top_p , temp , repeat_penalty
225
224
)
226
225
227
226
228
- lib .llama_sample_top_p_top_k .argtypes = [
227
+ _lib .llama_sample_top_p_top_k .argtypes = [
229
228
llama_context_p ,
230
229
llama_token_p ,
231
230
c_int ,
@@ -234,33 +233,32 @@ def llama_sample_top_p_top_k(
234
233
c_double ,
235
234
c_double ,
236
235
]
237
- lib .llama_sample_top_p_top_k .restype = llama_token
236
+ _lib .llama_sample_top_p_top_k .restype = llama_token
238
237
239
238
240
239
# Performance information
241
240
242
241
243
242
def llama_print_timings (ctx : llama_context_p ):
244
- lib .llama_print_timings (ctx )
243
+ _lib .llama_print_timings (ctx )
245
244
246
245
247
- lib .llama_print_timings .argtypes = [llama_context_p ]
248
- lib .llama_print_timings .restype = None
246
+ _lib .llama_print_timings .argtypes = [llama_context_p ]
247
+ _lib .llama_print_timings .restype = None
249
248
250
249
251
250
def llama_reset_timings (ctx : llama_context_p ):
252
- lib .llama_reset_timings (ctx )
251
+ _lib .llama_reset_timings (ctx )
253
252
254
253
255
- lib .llama_reset_timings .argtypes = [llama_context_p ]
256
- lib .llama_reset_timings .restype = None
254
+ _lib .llama_reset_timings .argtypes = [llama_context_p ]
255
+ _lib .llama_reset_timings .restype = None
257
256
258
257
259
258
# Print system information
260
259
def llama_print_system_info () -> bytes :
261
- """Print system informaiton"""
262
- return lib .llama_print_system_info ()
260
+ return _lib .llama_print_system_info ()
263
261
264
262
265
- lib .llama_print_system_info .argtypes = []
266
- lib .llama_print_system_info .restype = c_char_p
263
+ _lib .llama_print_system_info .argtypes = []
264
+ _lib .llama_print_system_info .restype = c_char_p
0 commit comments