8000 Fix: types · Dpaste20/llama-cpp-python@4050143 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4050143

Browse files
committed
Fix: types
1 parent 66e28eb commit 4050143

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

llama_cpp/llama_cpp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ class llama_context_params(Structure):
141141
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
142142
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
143143

144+
# Misc
145+
c_float_p = POINTER(c_float)
146+
c_uint8_p = POINTER(c_uint8)
147+
c_size_t_p = POINTER(c_size_t)
148+
144149
# Functions
145150

146151

@@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_
257262
return _lib.llama_copy_state_data(ctx, dest)
258263

259264

260-
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
265+
_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
261266
_lib.llama_copy_state_data.restype = c_size_t
262267

263268

@@ -269,7 +274,7 @@ def llama_set_state_data(
269274
return _lib.llama_set_state_data(ctx, src)
270275

271276

272-
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
277+
_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
273278
_lib.llama_set_state_data.restype = c_size_t
274279

275280

@@ -291,7 +296,7 @@ def llama_load_session_file(
291296
c_char_p,
292297
llama_token_p,
293298
c_size_t,
294-
POINTER(c_size_t),
299+
c_size_t_p,
295300
]
296301
_lib.llama_load_session_file.restype = c_size_t
297302

@@ -340,7 +345,7 @@ def llama_eval(
340345
def llama_tokenize(
341346
ctx: llama_context_p,
342347
text: bytes,
343-
tokens, # type: Array[llama_token]
348+
tokens: Array[llama_token],
344349
n_max_tokens: c_int,
345350
add_bos: c_bool,
346351
) -> c_int:
@@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p):
385390

386391

387392
_lib.llama_get_logits.argtypes = [llama_context_p]
388-
_lib.llama_get_logits.restype = POINTER(c_float)
393+
_lib.llama_get_logits.restype = c_float_p
389394

390395

391396
# Get the embeddings for the input
@@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p):
395400

396401

397402
_lib.llama_get_embeddings.argtypes = [llama_context_p]
398-
_lib.llama_get_embeddings.restype = POINTER(c_float)
403+
_lib.llama_get_embeddings.restype = c_float_p
399404

400405

401406
# Token Id -> String. Uses the vocabulary in the provided context
@@ -614,7 +619,7 @@ def llama_sample_token_mirostat(
614619
c_float,
615620
c_float,
616621
c_int,
617-
POINTER(c_float),
622+
c_float_p,
618623
]
619624
_lib.llama_sample_token_mirostat.restype = llama_token
620625

@@ -639,7 +644,7 @@ def llama_sample_token_mirostat_v2(
639644
llama_token_data_array_p,
640645
c_float,
641646
c_float,
642-
POINTER(c_float),
647+
c_float_p,
643648
]
644649
_lib.llama_sample_token_mirostat_v2.restype = llama_token
645650

0 commit comments

Comments
 (0)
0