8000 swift : adapt to new API · pqnet/llama.cpp@96ca6e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 96ca6e8

Browse files
committed
swift : adapt to new API
1 parent b0db7fc commit 96ca6e8

File tree

1 file changed

+24
-40
lines changed

1 file changed

+24
-40
lines changed

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,19 @@ enum LlamaError: Error {
55
case couldNotInitializeContext
66
}
77

8-
func llama_batch_clear(_ batch: inout llama_batch) {
9-
batch.n_tokens = 0
10-
}
11-
12-
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
13-
batch.token [Int(batch.n_tokens)] = id
14-
batch.pos [Int(batch.n_tokens)] = pos
15-
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
16-
for i in 0..<seq_ids.count {
17-
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
18-
}
19-
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
20-
21-
batch.n_tokens += 1
22-
}
23-
248
actor LlamaContext {
259
private var model: OpaquePointer
2610
private var context: OpaquePointer
2711
private var vocab: OpaquePointer
2812
private var sampling: UnsafeMutablePointer<llama_sampler>
29-
private var batch: llama_batch
13+
private var batch: OpaquePointer
3014
private var tokens_list: [llama_token]
3115
var is_done: Bool = false
3216

3317
/// This variable is used to store temporarily invalid cchars
3418
private var temporary_invalid_cchars: [CChar]
3519

36-
var n_len: Int32 = 1024
20+
var n_len: Int32 = 128
3721
var n_cur: Int32 = 0
3822

3923
var n_decode: Int32 = 0
@@ -42,7 +26,7 @@ actor LlamaContext {
4226
self.model = model
4327
self.context = context
4428
self.tokens_list = []
45-
self.batch = llama_batch_init(512, 0, 1)
29+
self.batch = llama_batch_ext_init(512, 1)
4630
self.temporary_invalid_cchars = []
4731
let sparams = llama_sampler_chain_default_params()
4832
self.sampling = llama_sampler_chain_init(sparams)
@@ -53,7 +37,7 @@ actor LlamaContext {
5337

5438
deinit {
5539
llama_sampler_free(sampling)
56-
llama_batch_free(batch)
40+
llama_batch_ext_free(batch)
5741
llama_model_free(model)
5842
llama_free(context)
5943
llama_backend_free()
@@ -111,7 +95,7 @@ actor LlamaContext {
11195
}
11296

11397
func get_n_tokens() -> Int32 {
114-
return batch.n_tokens;
98+
return llama_batch_ext_get_n_tokens(batch)
11599
}
116100

117101
func completion_init(text: String) {
@@ -133,25 +117,25 @@ actor LlamaContext {
133117
print(String(cString: token_to_piece(token: id) + [0]))
134118
}
135119

136-
llama_batch_clear(&batch)
120+
llama_batch_ext_clear(batch)
137121

138122
for i1 in 0..<tokens_list.count {
139123
let i = Int(i1)
140-
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
124+
llama_batch_ext_add_text(batch, tokens_list[i], Int32(i), [llama_seq_id(0)], 1, false)
141125
}
142-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
126+
llama_batch_ext_set_output_last(batch)
143127

144-
if llama_decode(context, batch) != 0 {
145-
print("llama_decode() failed")
128+
if llama_decode_ext(context, batch) != 0 {
129+
print("llama_decode_ext() failed")
146130
}
147131

148-
n_cur = batch.n_tokens
132+
n_cur = llama_batch_ext_get_n_tokens(batch)
149133
}
150134

151135
func completion_loop() -> String {
152136
var new_token_id: llama_token = 0
153137

154-
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
138+
new_token_id = llama_sampler_sample(sampling, context, llama_batch_ext_get_n_tokens(batch) - 1)
155139

156140
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
157141
print("\n")
@@ -178,13 +162,13 @@ actor LlamaContext {
178162
print(new_token_str)
179163
// tokens_list.append(new_token_id)
180164

181-
llama_batch_clear(&batch)
182-
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
165+
llama_batch_ext_clear(batch)
166+
llama_batch_ext_add_text(batch, new_token_id, n_cur, [llama_seq_id(0)], 1, true)
183167

184168
n_decode += 1
185169
n_cur += 1
186170

187-
if llama_decode(context, batch) != 0 {
171+
if llama_decode_ext(context, batch) != 0 {
188172
print("failed to evaluate llama!")
189173
}
190174

@@ -201,21 +185,21 @@ actor LlamaContext {
201185
for _ in 0..<nr {
202186
// bench prompt processing
203187

204-
llama_batch_clear(&batch)
188+
llama_batch_ext_clear(batch)
205189

206190
let n_tokens = pp
207191

208192
for i in 0..<n_tokens {
209-
llama_batch_add(&batch, 0, Int32(i), [0], false)
193+
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(0)], 1, false)
210194
}
211-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
195+
llama_batch_ext_set_output_last(batch)
212196

213197
llama_kv_self_clear(context)
214198

215199
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
216200

217-
if llama_decode(context, batch) != 0 {
218-
print("llama_decode() failed during prompt")
201+
if llama_decode_ext(context, batch) != 0 {
202+
print("llama_decode_ext() failed during prompt")
219203
}
220204
llama_synchronize(context)
221205

@@ -228,14 +212,14 @@ actor LlamaContext {
228212
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
229213

230214
for i in 0..<tg {
231-
llama_batch_clear(&batch)
215+
llama_batch_ext_clear(batch)
232216

233217
for j in 0..<pl {
234-
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
218+
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(Int32(j))], 1, true)
235219
}
236220

237-
if llama_decode(context, batch) != 0 {
238-
print("llama_decode() failed during text generation")
221+
if llama_decode_ext(context, batch) != 0 {
222+
print("llama_decode_ext() failed during text generation")
239223
}
240224
llama_synchronize(context)
241225
}

0 commit comments

Comments
 (0)
0