1
1
import ctypes
2
2
3
- from ctypes import (
4
- c_int ,
5
- c_float ,
6
- c_char_p ,
7
- c_void_p ,
8
- c_bool ,
9
- POINTER ,
10
- Structure ,
11
- )
3
+ from ctypes import c_int , c_float , c_char_p , c_void_p , c_bool , POINTER , Structure , Array
12
4
13
5
import pathlib
14
6
from itertools import chain
@@ -116,7 +108,7 @@ def llama_model_quantize(
116
108
# Returns 0 on success
117
109
def llama_eval (
118
110
ctx : llama_context_p ,
119
- tokens : ctypes . Array [llama_token ],
111
+ tokens , # type: Array[llama_token]
120
112
n_tokens : c_int ,
121
113
n_past : c_int ,
122
114
n_threads : c_int ,
@@ -136,7 +128,7 @@ def llama_eval(
136
128
def llama_tokenize (
137
129
ctx : llama_context_p ,
138
130
text : bytes ,
139
- tokens : ctypes . Array [llama_token ],
131
+ tokens , # type: Array[llama_token]
140
132
n_max_tokens : c_int ,
141
133
add_bos : c_bool ,
142
134
) -> c_int :
@@ -176,7 +168,7 @@ def llama_n_embd(ctx: llama_context_p) -> c_int:
176
168
# Can be mutated in order to change the probabilities of the next token
177
169
# Rows: n_tokens
178
170
# Cols: n_vocab
179
- def llama_get_logits (ctx : llama_context_p ) -> ctypes . Array [ c_float ] :
171
+ def llama_get_logits (ctx : llama_context_p ):
180
172
return _lib .llama_get_logits (ctx )
181
173
182
174
@@ -186,7 +178,7 @@ def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
186
178
187
179
# Get the embeddings for the input
188
180
# shape: [n_embd] (1-dimensional)
189
- def llama_get_embeddings (ctx : llama_context_p ) -> ctypes . Array [ c_float ] :
181
+ def llama_get_embeddings (ctx : llama_context_p ):
190
182
return _lib .llama_get_embeddings (ctx )
191
183
192
184
@@ -224,7 +216,7 @@ def llama_token_eos() -> llama_token:
224
216
# TODO: improve the last_n_tokens interface ?
225
217
def llama_sample_top_p_top_k (
226
218
ctx : llama_context_p ,
227
- last_n_tokens_data : ctypes . Array [llama_token ],
219
+ last_n_tokens_data , # type: Array[llama_token]
228
220
last_n_tokens_size : c_int ,
229
221
top_k : c_int ,
230
222
top_p : c_float ,
0 commit comments