8000 Chat llama.cpp example implementation · coderonion/llama-cpp-python@f1615f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit f1615f0

Browse files
author
Mug
committed
Chat llama.cpp example implementation
1 parent 7d1977e commit f1615f0

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""
2+
This is an example implementation of main.cpp from llama.cpp
3+
Quirks:
4+
* Its not exactly alike since this port is designed around programmatic I/O
5+
* Input is always echoed if on, so it should be turned off when using "input()"
6+
* The first antiprompt should be the userprompt like "\nUser:",
7+
because its added when n_predict is reached (aka generation ended prematurely)
8+
* n_predict can be set to -1 for unlimited length responses
9+
"""
10+
import llama_cpp
11+
12+
def toIntArray(lst):
13+
return [int(i) for i in lst]
14+
15+
# A LLaMA interactive session
16+
class LLaMAInteract:
17+
def __init__(self,
18+
primer: str="",
19+
model: str="./models/30B/ggml-model-q4_0.bin",
20+
n_ctx: int=1024,
21+
seed: int=0,
22+
n_threads: int=8,
23+
antiprompt: list[str]=[],
24+
input_echo: bool=True,
25+
n_predict: int=20,
26+
n_batch: int=8,
27+
repeat_last_n: int=64,
28+
top_k: int=50,
29+
top_p: float=1.,
30+
temp: float=1.0,
31+
repeat_penalty: float=1,
32+
) -> None:
33+
# input args
34+
self.n_threads = n_threads
35+
self.input_echo = input_echo
36+
self.n_predict = n_predict
37+
self.n_batch = n_batch
38+
self.repeat_last_n = repeat_last_n
39+
self.top_k=top_k
40+
self.top_p=top_p
41+
self.temp=temp
42+
self.repeat_penalty=repeat_penalty
43+
self.n_ctx = n_ctx
44+
self.seed = seed
45+
46+
# runtime args
47+
self.input_consumed = 0
48+
self.embd = []
49+
self.embd_inp = []
50+
self.n_past = 0
51+
self.first_antiprompt = []
52+
self.remaining_tokens = self.n_predict
53+
self.output_echo = input_echo
54+
55+
# model load
56+
self.lparams = llama_cpp.llama_context_default_params()
57+
self.lparams.n_ctx = self.n_ctx
58+
self.lparams.seed = self.seed
59+
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
60+
61+
# determine the required inference memory per token:
62+
tmp = [0, 1, 2, 3]
63+
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
64+
65+
# determine newline token
66+
self.llama_token_newline = (llama_cpp.llama_token * 1)()
67+
llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False)
68+
self.llama_token_newline = toIntArray(self.llama_token_newline)
69+
70+
# primer feed
71+
if (len(primer) > 0):
72+
self.input(primer)
73+
self.n_keep = len(self.embd_inp)
74+
75+
# create internal context
76+
self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
77+
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
78+
79+
# determine antiprompt tokens
80+
for i in antiprompt:
81+
d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))()
82+
n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False)
83+
self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt]))
84+
85+
# if an antiprompt is present
86+
def use_antiprompt(self):
87+
return len(self.first_antiprompt) > 0
88+
89+
def generate(self):
90+
while self.remaining_tokens > 0 or self.use_antiprompt():
91+
# predict
92+
if len(self.embd) > 0:
93+
# infinite text generation via context swapping
94+
# if we run out of context:
95+
# - take the n_keep first tokens from the original prompt (via n_past)
96+
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
97+
if (self.n_past + len(self.embd) > self.n_ctx):
98+
n_left = self.n_past - self.n_keep
99+
self.n_past = self.n_keep
100+
101+
# insert n_left/2 tokens at the start of embd from last_n_tokens
102+
_insert = self.last_n_tokens[
103+
-(int(n_left/2) - len(self.embd)):-len(self.embd)
104+
]
105+
self.embd[:len(_insert)] = _insert
106+
#TODO: Still untested
107+
108+
if (llama_cpp.llama_eval(
109+
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
110+
) != 0):
111+
raise Exception("Failed to llama_eval!")
112+
113+
self.n_past += len(self.embd)
114+
self.embd = []
115+
if len(self.embd_inp) <= self.input_consumed:
116+
# out of user input, sample next token
117+
_arr = self.last_n_tokens[-min(self.repeat_last_n, self.n_past):]
118+
id = llama_cpp.llama_sample_top_p_top_k(
119+
self.ctx,
120+
(llama_cpp.llama_token * len(_arr))(*_arr),
121+
len(_arr),
122+
self.top_k,
123+
self.top_p,
124+
self.temp,
125+
self.repeat_penalty,
126+
)
127+
self.last_n_tokens.pop(0)
128+
self.last_n_tokens.append(int(id))
129+
130+
# replace end of text token with newline token when in interactive mode
131+
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()):
132+
id = self.llama_token_newline[0]
133+
# tokenize and inject first reverse prompt
134+
self.embd_inp += self.first_antiprompt[0]
135+
136+
# add it to the context
137+
self.embd.append(int(id))
138+
139+
# echo this to console
140+
self.output_echo = True
141+
142+
# decrement remaining sampling budget
143+
self.remaining_tokens -= 1
144+
else:
145+
# output to console if input echo is on
146+
self.output_echo = self.input_echo
147+
148+
# some user input remains from prompt or interaction, forward it to processing
149+
while len(self.embd_inp) > self.input_consumed:
150+
self.embd.append(int(self.embd_inp[self.input_consumed]))
151+
self.last_n_tokens.pop(0)
152+
self.last_n_tokens.append(int(self.embd_inp[self.input_consumed]))
153+
self.input_consumed += 1
154+
if len(self.embd) >= self.n_batch:
155+
break
156+
157+
# display tokens
158+
if self.output_echo:
159+
for id in self.embd:
160+
yield id
161+
162+
# if antiprompt is present, stop
163+
if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed):
164+
for i in self.first_antiprompt:
165+
if i == self.last_n_tokens[-len(i):]:
166+
return
167+
168+
# if end of generation
169+
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
170+
break
171+
172+
# respect n_predict even if antiprompt is present
173+
if (self.use_antiprompt() and self.remaining_tokens <= 0 and self.n_predict != -1):
174+
self.embd_inp += self.first_antiprompt[0]
175+
break
176+
177+
def past(self):
178+
for id in self.last_n_tokens[-self.n_past:]:
179+
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
180+
181+
def input(self, prompt: str):
182+
embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
183+
n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True)
184+
self.embd_inp += toIntArray(embd_arr[:n_of_tok])
185+
186+
def output(self):
187+
self.remaining_tokens = self.n_predict
188+
for id in self.generate():
189+
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
190+
191+
if __name__ == "__main__":
192+
from datetime import datetime
193+
194+
USER_NAME="User"
195+
AI_NAME="ChatLLaMa"
196+
197+
time_now = datetime.now()
198+
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
199+
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
200+
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
201+
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
202+
The transcript only includes text, it does not include markup like HTML and Markdown.
203+
204+
{USER_NAME}: Hello, {AI_NAME}!
205+
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
206+
{USER_NAME}: What time is it?
207+
{AI_NAME}: It is {time_now.strftime("%H:%M")}.
208+
{USER_NAME}: What year is it?
209+
{AI_NAME}: We are in {time_now.strftime("%Y")}.
210+
{USER_NAME}: What is a cat?
211+
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
212+
{USER_NAME}: Name a color.
213+
{AI_NAME}: Blue
214+
{USER_NAME}:"""
215+
216+
print("Loading model...")
217+
ll = LLaMAInteract(prompt,
218+
model="./models/30B/ggml-model-q4_0.bin",
219+
n_ctx=2048,
220+
antiprompt=[f"\n{USER_NAME}:"],
221+
repeat_last_n=256,
222+
n_predict=2048,
223+
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
224+
)
225+
print("Loaded model!")
226+
227+
for i in ll.output():
228+
print(i,end="",flush=True)
229+
ll.input_echo = False
230+
231+
inp = lambda x: f" {x}\n"
232+
while True:
233+
ll.input(inp(input(' ')))
234+
for i in ll.output():
235+
print(i,end="",flush=True)

0 commit comments

Comments
 (0)
0