8000 Fix unicode in grammars (fixes #2501) by ejones · Pull Request #2553 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Fix unicode in grammars (fixes #2501) #2553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix unicode in grammars (fixes #2501)
  • Loading branch information
ejones committed Aug 8, 2023
commit 961f2ab73f99bd78e296585bc908d1eab7dc16ad
149 changes: 125 additions & 24 deletions llama.cpp
10000 7815
Original file line number Diff line number Diff line change
Expand Up @@ -2082,37 +2082,79 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
// grammar - internal
//

struct llama_partial_utf8 {
uint32_t value;
int n_remain;
};

struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;
llama_partial_utf8 partial_utf8;
};

struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};

// NOTE: assumes valid utf8 (but checks for overrun)
// Decodes a UTF-8 string which may end in an incomplete sequence. Otherwise, assumes valid UTF-8.
// adds a terminating 0 for use as pointer
std::vector<uint32_t> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src,
llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src;
std::vector<uint32_t> code_points;
uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain;

// continue previous decode, if applicable
while (*pos != 0 && n_remain > 0) {
uint8_t next_byte = static_cast<uint8_t>(*pos);
if ((next_byte >> 6) != 2) {
// REVIEW invalid sequence, abort
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
}
value = (value << 6) + (next_byte & 0x3F);
++pos;
--n_remain;
}

if (partial_start.n_remain > 0 && n_remain == 0) {
code_points.push_back(value);
}

// decode any subsequent utf-8 sequences, which may end in an incomplete one
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = pos + len; // may overrun!
n_remain = lookup[highbits] - 1;

if (n_remain < 0) {
// REVIEW invalid sequence, abort
code_points.clear();
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
}

uint8_t mask = (1 << (7 - n_remain)) - 1;
value = first_byte & mask;
++pos;
for ( ; pos < end && *pos != 0; ++pos) {
while (*pos != 0 && n_remain > 0) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
++pos;
--n_remain;
}
if (n_remain == 0) {
code_points.push_back(value);
}
code_points.push_back(value);
}
code_points.push_back(0);
return code_points;

return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
}

// returns true iff pos points to the end of one of the definitions of a rule
Expand Down Expand Up @@ -2149,6 +2191,52 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
return std::make_pair(found == is_positive_char, pos);
}

static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) {

bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);

uint32_t partial_value = partial_utf8.value;
int n_remain = partial_utf8.n_remain;

if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
return false;
}

// range of possible code points this partial UTF-8 sequence could complete to
uint32_t low = partial_value << (n_remain * 6);
uint32_t high = low | ((1 << (n_remain * 6)) - 1);

if (low == 0) {
if (n_remain == 2) {
low = 1 << 11;
} else if (n_remain == 3) {
low = 1 << 16;
}
}

do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
if (pos->value <= high && low <= pos[1].value) {
return is_positive_char;
}
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) {
return is_positive_char;
}
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);

return !is_positive_char;
}


// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
Expand Down Expand Up @@ -2249,19 +2337,27 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
std::vector<llama_grammar_candidate> rejects;

if (stack.empty()) {
// accept nothing; EOS is handled elsewhere
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
for (auto tok : candidates) {
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
rejects.push_back(tok);
}
}
return rejects;
}

const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
for (auto tok : candidates) {
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
if (tok.code_points[1] != 0) {
next_candidates.push_back({ tok.index, tok.code_points + 1 });
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
// that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 &&
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
rejects.push_back(tok);
}
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
} else {
rejects.push_back(tok);
}
Expand All @@ -2279,7 +2375,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_

auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1 });
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
}

return rejects;
Expand Down Expand Up @@ -2344,7 +2440,7 @@ struct llama_grammar * llama_grammar_init(
}
} while (true);

return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}

void llama_grammar_free(struct llama_grammar * grammar) {
Expand Down Expand Up @@ -2650,8 +2746,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c

const llama_token eos = llama_token_eos();

std::vector<std::vector<uint32_t>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
Expand All @@ -2663,8 +2759,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (*str == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(str));
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8));
candidates_grammar.push_back({
i, candidates_decoded.back().first.data(), candidates_decoded.back().second
});
}
}

Expand Down Expand Up @@ -2865,11 +2963,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
}

const char * str = llama_token_to_str(ctx, token);

// Note terminating 0 in decoded string
auto code_points = decode_utf8(str);
const auto decoded = decode_utf8(str, grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
grammar->partial_utf8 = decoded.second;
LLAMA_ASSERT(!grammar->stacks.empty());

ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
Expand Down
0