8000 llama : adds llama-grammar memoization stacks (#4218) #9833 · Nexesenex/croco.cpp@e78bed5 · GitHub
[go: up one dir, main page]

Skip to content

Commit e78bed5

Browse files
llama : adds llama-grammar memoization stacks (ggml-org#4218) ggml-org#9833
Grammar memo Co-Authored-By: Clarissa Miranda <80654285+clarissamiranda@users.noreply.github.com>
1 parent 3992df7 commit e78bed5

File tree

3 files changed

+139
-31
lines changed

3 files changed

+139
-31
lines changed

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,15 @@
1111
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
1212
const auto cpts = unicode_cpts_from_utf8(input_str);
1313

14-
const llama_grammar 8000 _rules & rules = llama_grammar_get_rules (grammar);
15-
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
14+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
1615

1716
size_t pos = 0;
1817
for (const auto & cpt : cpts) {
19-
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
20-
21-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
18+
llama_grammar_accept(grammar, cpt);
2219

2320
if (stacks_cur.empty()) {
2421
error_pos = pos;
2522
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
26-
stacks_cur = stacks_prev;
2723
return false;
2824
}
2925
++pos;
@@ -82,7 +78,8 @@ int main(int argc, char** argv) {
8278

8379
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
8480
if (grammar == nullptr) {
85-
throw std::runtime_error("Failed to initialize llama_grammar");
81+
fprintf(stdout, "Failed to initialize llama_grammar\n");
82+
return 1;
8683
}
8784
// Read the input file
8885
std::string input_str;

src/llama-grammar.cpp

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,101 @@ static bool llama_grammar_match_partial_char(
682682
return !is_positive_char;
683683
}
684684

685+
// transforms a grammar pushdown stack into N possible stacks, all ending
686+
// at a character range (terminal element)
687+
// additionally memoizes the stack to its possible stacks by mapping
688+
// < llama_grammar_stack, llama_grammar_stacks >
689+
690+
static void llama_grammar_advance_stack_memo(
691+
const llama_grammar_rules & rules,
692+
const llama_grammar_stack & stack,
693+
llama_grammar_stacks & new_stacks,
694+
llama_grammar_stacks_cache & stacks_cache);
695+
696+
static void llama_grammar_advance_stack_memo_impl(
697+
const llama_grammar_rules & rules,
698+
const llama_grammar_stack & stack,
699+
llama_grammar_stacks & new_stacks,
700+
llama_grammar_stacks_cache & stacks_cache) {
701+
if (stack.empty()) {
702+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
703+
new_stacks.emplace_back(stack);
704+
}
705+
return;
706+
}
707+
708+
const llama_grammar_element * pos = stack.back();
709+
710+
switch (pos->type) {
711+
case LLAMA_GRETYPE_RULE_REF: {
712+
const size_t rule_id = static_cast<size_t>(pos->value);
713+
const llama_grammar_element * subpos = rules[rule_id].data();
714+
do {
715+
// init new stack without the top (pos)
716+
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
717+
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
718+
// if this rule ref is followed by another element, add that to stack
719+
new_stack.push_back(pos + 1);
720+
}
721+
if (!llama_grammar_is_end_of_sequence(subpos)) {
722+
// if alternate is nonempty, add to stack
723+
new_stack.push_back(subpos);
724+
}
725+
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
726+
while (!llama_grammar_is_end_of_sequence(subpos)) {
727+
// scan to end of alternate def
728+
subpos++;
729+
}
730+
if (subpos->type == LLAMA_GRETYPE_ALT) {
731+
// there's another alternate def of this rule to process
732+
subpos++;
733+
} else {
734+
break;
735+
}
736+
} while (true);
737+
break;
738+
}
739+
case LLAMA_GRETYPE_CHAR:
740+
case LLAMA_GRETYPE_CHAR_NOT:
741+
case LLAMA_GRETYPE_CHAR_ANY:
742+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
743+
// only add the stack if it's not a duplicate of one we already have
744+
new_stacks.emplace_back(stack);
745+
}
746+
break;
747+
default:
748+
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
749+
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
750+
// those
751+
GGML_ABORT("fatal error");
752+
}
753+
}
754+
755+
static void llama_grammar_advance_stack_memo(
756+
const llama_grammar_rules & rules,
757+
const llama_grammar_stack & stack,
758+
llama_grammar_stacks & new_stacks,
759+
llama_grammar_stacks_cache & stacks_cache) {
760+
761+
llama_grammar_stacks advanced_stacks;
762+
// Look if stack is already in memory
763+
auto it = stacks_cache.find(stack);
764+
if (it != stacks_cache.end()) {
765+
advanced_stacks = it->second;
766+
} else {
767+
// Advance stacks with memoization
768+
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
769+
stacks_cache.insert(make_pair(stack, advanced_stacks));
770+
}
771+
// Add the advanced stacks to new_stacks avoiding duplicates
772+
for (const auto & new_stack : advanced_stacks) {
773+
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) {
774+
new_stacks.emplace_back(new_stack);
775+
}
776+
}
777+
778+
}
779+
685780
// transforms a grammar pushdown stack into N possible stacks, all ending
686781
// at a character range (terminal element)
687782
static void llama_grammar_advance_stack(
@@ -822,15 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
822917
return grammar->stacks;
823918
}
824919

825-
void llama_grammar_accept(
826-
const llama_grammar_rules & rules,
827-
const llama_grammar_stacks & stacks,
828-
const uint32_t chr,
829-
llama_grammar_stacks & stacks_new) {
830-
stacks_new.clear();
831-
stacks_new.reserve(stacks.size());
920+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
921+
llama_grammar_stacks stacks_new;
922+
stacks_new.reserve(grammar->stacks.size());
832923

833-
for (const auto & stack : stacks) {
924+
for (const auto & stack : grammar->stacks) {
834925
if (stack.empty()) {
835926
continue;
836927
}
@@ -844,9 +935,11 @@ void llama_grammar_accept(
844935
if (!llama_grammar_is_end_of_sequence(pos)) {
845936
new_stack.push_back(pos);
846937
}
847-
llama_grammar_advance_stack(rules, new_stack, stacks_new);
938+
llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache);
848939
}
849940
}
941+
942+
grammar->stacks = std::move(stacks_new);
850943
}
851944

852945
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -938,14 +1031,15 @@ struct llama_grammar * llama_grammar_init_impl(
9381031

9391032
// loop over alternates of start rule to build initial stacks
9401033
llama_grammar_stacks stacks;
1034+
llama_grammar_stacks_cache stacks_cache;
9411035
pos = vec_rules[start_rule_index].data();
9421036
do {
9431037
llama_grammar_stack stack;
9441038
if (!llama_grammar_is_end_of_sequence(pos)) {
9451039
// if alternate is nonempty, add to stack
9461040
stack.push_back(pos);
9471041
}
948-
llama_grammar_advance_stack(vec_rules, stack, stacks);
1042+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
9491043
while (!llama_grammar_is_end_of_sequence(pos)) {
9501044
// scan to end of alternate def
9511045
pos++;
@@ -961,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
9611055
// Important: vec_rules has to be moved here, not copied, because stacks contains
9621056
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
9631057
// then the pointers would be invalidated when the local vec_rules goes out of scope.
964-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1058+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
9651059
}
9661060

9671061
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1016,14 +1110,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
10161110

10171111
// loop over alternates of start rule to build initial stacks
10181112
llama_grammar_stacks stacks;
1113+
llama_grammar_stacks_cache stacks_cache;
10191114
pos = vec_rules[start_rule_index].data();
10201115
do {
10211116
llama_grammar_stack stack;
10221117
if (!llama_grammar_is_end_of_sequence(pos)) {
10231118
// if alternate is nonempty, add to stack
10241119
stack.push_back(pos);
10251120
}
1026-
llama_grammar_advance_stack(vec_rules, stack, stacks);
1121+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
10271122
while (!llama_grammar_is_end_of_sequence(pos)) {
10281123
// scan to end of alternate def
10291124
pos++;
@@ -1039,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
10391134
// Important: vec_rules has to be moved here, not copied, because stacks contains
10401135
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10411136
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1042-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1137+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
10431138
}
10441139

10451140
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1051,15 +1146,21 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
10511146
}
10521147

10531148
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1054-
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
1149+
llama_grammar * result = new llama_grammar {
1150+
grammar.vocab,
1151+
grammar.rules,
1152+
grammar.stacks,
1153+
grammar.stacks_cache,
1154+
grammar.partial_utf8,
1155+
};
10551156

10561157
// redirect elements in stacks to point to new rules
10571158
for (size_t is = 0; is < result->stacks.size(); is++) {
10581159
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
10591160
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
10601161
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
10611162
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1062-
result->stacks[is][ie] = &result->rules[ir0][ir1];
1163+
result->stacks[is][ie] = &result->rules[ir0][ir1];
10631164
}
10641165
}
10651166
}
@@ -1126,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11261227
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
11271228
const auto & code_points = decoded.first;
11281229

1129-
llama_grammar_stacks stacks_new;
1130-
11311230
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1132-
llama_grammar_ac F438 cept(grammar.rules, grammar.stacks, *it, stacks_new);
1133-
grammar.stacks = std::move(stacks_new);
1231+
llama_grammar_accept(&grammar, *it);
11341232
}
11351233

11361234
grammar.partial_utf8 = decoded.second;

src/llama-grammar.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44

55
#include <map>
6+
#include <unordered_map>
67

78
struct llama_vocab;
89

@@ -58,18 +59,27 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
5859
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
5960
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
6061

62+
struct VectorPointerHash {
63+
size_t operator()(const llama_grammar_stack & v) const {
64+
size_t seed = v.size();
65+
for (const auto* ptr : v) {
66+
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
67+
}
68+
return seed;
69+
}
70+
};
71+
72+
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
73+
74+
// TODO: remove, needed for tests atm
6175
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
6276
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
6377

6478
// takes a set of possible pushdown stacks on a grammar, which are required to
6579
// be positioned at a character range (see `llama_grammar_advance_stack`), and
6680
// produces the N possible stacks if the given char is accepted at those
6781
// positions
68-
void llama_grammar_accept(
69-
const llama_grammar_rules & rules,
70-
const llama_grammar_stacks & stacks,
71-
uint32_t chr,
72-
llama_grammar_stacks & stacks_new);
82+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
7383

7484
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
7585
const llama_grammar_rules & rules,
@@ -113,6 +123,9 @@ struct llama_grammar {
113123
const llama_grammar_rules rules; // TODO: shared ptr
114124
llama_grammar_stacks stacks;
115125

126+
// cache N possible stacks from a stack
127+
llama_grammar_stacks_cache stacks_cache;
128+
116129
// buffer for partially generated UTF-8 sequence from accepted tokens
117130
llama_partial_utf8 partial_utf8;
118131
};

0 commit comments

Comments
 (0)
0