@@ -197,13 +197,14 @@ def __init__(
197
197
# kv_overrides is the original python dict
198
198
self .kv_overrides = kv_overrides
199
199
if kv_overrides is not None :
200
-
201
200
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
202
- kvo_array_len = len (kv_overrides ) + 1 # for sentinel element
203
- self ._kv_overrides_array = (llama_cpp .llama_model_kv_override * kvo_array_len )()
201
+ kvo_array_len = len (kv_overrides ) + 1 # for sentinel element
202
+ self ._kv_overrides_array = (
203
+ llama_cpp .llama_model_kv_override * kvo_array_len
204
+ )()
204
205
205
206
for i , (k , v ) in enumerate (kv_overrides .items ()):
206
- self ._kv_overrides_array [i ].key = k .encode (' utf-8' );
207
+ self ._kv_overrides_array [i ].key = k .encode (" utf-8" )
207
208
if isinstance (v , int ):
208
209
self ._kv_overrides_array [i ].tag = llama_cpp .LLAMA_KV_OVERRIDE_INT
209
210
self ._kv_overrides_array [i ].value .int_value = v
@@ -216,7 +217,9 @@ def __init__(
216
217
else :
217
218
raise ValueError (f"Unknown value type for { k } : { v } " )
218
219
219
- self ._kv_overrides_array [- 1 ].key = b'\0 ' # ensure sentinel element is zeroed
220
+ self ._kv_overrides_array [
221
+ - 1
222
+ ].key = b"\0 " # ensure sentinel element is zeroed
220
223
self .model_params .kv_overrides = self ._kv_overrides_array
221
224
222
225
self .n_batch = min (n_ctx , n_batch ) # ???
@@ -326,15 +329,17 @@ def __init__(
326
329
(n_ctx , self ._n_vocab ), dtype = np .single
327
330
)
328
331
329
- self ._mirostat_mu = ctypes .c_float (2.0 * 5.0 ) # TODO: Move this to sampling context
332
+ self ._mirostat_mu = ctypes .c_float (
333
+ 2.0 * 5.0
334
+ ) # TODO: Move this to sampling context
330
335
331
336
try :
332
337
self .metadata = self ._model .metadata ()
333
338
except Exception as e :
334
339
self .metadata = {}
335
340
if self .verbose :
336
341
print (f"Failed to load metadata: { e } " , file = sys .stderr )
337
-
342
+
338
343
if self .verbose :
339
344
print (f"Model metadata: { self .metadata } " , file = sys .stderr )
340
345
@@ -534,7 +539,7 @@ def sample(
534
539
candidates = self ._candidates ,
535
540
tau = mirostat_tau ,
536
541
eta = mirostat_eta ,
537
- mu = ctypes .pointer (self ._mirostat_mu )
542
+ mu = ctypes .pointer (self ._mirostat_mu ),
538
543
)
539
544
else :
540
545
self ._ctx .sample_top_k (candidates = self ._candidates , k = top_k , min_keep = 1 )
0 commit comments