8000 Update api to allow for easier interactive mode · coderonion/llama-cpp-python@a4a1bbe · GitHub
[go: up one dir, main page]

Skip to content

Commit a4a1bbe

Browse files
committed
Update api to allow for easier interactive mode
1 parent eef627c commit a4a1bbe

File tree

1 file changed

+76
-32
lines changed

1 file changed

+76
-32
lines changed

llama_cpp/llama.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def __init__(
6363
self.params.embedding = embedding
6464

6565
self.last_n_tokens_size = last_n_tokens_size
66+
self.last_n_tokens_data = deque(
67+
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
68+
maxlen=self.last_n_tokens_size,
69+
)
70+
self.tokens_consumed = 0
6671
self.n_batch = n_batch
6772

6873
self.n_threads = n_threads or multiprocessing.cpu_count()
@@ -115,6 +120,67 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
115120
output += llama_cpp.llama_token_to_str(self.ctx, token)
116121
return output
117122

123+
def reset(self):
124+
"""Reset the model state."""
125+
self.last_n_tokens_data.extend(
126+
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
127+
)
128+
self.tokens_consumed = 0
129+
130+
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
131+
"""Evaluate a list of tokens.
132+
133+
Args:
134+
tokens: The list of tokens to evaluate.
135+
"""
136+
assert self.ctx is not None
137+
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
138+
for i in range(0, len(tokens), self.n_batch):
139+
batch = tokens[i : min(len(tokens), i + self.n_batch)]
140+
n_past = min(n_ctx - len(batch), self.tokens_consumed)
141+
return_code = llama_cpp.llama_eval(
142+
ctx=self.ctx,
143+
tokens=(llama_cpp.llama_token * len(batch))(*batch),
144+
n_tokens=llama_cpp.c_int(len(batch)),
145+
n_past=llama_cpp.c_int(n_past),
146+
n_threads=llama_cpp.c_int(self.n_threads),
147+
)
148+
if int(return_code) != 0:
149+
raise RuntimeError(f"llama_eval returned {return_code}")
150+
self.last_n_tokens_data.extend(batch)
151+
self.tokens_consumed += len(batch)
152+
153+
def sample(
154+
self,
155+
top_k: int,
156+
top_p: float,
157+
temp: float,
158+
repeat_penalty: float,
159+
):
160+
"""Sample a token from the model.
161+
162+
Args:
163+
top_k: The top-k sampling parameter.
164+
top_p: The top-p sampling parameter.
165+
temp: The temperature parameter.
166+
repeat_penalty: The repeat penalty parameter.
167+
168+
Returns:
169+
The sampled token.
170+
"""
171+
assert self.ctx is not None
172+
return llama_cpp.llama_sample_top_p_top_k(
173+
ctx=self.ctx,
174+
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
175+
*self.last_n_tokens_data
176+
),
177+
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
178+
top_k=llama_cpp.c_int(top_k),
179+
top_p=llama_cpp.c_float(top_p),
180+
temp=llama_cpp.c_float(temp),
181+
repeat_penalty=llama_cpp.c_float(repeat_penalty),
182+
)
183+
118184
def generate(
119185
self,
120186
tokens: Sequence[llama_cpp.llama_token],
@@ -125,7 +191,7 @@ def generate(
125191
) -> Generator[
126192
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
127193
]:
128-
"""Generate tokens.
194+
"""Create a generator of tokens from a prompt.
129195
130196
Examples:
131197
>>> llama = Llama("models/ggml-7b.bin")
@@ -149,37 +215,14 @@ def generate(
149215
top_p = 0.0
150216
top_k = 1
151217
assert self.ctx is not None
152-
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
153-
n_tokens = 0
154-
last_n_tokens = deque(
155-
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
156-
maxlen=self.last_n_tokens_size,
157-
)
218+
self.reset()
158219
while True:
159-
for i in range(0, len(tokens), self.n_batch):
160-
batch = tokens[i : min(len(tokens), i + self.n_batch)]
161-
n_past = min(n_ctx - len(batch), n_tokens)
162-
return_code = llama_cpp.llama_eval(
163-
ctx=self.ctx,
164-
tokens=(llama_cpp.llama_token * len(batch))(*batch),
165-
n_tokens=llama_cpp.c_int(len(batch)),
166-
n_past=llama_cpp.c_int(n_past),
167-
n_threads=llama_cpp.c_int(self.n_threads),
168-
)
169-
if int(return_code) != 0:
170-
raise RuntimeError(f"llama_eval returned {return_code}")
171-
last_n_tokens.extend(batch)
172-
n_tokens += len(batch)
173-
token = llama_cpp.llama_sample_top_p_top_k(
174-
ctx=self.ctx,
175-
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
176-
*last_n_tokens
177-
),
178-
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
179-
top_k=llama_cpp.c_int(top_k),
180-
top_p=llama_cpp.c_float(top_p),
181-
temp=llama_cpp.c_float(temp),
182-
repeat_penalty=llama_cpp.c_float(repeat_penalty),
220+
self.eval(tokens)
221+
token = self.sample(
222+
top_k=top_k,
223+
top_p=top_p,
224+
temp=temp,
225+
repeat_penalty=repeat_penalty,
183226
)
184227
tokens_or_none = yield token
185228
tokens = [token]
@@ -197,7 +240,8 @@ def create_embedding(self, input: str) -> Embedding:
197240
"""
198241
assert self.ctx is not None
199242
tokens = self.tokenize(input.encode("utf-8"))
200-
next(self.generate(tokens, top_k=0, top_p=0.0, temp=1.0, repeat_penalty=1.0))
243+
self.reset()
244+
self.eval(tokens)
201245
n_tokens = len(tokens)
202246
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
203247
: llama_cpp.llama_n_embd(self.ctx)

0 commit comments

Comments
 (0)
0