@@ -5,35 +5,19 @@ enum LlamaError: Error {
5
5
case couldNotInitializeContext
6
6
}
7
7
8
- func llama_batch_clear( _ batch: inout llama_batch ) {
9
- batch. n_tokens = 0
10
- }
11
-
12
- func llama_batch_add( _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] , _ logits: Bool ) {
13
- batch. token [ Int ( batch. n_tokens) ] = id
14
- batch. pos [ Int ( batch. n_tokens) ] = pos
15
- batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
16
- for i in 0 ..< seq_ids. count {
17
- batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
18
- }
19
- batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
20
-
21
- batch. n_tokens += 1
22
- }
23
-
24
8
actor LlamaContext {
25
9
private var model : OpaquePointer
26
10
private var context : OpaquePointer
27
11
private var vocab : OpaquePointer
28
12
private var sampling : UnsafeMutablePointer < llama_sampler >
29
- private var batch : llama_batch
13
+ private var batch : OpaquePointer
30
14
private var tokens_list : [ llama_token ]
31
15
var is_done : Bool = false
32
16
33
17
/// This variable is used to store temporarily invalid cchars
34
18
private var temporary_invalid_cchars : [ CChar ]
35
19
36
- var n_len : Int32 = 1024
20
+ var n_len : Int32 = 128
37
21
var n_cur : Int32 = 0
38
22
39
23
var n_decode : Int32 = 0
@@ -42,7 +26,7 @@ actor LlamaContext {
42
26
self . model = model
43
27
self . context = context
44
28
self . tokens_list = [ ]
45
- self . batch = llama_batch_init ( 512 , 0 , 1 )
29
+ self . batch = llama_batch_ext_init ( 512 , 1 )
46
30
self . temporary_invalid_cchars = [ ]
47
31
let sparams = llama_sampler_chain_default_params ( )
48
32
self . sampling = llama_sampler_chain_init ( sparams)
@@ -53,7 +37,7 @@ actor LlamaContext {
53
37
54
38
deinit {
55
39
llama_sampler_free ( sampling)
56
- llama_batch_free ( batch)
40
+ llama_batch_ext_free ( batch)
57
41
llama_model_free ( model)
58
42
llama_free ( context)
59
43
llama_backend_free ( )
@@ -111,7 +95,7 @@ actor LlamaContext {
111
95
}
112
96
113
97
func get_n_tokens( ) -> Int32 {
114
- return batch. n_tokens;
98
+ return llama_batch_ext_get_n_tokens ( batch)
115
99
}
116
100
117
101
func completion_init( text: String ) {
@@ -133,25 +117,25 @@ actor LlamaContext {
133
117
print ( String ( cString: token_to_piece ( token: id) + [ 0 ] ) )
134
118
}
135
119
136
- llama_batch_clear ( & batch)
120
+ llama_batch_ext_clear ( batch)
137
121
138
122
for i1 in 0 ..< tokens_list. count {
139
123
let i = Int ( i1)
140
- llama_batch_add ( & batch, tokens_list [ i] , Int32 ( i) , [ 0 ] , false )
124
+ llama_batch_ext_add_text ( batch, tokens_list [ i] , Int32 ( i) , [ llama_seq_id ( 0 ) ] , 1 , false )
141
125
}
142
- batch . logits [ Int ( batch. n_tokens ) - 1 ] = 1 // true
126
+ llama_batch_ext_set_output_last ( batch)
143
127
144
- if llama_decode ( context, batch) != 0 {
145
- print ( " llama_decode () failed" )
128
+ if llama_decode_ext ( context, batch) != 0 {
129
+ print ( " llama_decode_ext () failed" )
146
130
}
147
131
148
- n_cur = batch. n_tokens
132
+ n_cur = llama_batch_ext_get_n_tokens ( batch)
149
133
}
150
134
151
135
func completion_loop( ) -> String {
152
136
var new_token_id : llama_token = 0
153
137
154
- new_token_id = llama_sampler_sample ( sampling, context, batch. n_tokens - 1 )
138
+ new_token_id = llama_sampler_sample ( sampling, context, llama_batch_ext_get_n_tokens ( batch) - 1 )
155
139
156
140
if llama_vocab_is_eog ( vocab, new_token_id) || n_cur == n_len {
157
141
print ( " \n " )
@@ -178,13 +162,13 @@ actor LlamaContext {
178
162
print ( new_token_str)
179
163
// tokens_list.append(new_token_id)
180
164
181
- llama_batch_clear ( & batch)
182
- llama_batch_add ( & batch, new_token_id, n_cur, [ 0 ] , true )
165
+ llama_batch_ext_clear ( batch)
166
+ llama_batch_ext_add_text ( batch, new_token_id, n_cur, [ llama_seq_id ( 0 ) ] , 1 , true )
183
167
184
168
n_decode += 1
185
169
n_cur += 1
186
170
187
- if llama_decode ( context, batch) != 0 {
171
+ if llama_decode_ext ( context, batch) != 0 {
188
172
print ( " failed to evaluate llama! " )
189
173
}
190
174
@@ -201,21 +185,21 @@ actor LlamaContext {
201
185
for _ in 0 ..< nr {
202
186
// bench prompt processing
203
187
204
- llama_batch_clear ( & batch)
188
+ llama_batch_ext_clear ( batch)
205
189
206
190
let n_tokens = pp
207
191
208
192
for i in 0 ..< n_tokens {
209
- llama_batch_add ( & batch, 0 , Int32 ( i) , [ 0 ] , false )
193
+ llama_batch_ext_add_text ( batch, 0 , Int32 ( i) , [ llama_seq_id ( 0 ) ] , 1 , false )
210
194
}
211
- batch . logits [ Int ( batch. n_tokens ) - 1 ] = 1 // true
195
+ llama_batch_ext_set_output_last ( batch)
212
196
213
197
llama_kv_self_clear ( context)
214
198
215
199
let t_pp_start = DispatchTime . now ( ) . uptimeNanoseconds / 1000 ;
216
200
217
- if llama_decode ( context, batch) != 0 {
218
- print ( " llama_decode () failed during prompt" )
201
+ if llama_decode_ext ( context, batch) != 0 {
202
+ print ( " llama_decode_ext () failed during prompt" )
219
203
}
220
204
llama_synchronize ( context)
221
205
@@ -228,14 +212,14 @@ actor LlamaContext {
228
212
let t_tg_start = DispatchTime . now ( ) . uptimeNanoseconds / 1000 ;
229
213
230
214
for i in 0 ..< tg {
231
- llama_batch_clear ( & batch)
215
+ llama_batch_ext_clear ( batch)
232
216
233
217
for j in 0 ..< pl {
234
- llama_batch_add ( & batch, 0 , Int32 ( i) , [ Int32 ( j) ] , true )
218
+ llama_batch_ext_add_text ( batch, 0 , Int32 ( i) , [ llama_seq_id ( Int32 ( j) ) ] , 1 , true )
235
219
}
236
220
237
- if llama_decode ( context, batch) != 0 {
238
- print ( " llama_decode () failed during text generation" )
221
+ if llama_decode_ext ( context, batch) != 0 {
222
+ print ( " llama_decode_ext () failed during text generation" )
239
223
}
240
224
llama_synchronize ( context)
241
225
}
0 commit comments