8000 Improve GNBF performance by attempting culled grammar search first (#… · wbruna/llama.cpp@5af9138 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5af9138

Browse files
authored
Improve GNBF performance by attempting culled grammar search first (ggml-org#1597)
* cull tokens with top_3k first before running grammar, fallback to unculled if none found * fix errors * fix improvement and test against concedo's GBNF * revert non-culling changes
1 parent 1cbe716 commit 5af9138

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

gpttype_adapter.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,7 +1593,12 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15931593
for (const auto & reject : rejects) {
15941594
candidates->data[reject.index].logit = -INFINITY;
15951595
}
1596-
1596+
1597+
auto first = candidates->data;
1598+
auto last = first + candidates->size;
1599+
last = std::remove_if(first, last,
1600+
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
1601+
candidates->size = last - first;
15971602
}
15981603

15991604
void sample_guidance(struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
@@ -1643,15 +1648,30 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
16431648

16441649
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
16451650

1646-
if (grammar != nullptr) {
1647-
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
1648-
}
1649-
16501651
//dry always first as logits cannot be resorted
16511652
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1652-
1653+
16531654
//prefilter to top 3k tokens for improved speed
1655+
bool use_grammar = grammar != nullptr;
1656+
size_t n_pre_cull = candidates_p.size;
1657+
16541658
sample_top_k(&candidates_p, 3000);
1659+
1660+
if (use_grammar) {
1661+
1662+
(debugmode == 1 && printf("\nGrammar sampling %zu candidates.\n", candidates_p.size));
1663+
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
1664+
(debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size));
1665+
1666+
// if top_k 3000 doesn't contain a valid candidate for this grammar, try again pre-cull
1667+
if (candidates_p.size <= 0) {
1668+
candidates_p.size = n_pre_cull;
1669+
(debugmode == 1 && printf("\nRe-sampling grammar with %zu pre-cull tokens.\n", candidates_p.size));
1670+
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
1671+
(debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size));
1672+
sample_top_k(&candidates_p, 3000);
1673+
}
1674+
}
16551675

16561676
if (mirostat == 1 || mirostat == 2)
16571677
{
@@ -1745,7 +1765,6 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct
17451765
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
17461766
const auto & code_points = decoded.first;
17471767
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1748-
auto prev_stacks = grammar->stacks;
17491768
llama_grammar_accept(grammar, *it);
17501769
}
17511770
grammar->partial_utf8 = decoded.second;
@@ -3941,6 +3960,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39413960
}
39423961

39433962
if (grammar != nullptr) {
3963+
(debugmode == 1 && printf("\nGrammar attempting to accept token...\n"));
39443964
grammar_accept_token(file_format, n_vocab, grammar, id);
39453965
}
39463966

0 commit comments

Comments
 (0)
0