8000 Fix array type signatures · coderonion/llama-cpp-python@1545b22 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1545b22

Browse files
committed
Fix array type signatures
1 parent 4b9eb5c commit 1545b22

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

llama_cpp/llama_cpp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def llama_model_quantize(
116116
# Returns 0 on success
117117
def llama_eval(
118118
ctx: llama_context_p,
119-
tokens: llama_token_p,
119+
tokens: ctypes.Array[llama_token],
120120
n_tokens: c_int,
121121
n_past: c_int,
122122
n_threads: c_int,
@@ -136,7 +136,7 @@ def llama_eval(
136136
def llama_tokenize(
137137
ctx: llama_context_p,
138138
text: bytes,
139-
tokens: llama_token_p,
139+
tokens: ctypes.Array[llama_token],
140140
n_max_tokens: c_int,
141141
add_bos: c_bool,
142142
) -> c_int:
@@ -176,7 +176,7 @@ def llama_n_embd(ctx: llama_context_p) -> c_int:
176176
# Can be mutated in order to change the probabilities of the next token
177177
# Rows: n_tokens
178178
# Cols: n_vocab
179-
def llama_get_logits(ctx: llama_context_p):
179+
def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
180180
return _lib.llama_get_logits(ctx)
181181

182182

@@ -186,7 +186,7 @@ def llama_get_logits(ctx: llama_context_p):
186186

187187
# Get the embeddings for the input
188188
# shape: [n_embd] (1-dimensional)
189-
def llama_get_embeddings(ctx: llama_context_p):
189+
def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]:
190190
return _lib.llama_get_embeddings(ctx)
191191

192192

@@ -224,7 +224,7 @@ def llama_token_eos() -> llama_token:
224224
# TODO: improve the last_n_tokens interface ?
225225
def llama_sample_top_p_top_k(
226226
ctx: llama_context_p,
227-
last_n_tokens_data: llama_token_p,
227+
last_n_tokens_data: ctypes.Array[llama_token],
228228
last_n_tokens_size: c_int,
229229
top_k: c_int,
230230
top_p: c_float,

0 commit comments

Comments
 (0)
0