8000 Fix ctypes typing issue for Arrays · coderonion/llama-cpp-python@670d390 · GitHub
[go: up one dir, main page]

Skip to content

Commit 670d390

Browse files
committed
Fix ctypes typing issue for Arrays
1 parent 1545b22 commit 670d390

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

llama_cpp/llama_cpp.py

+6Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import ctypes
22

3-
from ctypes import (
4-
c_int,
5-
c_float,
6-
c_char_p,
7-
c_void_p,
8-
c_bool,
9-
POINTER,
10-
Structure,
11-
)
3+
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array
124

135
import pathlib
146
from itertools import chain
@@ -116,7 +108,7 @@ def llama_model_quantize(
116108
# Returns 0 on success
117109
def llama_eval(
118110
ctx: llama_context_p,
119-
tokens: ctypes.Array[llama_token],
111+
tokens, # type: Array[llama_token]
120112
n_tokens: c_int,
121113
n_past: c_int,
122114
n_threads: c_int,
@@ -136,7 +128,7 @@ def llama_eval(
136128
def llama_tokenize(
137129
ctx: llama_context_p,
138130
text: bytes,
139-
tokens: ctypes.Array[llama_token],
131+
tokens, # type: Array[llama_token]
140132
n_max_tokens: c_int,
141133
add_bos: c_bool,
142134
) -> c_int:
@@ -176,7 +168,7 @@ def llama_n_embd(ctx: llama_context_p) -> c_int:
176168
# Can be mutated in order to change the probabilities of the next token
177169
# Rows: n_tokens
178170
# Cols: n_vocab
179-
def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
171+
def llama_get_logits(ctx: llama_context_p):
180172
return _lib.llama_get_logits(ctx)
181173

182174

@@ -186,7 +178,7 @@ def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
186178

187179
# Get the embeddings for the input
188180
# shape: [n_embd] (1-dimensional)
189-
def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]:
181+
def llama_get_embeddings(ctx: llama_context_p):
190182
return _lib.llama_get_embeddings(ctx)
191183

192184

@@ -224,7 +216,7 @@ def llama_token_eos() -> llama_token:
224216
# TODO: improve the last_n_tokens interface ?
225217
def llama_sample_top_p_top_k(
226218
ctx: llama_context_p,
227-
last_n_tokens_data: ctypes.Array[llama_token],
219+
last_n_tokens_data, # type: Array[llama_token]
228220
last_n_tokens_size: c_int,
229221
top_k: c_int,
230222
top_p: c_float,

0 commit comments

Comments
 (0)
0