8000 Fix session loading and saving in low level example chat · fabregas201307/llama-cpp-python@2c0d9b1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2c0d9b1

Browse files
committed
Fix session loading and saving in low level example chat
1 parent ed66a46 commit 2c0d9b1

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,17 @@ def __init__(self, params: GptParams) -> None:
112112

113113
if (path.exists(self.params.path_session)):
114114
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
115-
_n_token_count_out = llama_cpp.c_int()
115+
_n_token_count_out = llama_cpp.c_size_t()
116116
if (llama_cpp.llama_load_session_file(
117117
self.ctx,
118118
self.params.path_session.encode("utf8"),
119119
_session_tokens,
120120
self.params.n_ctx,
121121
ctypes.byref(_n_token_count_out)
122-
) != 0):
122+
) != 1):
123123
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
124124
return
125+
_n_token_count_out = _n_token_count_out.value
125126
self.session_tokens = _session_tokens[:_n_token_count_out]
126127
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
127128
else:
@@ -135,19 +136,21 @@ def __init__(self, params: GptParams) -> None:
135136
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
136137

137138
# debug message about similarity of saved session, if applicable
138-
n_matching_session_tokens = 0
139+
self.n_matching_session_tokens = 0
139140
if len(self.session_tokens) > 0:
140141
for id in self.session_tokens:
141-
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
142+
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
142143
break
143-
n_matching_session_tokens += 1
144+
self.n_matching_session_tokens += 1
144145

145-
if n_matching_session_tokens >= len(self.embd_inp):
146+
if self.n_matching_session_tokens >= len(self.embd_inp):
146147
print(f"session file has exact match for prompt!")
147-
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
148-
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
148+
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
149+
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
149150
else:
150-
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
151+
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
152+
153+
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
151154

152155
# number of tokens to keep when resetting context
153156
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
@@ -232,9 +235,6 @@ def __init__(self, params: GptParams) -> None:
232235
""", file=sys.stderr)
233236
self.set_color(util.CONSOLE_COLOR_PROMPT)
234237

235-
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
236-
237-
238238
# tokenize a prompt
239239
def _tokenize(self, prompt, bos=True):
240240
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
@@ -302,7 +302,7 @@ def generate(self):
302302
) != 0):
303303
raise Exception("Failed to llama_eval!")
304304

305-
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
305+
if len(self.embd) > 0 and len(self.params.path_session) > 0:
306306
self.session_tokens.extend(self.embd)
307307
self.n_session_consumed = len(self.session_tokens)
308308

@@ -319,7 +319,7 @@ def generate(self):
319319
llama_cpp.llama_save_session_file(
320320
self.ctx,
321321
self.params.path_session.encode("utf8"),
322-
self.session_tokens,
322+
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
323323
len(self.session_tokens)
324324
)
325325

0 commit comments

Comments
 (0)
0