8000 tokenizer : special token handling by staviq · Pull Request #3538 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

tokenizer : special token handling #3538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 17, 2023
Prev Previous commit
Next Next commit
formatting, remove copying iterator on delete
  • Loading branch information
staviq committed Oct 11, 2023
commit f7b1205a515ca24cd8af1ecec94b697f275400db
112 changes: 23 additions & 89 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,11 +2268,11 @@ static void llm_load_vocab(
const auto & id = t.second;

// Count all non-normal tokens in the vocab while iterating
if( vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL )
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL)
special_tokens_count_by_type++;

// Skip single character tokens
if( token.length() > 1 )
if (token.length() > 1)
{
bool is_tokenizable = false;

Expand All @@ -2284,17 +2284,16 @@ static void llm_load_vocab(
const auto right = token.substr(i);

// check if we didnt partition in the middle of a utf sequence
auto utf = utf8_len( left.at( left.length() -1 ) );
auto utf = utf8_len(left.at(left.length() - 1));

if( utf == 1 )
if (utf == 1)
{
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() &&
vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
{
is_tokenizable = true;
break;
}

i++;
}
else
Expand All @@ -2314,21 +2313,20 @@ static void llm_load_vocab(
for (unsigned i = 0; i < token.length();)
{
utf8_str_len++;
i += utf8_len( token.at(i) );
i += utf8_len(token.at(i));
}

// And skip the ones which are one character
if (utf8_str_len > 1)
{
// At this point what we have left are special tokens only

vocab.special_tokens_cache[token] = id;

// Count manually found special tokens
special_tokens_count_from_verification ++;

// If this manually found special token is not marked as such, flag a mismatch
if( vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL )
if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL)
special_tokens_definition_mismatch = true;
}
}
Expand All @@ -2337,7 +2335,7 @@ static void llm_load_vocab(

if( special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type )
{
fprintf(stderr, "%s: WARNING: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
fprintf(stderr, "warning: %s: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size()
Expand Down Expand Up @@ -6608,89 +6606,71 @@ struct fragment_buffer_variant{
const uint64_t length;
};

#define PRETOKENIZERDEBUG
// #define PRETOKENIZERDEBUG

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
{
// for each special token
for( const auto & st: vocab.special_tokens_cache )
for (const auto & st: vocab.special_tokens_cache)
{
const auto & special_token = st.first;
const auto & special_id = st.second;

// for each text fragment
//for (auto & fragment: buffer)
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
while (it != buffer.end())
{
auto & fragment = (*it);
// if a fragment is text ( not yet processed )
if( fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
{
auto * raw_text = &(fragment.raw_text);
auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length;

// loop over the text
while(true)
while (true)
{
// find the first occurence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text
auto match = raw_text->find( special_token, raw_text_base_offset );

// no occurences found, stop processing this fragment for a given special token
if (match == std::string::npos)
{
break;
}
if (match == std::string::npos) break;

// check if match is within bounds of offset <-> length
if( match + special_token.length() > raw_text_base_offset + raw_text_base_length )
{
// match is out of bounds
break;
}
if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;

#ifdef PRETOKENIZERDEBUG
fprintf(stderr,"FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
#endif

auto source = std::distance( buffer.begin(), it );
auto source = std::distance(buffer.begin(), it);

// if match is further than base offset
// then we have some text to the left of it
if( match > raw_text_base_offset )
if (match > raw_text_base_offset)
{
// left
//buffer.emplace_after(it, raw_text->substr(0, match));
const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset;
buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);

#ifdef PRETOKENIZERDEBUG
fprintf(stderr,"FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
#endif

it++;
}

// special token
buffer.emplace_after(it, special_id);
it++;


// right
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length)
{
/*
| |
-------------------------------------------------------------------------
. |ttttt| |
*/
//buffer.emplace_after(it, raw_text->substr(match + special_token.length()));
const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ( ( match - raw_text_base_offset ) + special_token.length() );
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);

#ifdef PRETOKENIZERDEBUG
Expand All @@ -6699,66 +6679,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<

it++;

if (source == 0)
{
// TODO? It might not be needed to store/restore the iterator like this
// but this gives me the peace of mind I'm not causing some
// accidental undefined behaviour.
auto it_backup = std::distance( buffer.begin(), it );

buffer.erase_after(buffer.before_begin());

it = std::next( buffer.begin(), it_backup-1 );
}
else
{
auto it_backup = std::distance( buffer.begin(), it );

//auto prev = std::prev( buffer.begin(), -(source-1) );
auto prev = std::next( buffer.begin(), (source-1) );
buffer.erase_after(prev);

it = std::next( buffer.begin(), it_backup-1 );
}
//it = std::prev( it, 1 );
if (source == 0) buffer.erase_after(buffer.before_begin());
else buffer.erase_after(std::next(buffer.begin(), (source-1)));

// repeat for the right side
raw_text_base_offset = right_reminder_offset; //match + special_token.length();
raw_text_base_length = right_reminder_length; //right_reminder_length - ( ( match + special_token.length() ) - raw_text_base_offset );
//raw_text = &((*it).raw_text);
raw_text_base_offset = right_reminder_offset;
raw_text_base_length = right_reminder_length;

#ifdef PRETOKENIZERDEBUG
fprintf(stderr,"RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
#endif

}
else
{
if (source == 0)
{
auto it_backup = std::distance( buffer.begin(), it );

buffer.erase_after(buffer.before_begin());

it = std::next( buffer.begin(), it_backup-1 );
}
else
{
auto it_backup = std::distance( buffer.begin(), it );

//auto prev = std::prev( buffer.begin(), -(source) );
auto prev = std::next( buffer.begin(), (source-1) );
buffer.erase_after(prev);

it = std::next( buffer.begin(), it_backup-1 );
}
//it = std::prev( it, 1 );

if (source == 0) buffer.erase_after(buffer.before_begin());
else buffer.erase_after(std::next(buffer.begin(), (source-1)));
break;
}
}
}

it++;
}
}
Expand All @@ -6781,12 +6720,9 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}

std::forward_list<fragment_buffer_variant> fragment_buffer;

fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );

if (special) {
tokenizer_st_partition( vocab, fragment_buffer );
}
if (special) tokenizer_st_partition( vocab, fragment_buffer );

switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM:
Expand All @@ -6806,7 +6742,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG
fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif

llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
Expand All @@ -6828,7 +6763,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG
fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif

llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
}
Expand Down
0