8000 Implement prompt batch processing as in main.cpp · coderonion/llama-cpp-python@e24c581 · GitHub
[go: up one dir, main page]

Skip to content

Commit e24c581

Browse files
committed
Implement prompt batch processing as in main.cpp
1 parent a28cb92 commit e24c581

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

llama_cpp/llama.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(
1919
):
2020
self.model_path = model_path
2121

22+
self.last_n = 64
23+
self.max_chunk_size = 32
24+
2225
self.params = llama_cpp.llama_context_default_params()
2326
self.params.n_ctx = n_ctx
2427
self.params.n_parts = n_parts
@@ -59,21 +62,32 @@ def __call__(
5962
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
6063
)
6164

62-
if prompt_tokens + max_tokens > self.params.n_ctx:
65+
if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
6366
raise ValueError(
6467
f"Requested tokens exceed context window of {self.params.n_ctx}"
6568
)
6669

67-
for i in range(prompt_tokens):
68-
llama_cpp.llama_eval(
69-
self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads
70+
# Process prompt in chunks to avoid running out of memory
71+
for i in range(0, prompt_tokens, self.max_chunk_size):
72+
chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)]
73+
rc = llama_cpp.llama_eval(
74+
self.ctx,
75+
(llama_cpp.llama_token * len(chunk))(*chunk),
76+
len(chunk),
77+
max(0, i - 1),
78+
self.n_threads,
7079
)
80+
if rc != 0:
81+
raise RuntimeError(f"Failed to evaluate prompt: {rc}")
7182

7283
for i in range(max_tokens):
84+
tokens_seen = prompt_tokens + completion_tokens
85+
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)]
86+
7387
token = llama_cpp.llama_sample_top_p_top_k(
7488
self.ctx,
75-
self.tokens,
76-
prompt_tokens + completion_tokens,
89+
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
90+
len(last_n_tokens),
7791
top_k=top_k,
7892
top_p=top_p,
7993
temp=temperature,
@@ -82,7 +96,6 @@ def __call__(
8296
if token == llama_cpp.llama_token_eos():
8397
finish_reason = "stop"
8498
break
85-
# text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
8699
text += llama_cpp.llama_token_to_str(self.ctx, token)
87100
self.tokens[prompt_tokens + i] = token
88101
completion_tokens += 1
@@ -96,7 +109,7 @@ def __call__(
96109

97110
llama_cpp.llama_eval(
98111
self.ctx,
99-
(llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]),
112+
(llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]),
100113
1,
101114
prompt_tokens + completion_tokens,
102115
self.n_threads,

0 commit comments

Comments
 (0)
0