@@ -682,6 +682,101 @@ static bool llama_grammar_match_partial_char(
682
682
return !is_positive_char;
683
683
}
684
684
685
+ // transforms a grammar pushdown stack into N possible stacks, all ending
686
+ // at a character range (terminal element)
687
+ // additionally memoizes the stack to its possible stacks by mapping
688
+ // < llama_grammar_stack, llama_grammar_stacks >
689
+
690
+ static void llama_grammar_advance_stack_memo (
691
+ const llama_grammar_rules & rules,
692
+ const llama_grammar_stack & stack,
693
+ llama_grammar_stacks & new_stacks,
694
+ llama_grammar_stacks_cache & stacks_cache);
695
+
696
+ static void llama_grammar_advance_stack_memo_impl (
697
+ const llama_grammar_rules & rules,
698
+ const llama_grammar_stack & stack,
699
+ llama_grammar_stacks & new_stacks,
700
+ llama_grammar_stacks_cache & stacks_cache) {
701
+ if (stack.empty ()) {
702
+ if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
703
+ new_stacks.emplace_back (stack);
704
+ }
705
+ return ;
706
+ }
707
+
708
+ const llama_grammar_element * pos = stack.back ();
709
+
710
+ switch (pos->type ) {
711
+ case LLAMA_GRETYPE_RULE_REF: {
712
+ const size_t rule_id = static_cast <size_t >(pos->value );
713
+ const llama_grammar_element * subpos = rules[rule_id].data ();
714
+ do {
715
+ // init new stack without the top (pos)
716
+ llama_grammar_stack new_stack (stack.begin (), stack.end () - 1 );
717
+ if (!llama_grammar_is_end_of_sequence (pos + 1 )) {
718
+ // if this rule ref is followed by another element, add that to stack
719
+ new_stack.push_back (pos + 1 );
720
+ }
721
+ if (!llama_grammar_is_end_of_sequence (subpos)) {
722
+ // if alternate is nonempty, add to stack
723
+ new_stack.push_back (subpos);
724
+ }
725
+ llama_grammar_advance_stack_memo (rules, new_stack, new_stacks, stacks_cache);
726
+ while (!llama_grammar_is_end_of_sequence (subpos)) {
727
+ // scan to end of alternate def
728
+ subpos++;
729
+ }
730
+ if (subpos->type == LLAMA_GRETYPE_ALT) {
731
+ // there's another alternate def of this rule to process
732
+ subpos++;
733
+ } else {
734
+ break ;
735
+ }
736
+ } while (true );
737
+ break ;
738
+ }
739
+ case LLAMA_GRETYPE_CHAR:
740
+ case LLAMA_GRETYPE_CHAR_NOT:
741
+ case LLAMA_GRETYPE_CHAR_ANY:
742
+ if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
743
+ // only add the stack if it's not a duplicate of one we already have
744
+ new_stacks.emplace_back (stack);
745
+ }
746
+ break ;
747
+ default :
748
+ // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
749
+ // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
750
+ // those
751
+ GGML_ABORT (" fatal error" );
752
+ }
753
+ }
754
+
755
+ static void llama_grammar_advance_stack_memo (
756
+ const llama_grammar_rules & rules,
757
+ const llama_grammar_stack & stack,
758
+ llama_grammar_stacks & new_stacks,
759
+ llama_grammar_stacks_cache & stacks_cache) {
760
+
761
+ llama_grammar_stacks advanced_stacks;
762
+ // Look if stack is already in memory
763
+ auto it = stacks_cache.find (stack);
764
+ if (it != stacks_cache.end ()) {
765
+ advanced_stacks = it->second ;
766
+ } else {
767
+ // Advance stacks with memoization
768
+ llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache);
769
+ stacks_cache.insert (make_pair (stack, advanced_stacks));
770
+ }
771
+ // Add the advanced stacks to new_stacks avoiding duplicates
772
+ for (const auto & new_stack : advanced_stacks) {
773
+ if (std::find (new_stacks.begin (), new_stacks.end (), new_stack) == new_stacks.end ()) {
774
+ new_stacks.emplace_back (new_stack);
775
+ }
776
+ }
777
+
778
+ }
779
+
685
780
// transforms a grammar pushdown stack into N possible stacks, all ending
686
781
// at a character range (terminal element)
687
782
static void llama_grammar_advance_stack (
@@ -822,15 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
822
917
return grammar->stacks ;
823
918
}
824
919
825
- void llama_grammar_accept (
826
- const llama_grammar_rules & rules,
827
- const llama_grammar_stacks & stacks,
828
- const uint32_t chr,
829
- llama_grammar_stacks & stacks_new) {
830
- stacks_new.clear ();
831
- stacks_new.reserve (stacks.size ());
920
+ void llama_grammar_accept (struct llama_grammar * grammar, uint32_t chr) {
921
+ llama_grammar_stacks stacks_new;
922
+ stacks_new.reserve (grammar->stacks .size ());
832
923
833
- for (const auto & stack : stacks) {
924
+ for (const auto & stack : grammar-> stacks ) {
834
925
if (stack.empty ()) {
835
926
continue ;
836
927
}
@@ -844,9 +935,11 @@ void llama_grammar_accept(
844
935
if (!llama_grammar_is_end_of_sequence (pos)) {
845
936
new_stack.push_back (pos);
846
937
}
847
- llama_grammar_advance_stack ( rules, new_stack, stacks_new);
938
+ llama_grammar_advance_stack_memo (grammar-> rules , new_stack, stacks_new, grammar-> stacks_cache );
848
939
}
849
940
}
941
+
942
+ grammar->stacks = std::move (stacks_new);
850
943
}
851
944
852
945
llama_grammar_candidates llama_grammar_reject_candidates_for_stack (
@@ -938,14 +1031,15 @@ struct llama_grammar * llama_grammar_init_impl(
938
1031
939
1032
// loop over alternates of start rule to build initial stacks
940
1033
llama_grammar_stacks stacks;
1034
+ llama_grammar_stacks_cache stacks_cache;
941
1035
pos = vec_rules[start_rule_index].data ();
942
1036
do {
943
1037
llama_grammar_stack stack;
944
1038
if (!llama_grammar_is_end_of_sequence (pos)) {
945
1039
// if alternate is nonempty, add to stack
946
1040
stack.push_back (pos);
947
1041
}
948
- llama_grammar_advance_stack (vec_rules, stack, stacks);
1042
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
949
1043
while (!llama_grammar_is_end_of_sequence (pos)) {
950
1044
// scan to end of alternate def
951
1045
pos++;
@@ -961,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
961
1055
// Important: vec_rules has to be moved here, not copied, because stacks contains
962
1056
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
963
1057
// then the pointers would be invalidated when the local vec_rules goes out of scope.
964
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1058
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {}, };
965
1059
}
966
1060
967
1061
struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1016,14 +1110,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1016
1110
1017
1111
// loop over alternates of start rule to build initial stacks
1018
1112
llama_grammar_stacks stacks;
1113
+ llama_grammar_stacks_cache stacks_cache;
1019
1114
pos = vec_rules[start_rule_index].data ();
1020
1115
do {
1021
1116
llama_grammar_stack stack;
1022
1117
if (!llama_grammar_is_end_of_sequence (pos)) {
1023
1118
// if alternate is nonempty, add to stack
1024
1119
stack.push_back (pos);
1025
1120
}
1026
- llama_grammar_advance_stack (vec_rules, stack, stacks);
1121
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
1027
1122
while (!llama_grammar_is_end_of_sequence (pos)) {
1028
1123
// scan to end of alternate def
1029
1124
pos++;
@@ -1039,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1039
1134
// Important: vec_rules has to be moved here, not copied, because stacks contains
1040
1135
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1041
1136
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1042
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1137
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {}, };
1043
1138
}
1044
1139
1045
1140
void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1051,15 +1146,21 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
1051
1146
}
1052
1147
1053
1148
struct llama_grammar * llama_grammar_clone_impl (const struct llama_grammar & grammar) {
1054
- llama_grammar * result = new llama_grammar { grammar.vocab , grammar.rules , grammar.stacks , grammar.partial_utf8 , };
1149
+ llama_grammar * result = new llama_grammar {
1150
+ grammar.vocab ,
1151
+ grammar.rules ,
1152
+ grammar.stacks ,
1153
+ grammar.stacks_cache ,
1154
+ grammar.partial_utf8 ,
1155
+ };
1055
1156
1056
1157
// redirect elements in stacks to point to new rules
1057
1158
for (size_t is = 0 ; is < result->stacks .size (); is++) {
1058
1159
for (size_t ie = 0 ; ie < result->stacks [is].size (); ie++) {
1059
1160
for (size_t ir0 = 0 ; ir0 < grammar.rules .size (); ir0++) {
1060
1161
for (size_t ir1 = 0 ; ir1 < grammar.rules [ir0].size (); ir1++) {
1061
1162
if (grammar.stacks [is][ie] == &grammar.rules [ir0][ir1]) {
1062
- result->stacks [is][ie] = &result->rules [ir0][ir1];
1163
+ result->stacks [is][ie] = &result->rules [ir0][ir1];
1063
1164
}
1064
1165
}
1065
1166
}
@@ -1126,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1126
1227
const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
1127
1228
const auto & code_points = decoded.first ;
1128
1229
1129
- llama_grammar_stacks stacks_new;
1130
-
1131
1230
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1132
- llama_grammar_ac
F438
cept (grammar.rules , grammar.stacks , *it, stacks_new);
1133
- grammar.stacks = std::move (stacks_new);
1231
+ llama_grammar_accept (&grammar, *it);
1134
1232
}
1135
1233
1136
1234
grammar.partial_utf8 = decoded.second ;
0 commit comments