1
1
import ctypes
2
2
3
- from ctypes import c_int , c_float , c_double , c_char_p , c_void_p , c_bool , POINTER , Structure
3
+ from ctypes import (
4
+ c_int ,
5
+ c_float ,
6
+ c_double ,
7
+ c_char_p ,
8
+ c_void_p ,
9
+ c_bool ,
10
+ POINTER ,
11
+ Structure ,
12
+ )
4
13
5
14
import pathlib
6
15
13
22
llama_token = c_int
14
23
llama_token_p = POINTER (llama_token )
15
24
25
+
16
26
class llama_token_data (Structure ):
17
27
_fields_ = [
18
- ('id' , llama_token ), # token id
19
- ('p' , c_float ), # probability of the token
20
- (' plog' , c_float ), # log probability of the token
28
+ ("id" , llama_token ), # token id
29
+ ("p" , c_float ), # probability of the token
30
+ (" plog" , c_float ), # log probability of the token
21
31
]
22
32
33
+
23
34
llama_token_data_p = POINTER (llama_token_data )
24
35
36
+
25
37
class llama_context_params (Structure ):
26
38
_fields_ = [
27
- ('n_ctx' , c_int ), # text context
28
- ('n_parts' , c_int ), # -1 for default
29
- ('seed' , c_int ), # RNG seed, 0 for random
30
- ('f16_kv' , c_bool ), # use fp16 for KV cache
31
- ('logits_all' , c_bool ), # the llama_eval() call computes all logits, not just the last one
32
-
33
- ('vocab_only' , c_bool ), # only load the vocabulary, no weights
39
+ ("n_ctx" , c_int ), # text context
40
+ ("n_parts" , c_int ), # -1 for default
41
+ ("seed" , c_int ), # RNG seed, 0 for random
42
+ ("f16_kv" , c_bool ), # use fp16 for KV cache
43
+ (
44
+ "logits_all" ,
45
+ c_bool ,
46
+ ), # the llama_eval() call computes all logits, not just the last one
47
+ ("vocab_only" , c_bool ), # only load the vocabulary, no weights
34
48
]
35
49
50
+
36
51
llama_context_params_p = POINTER (llama_context_params )
37
52
38
53
llama_context_p = c_void_p
@@ -74,7 +89,15 @@ class llama_context_params(Structure):
74
89
lib .llama_token_eos .argtypes = []
75
90
lib .llama_token_eos .restype = llama_token
76
91
77
- lib .llama_sample_top_p_top_k .argtypes = [llama_context_p , llama_token_p , c_int , c_int , c_double , c_double , c_double ]
92
+ lib .llama_sample_top_p_top_k .argtypes = [
93
+ llama_context_p ,
94
+ llama_token_p ,
95
+ c_int ,
96
+ c_int ,
97
+ c_double ,
98
+ c_double ,
99
+ c_double ,
100
+ ]
78
101
lib .llama_sample_top_p_top_k .restype = llama_token
79
102
80
103
lib .llama_print_timings .argtypes = [llama_context_p ]
@@ -86,45 +109,71 @@ class llama_context_params(Structure):
86
109
lib .llama_print_system_info .argtypes = []
87
110
lib .llama_print_system_info .restype = c_char_p
88
111
112
+
89
113
# Python functions
90
114
def llama_context_default_params () -> llama_context_params :
91
115
params = lib .llama_context_default_params ()
92
116
return params
93
117
94
- def llama_init_from_file (path_model : bytes , params : llama_context_params ) -> llama_context_p :
118
+
119
+ def llama_init_from_file (
120
+ path_model : bytes , params : llama_context_params
121
+ ) -> llama_context_p :
95
122
"""Various functions for loading a ggml llama model.
96
123
Allocate (almost) all memory needed for the model.
97
- Return NULL on failure """
124
+ Return NULL on failure"""
98
125
return lib .llama_init_from_file (path_model , params )
99
126
127
+
100
128
def llama_free (ctx : llama_context_p ):
101
129
"""Free all allocated memory"""
102
130
lib .llama_free (ctx )
103
131
104
- def llama_model_quantize (fname_inp : bytes , fname_out : bytes , itype : c_int , qk : c_int ) -> c_int :
132
+
133
+ def llama_model_quantize (
134
+ fname_inp : bytes , fname_out : bytes , itype : c_int , qk : c_int
135
+ ) -> c_int :
105
136
"""Returns 0 on success"""
106
137
return lib .llama_model_quantize (fname_inp , fname_out , itype , qk )
107
138
108
- def llama_eval (ctx : llama_context_p , tokens : llama_token_p , n_tokens : c_int , n_past : c_int , n_threads : c_int ) -> c_int :
139
+
140
+ def llama_eval (
141
+ ctx : llama_context_p ,
142
+ tokens : llama_token_p ,
143
+ n_tokens : c_int ,
144
+ n_past : c_int ,
145
+ n_threads : c_int ,
146
+ ) -> c_int :
109
147
"""Run the llama inference to obtain the logits and probabilities for the next token.
110
148
tokens + n_tokens is the provided batch of new tokens to process
111
149
n_past is the number of tokens to use from previous eval calls
112
150
Returns 0 on success"""
113
151
return lib .llama_eval (ctx , tokens , n_tokens , n_past , n_threads )
114
152
115
- def llama_tokenize (ctx : llama_context_p , text : bytes , tokens : llama_token_p , n_max_tokens : c_int , add_bos : c_bool ) -> c_int :
153
+
154
+ def llama_tokenize (
155
+ ctx : llama_context_p ,
156
+ text : bytes ,
157
+ tokens : llama_token_p ,
158
+ n_max_tokens : c_int ,
159
+ add_bos : c_bool ,
160
+ ) -> c_int :
116
161
"""Convert the provided text into tokens.
117
162
The tokens pointer must be large enough to hold the resulting tokens.
118
163
Returns the number of tokens on success, no more than n_max_tokens
119
- Returns a negative number on failure - the number of tokens that would have been returned"""
164
+ Returns a negative number on failure - the number of tokens that would have been returned
165
+ """
120
166
return lib .llama_tokenize (ctx , text , tokens , n_max_tokens , add_bos )
121
167
168
+
122
169
def llama_n_vocab (ctx : llama_context_p ) -> c_int :
123
170
return lib .llama_n_vocab (ctx )
124
171
172
+
125
173
def llama_n_ctx (ctx : llama_context_p ) -> c_int :
126
174
return lib .llama_n_ctx (ctx )
127
175
176
+
128
177
def llama_get_logits (ctx : llama_context_p ):
129
178
"""Token logits obtained from the last call to llama_eval()
130
179
The logits for the last token are stored in the last row
@@ -133,25 +182,42 @@ def llama_get_logits(ctx: llama_context_p):
133
182
Cols: n_voc
10000
ab"""
134
183
return lib .llama_get_logits (ctx )
135
184
185
+
136
186
def llama_token_to_str (ctx : llama_context_p , token : int ) -> bytes :
137
187
"""Token Id -> String. Uses the vocabulary in the provided context"""
138
188
return lib .llama_token_to_str (ctx , token )
139
189
190
+
140
191
def llama_token_bos () -> llama_token :
141
192
return lib .llama_token_bos ()
142
193
194
+
143
195
def llama_token_eos () -> llama_token :
144
196
return lib .llama_token_eos ()
145
197
146
- def llama_sample_top_p_top_k (ctx : llama_context_p , last_n_tokens_data : llama_token_p , last_n_tokens_size : c_int , top_k : c_int , top_p : c_double , temp : c_double , repeat_penalty : c_double ) -> llama_token :
147
- return lib .llama_sample_top_p_top_k (ctx , last_n_tokens_data , last_n_tokens_size , top_k , top_p , temp , repeat_penalty )
198
+
199
+ def llama_sample_top_p_top_k (
200
+ ctx : llama_context_p ,
201
+ last_n_tokens_data : llama_token_p ,
202
+ last_n_tokens_size : c_int ,
203
+ top_k : c_int ,
204
+ top_p : c_double ,
205
+ temp : c_double ,
206
+ repeat_penalty : c_double ,
207
+ ) -> llama_token :
208
+ return lib .llama_sample_top_p_top_k (
209
+ ctx , last_n_tokens_data , last_n_tokens_size , top_k , top_p , temp , repeat_penalty
210
+ )
211
+
148
212
149
213
def llama_print_timings (ctx : llama_context_p ):
150
214
lib .llama_print_timings (ctx )
151
215
216
+
152
217
def llama_reset_timings (ctx : llama_context_p ):
153
218
lib .llama_reset_timings (ctx )
154
219
220
+
155
221
def llama_print_system_info () -> bytes :
156
222
"""Print system informaiton"""
157
223
return lib .llama_print_system_info ()
0 commit comments