8000 fixed mistake done while reading · handshape/llama-cpp-python@4c74a82 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c74a82

Browse files
authored
fixed mistake done while reading
1 parent 5c050e8 commit 4c74a82

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

llama_cpp/llama_grammar.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
841841
raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
842842

843843

844-
previous_elements = out_elements[last_sym_start:]
844+
previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()]
845845

846846
if min_times == 0:
847847
out_elements.resize(last_sym_start)
@@ -859,12 +859,12 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
859859
rec_rule.resize(len(previous_elements))
860860
rec_rule_id = generate_symbol_id(state, rule_name) # type: int
861861
if i > 0 or max_times < 0:
862-
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id))
862+
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
863863
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
864864
rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
865865
add_rule(state, rec_rule_id, rec_rule)
866-
867866
last_rec_rule_id = rec_rule_id
867+
868868
if n_opt > 0:
869869
out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
870870

@@ -1058,6 +1058,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
10581058
max_times = min_times
10591059
pos = parse_space(pos + 1, is_nested)
10601060
elif pos[0] == ",":
1061+
10611062
pos = parse_space(pos + 1, is_nested)
10621063
if is_digit_char(pos[0]):
10631064
int_end = parse_int(pos)
@@ -1281,6 +1282,7 @@ def print_rule(
12811282
# break;
12821283
# }
12831284

1285+
12841286
for i, elem in enumerate(rule[:-1]):
12851287
case = elem.type # type: llama_gretype
12861288
if case is llama_gretype.LLAMA_GRETYPE_END:

0 commit comments

Comments
 (0)
0