@@ -141,6 +141,11 @@ class llama_context_params(Structure):
141
141
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
142
142
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
143
143
144
+ # Misc
145
+ c_float_p = POINTER(c_float)
146
+ c_uint8_p = POINTER(c_uint8)
147
+ c_size_t_p = POINTER(c_size_t)
148
+
144
149
# Functions
145
150
146
151
@@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_
257
262
return _lib.llama_copy_state_data(ctx, dest)
258
263
259
264
260
- _lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8) ]
265
+ _lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p ]
261
266
_lib.llama_copy_state_data.restype = c_size_t
262
267
263
268
@@ -269,7 +274,7 @@ def llama_set_state_data(
269
274
return _lib.llama_set_state_data(ctx, src)
270
275
271
276
272
- _lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8) ]
277
+ _lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p ]
273
278
_lib.llama_set_state_data.restype = c_size_t
274
279
275
280
@@ -291,7 +296,7 @@ def llama_load_session_file(
291
296
c_char_p,
292
297
llama_token_p,
293
298
c_size_t,
294
- POINTER(c_size_t) ,
299
+ c_size_t_p ,
295
300
]
296
301
_lib.llama_load_session_file.restype = c_size_t
297
302
@@ -340,7 +345,7 @@ def llama_eval(
340
345
def llama_tokenize(
341
346
ctx: llama_context_p,
342
347
text: bytes,
343
- tokens, # type : Array[llama_token]
348
+ tokens: Array[llama_token],
344
349
n_max_tokens: c_int,
345
350
add_bos: c_bool,
346
351
) -> c_int:
@@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p):
385
390
386
391
387
392
_lib.llama_get_logits.argtypes = [llama_context_p]
388
- _lib.llama_get_logits.restype = POINTER(c_float)
393
+ _lib.llama_get_logits.restype = c_float_p
389
394
390
395
391
396
# Get the embeddings for the input
@@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p):
395
400
396
401
397
402
_lib.llama_get_embeddings.argtypes = [llama_context_p]
398
- _lib.llama_get_embeddings.restype = POINTER(c_float)
403
+ _lib.llama_get_embeddings.restype = c_float_p
399
404
400
405
401
406
# Token Id -> String. Uses the vocabulary in the provided context
@@ -614,7 +619,7 @@ def llama_sample_token_mirostat(
614
619
c_float,
615
620
c_float,
616
621
c_int,
617
- POINTER(c_float) ,
622
+ c_float_p ,
618
623
]
619
624
_lib.llama_sample_token_mirostat.restype = llama_token
620
625
@@ -639,7 +644,7 @@ def llama_sample_token_mirostat_v2(
639
644
llama_token_data_array_p,
640
645
c_float,
641
646
c_float,
642
- POINTER(c_float) ,
647
+ c_float_p ,
643
648
]
644
649
_lib.llama_sample_token_mirostat_v2.restype = llama_token
645
650
0 commit comments