8000 Merge pull request #87 from SagsMug/main · hongbopeng/llama-cpp-python@4ce6670 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4ce6670

Browse files
authored
Merge pull request abetlen#87 from SagsMug/main
Fix TypeError in low_level chat
2 parents eb7f278 + 1b73a15 commit 4ce6670

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

examples/low_level_api/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class GptParams:
5050
# If chat ended prematurely, append this to the conversation to fix it.
5151
# Set to "\nUser:" etc.
5252
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
53-
fix_prefix: str = " "
53+
fix_prefix: str = ""
5454
output_postfix: str = ""
5555
input_echo: bool = True,
5656

@@ -75,7 +75,7 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
7575
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
7676
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
7777
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
78-
parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict")
78+
parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
7979
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
8080
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
8181
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(self, params: GptParams) -> None:
144144

145145
# determine newline token
146146
self.llama_token_newline = self._tokenize("\n", False)
147+
self.llama_token_eot = self._tokenize(" [end of text]\n", False)
147148

148149
if (self.params.verbose_prompt):
149150
print(f"""
@@ -203,16 +204,16 @@ def _tokenize(self, prompt, bos=True):
203204
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
204205
return _arr[:_n]
205206

206-
def use_antiprompt(self):
207-
return len(self.first_antiprompt) > 0
208-
209207
def set_color(self, c):
210208
if (self.params.use_color):
211209
print(c, end="")
212210

211+
def use_antiprompt(self):
212+
return len(self.first_antiprompt) > 0
213+
213214
# generate tokens
214215
def generate(self):
215-
while self.remaining_tokens > 0 or self.params.interactive:
216+
while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1:
216217
# predict
217218
if len(self.embd) > 0:
218219
# infinite text generation via context swapping
@@ -313,7 +314,7 @@ def generate(self):
313314
# end of text token
314315
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
315316
if (not self.params.instruct):
316-
for i in " [end of text]\n":
317+
for i in self.llama_token_eot:
317318
yield i
318319
break
319320

0 commit comments

Comments
 (0)
0