@@ -1593,7 +1593,12 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
1593
1593
for (const auto & reject : rejects) {
1594
1594
candidates->data [reject.index ].logit = -INFINITY;
1595
1595
}
1596
-
1596
+
1597
+ auto first = candidates->data ;
1598
+ auto last = first + candidates->size ;
1599
+ last = std::remove_if (first, last,
1600
+ [&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
1601
+ candidates->size = last - first;
1597
1602
}
1598
1603
1599
1604
void sample_guidance (struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
@@ -1643,15 +1648,30 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
1643
1648
1644
1649
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
1645
1650
1646
- if (grammar != nullptr ) {
1647
- sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1648
- }
1649
-
1650
1651
// dry always first as logits cannot be resorted
1651
1652
sample_dry (n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1652
-
1653
+
1653
1654
// prefilter to top 3k tokens for improved speed
1655
+ bool use_grammar = grammar != nullptr ;
1656
+ size_t n_pre_cull = candidates_p.size ;
1657
+
1654
1658
sample_top_k (&candidates_p, 3000 );
1659
+
1660
+ if (use_grammar) {
1661
+
1662
+ (debugmode == 1 && printf (" \n Grammar sampling %zu candidates.\n " , candidates_p.size ));
1663
+ sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1664
+ (debugmode == 1 && printf (" \n Grammar returned %zu candidates.\n " , candidates_p.size ));
1665
+
1666
+ // if top_k 3000 doesn't contain a valid candidate for this grammar, try again pre-cull
1667
+ if (candidates_p.size <= 0 ) {
1668
+ candidates_p.size = n_pre_cull;
1669
+ (debugmode == 1 && printf (" \n Re-sampling grammar with %zu pre-cull tokens.\n " , candidates_p.size ));
1670
+ sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1671
+ (debugmode == 1 && printf (" \n Grammar returned %zu candidates.\n " , candidates_p.size ));
1672
+ sample_top_k (&candidates_p, 3000 );
1673
+ }
1674
+ }
1655
1675
1656
1676
if (mirostat == 1 || mirostat == 2 )
1657
1677
{
@@ -1745,7 +1765,6 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct
1745
1765
const auto decoded = decode_utf8 (piece.c_str (), grammar->partial_utf8 );
1746
1766
const auto & code_points = decoded.first ;
1747
1767
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1748
- auto prev_stacks = grammar->stacks ;
1749
1768
llama_grammar_accept (grammar, *it);
1750
1769
}
1751
1770
grammar->partial_utf8 = decoded.second ;
@@ -3941,6 +3960,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
3941
3960
}
3942
3961
3943
3962
if (grammar != nullptr ) {
3963
+ (debugmode == 1 && printf (" \n Grammar attempting to accept token...\n " ));
3944
3964
grammar_accept_token (file_format, n_vocab, grammar, id);
3945
3965
}
3946
3966
0 commit comments