8000 Add Jinja template support by ochafik · Pull Request #11016 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Add Jinja template support #11016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
abd274a
Copy minja from https://github.com/google/minja/commit/58f0ca6dd74bcb…
Dec 30, 2024
e5113e8
Add --jinja and --chat-template-file flags
Dec 30, 2024
80138d9
Add missing <optional> include
Dec 30, 2024
06b5159
Avoid pr 8000 int in get_hf_chat_template.py
Dec 30, 2024
ce48584
No designated initializers yet
Dec 30, 2024
389d79b
Try and work around msvc++ non-macro max resolution quirk
Dec 30, 2024
238b968
Update test_chat_completion.py
Dec 30, 2024
cb72cf1
Merge remote-tracking branch 'origin/master' into jinja
Jan 13, 2025
78861a3
Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template
Jan 13, 2025
1aac99a
Refactor test-chat-template
Jan 13, 2025
7c84ebc
Test templates w/ minja
Jan 13, 2025
18f257b
Fix deprecation
Jan 13, 2025
8dd4f33
Add --jinja to llama-run
Jan 13, 2025
c04c50e
Merge remote-tracking branch 'origin/master' into jinja
Jan 13, 2025
a6afb27
Update common_chat_format_example to use minja template wrapper
Jan 13, 2025
b4083e4
Test chat_template in e2e test
Jan 13, 2025
b7e2171
Update utils.py
Jan 13, 2025
a57bb94
Update test_chat_completion.py
Jan 13, 2025
4daae0b
Update run.cpp
Jan 13, 2025
1b3bb7e
Update arg.cpp
ochafik Jan 14, 2025
3ed670b
Merge remote-tracking branch 'origin/master' into jinja
Jan 14, 2025
b75d062
Refactor common_chat_* functions to accept minja template + use_jinja…
Jan 18, 2025
40db789
Merge remote-tracking branch 'origin/master' into jinja
Jan 18, 2025
81c0d43
Attempt to fix linkage of LLAMA_CHATML_TEMPLATE
Jan 18, 2025
d5fa351
Revert LLAMA_CHATML_TEMPLATE refactor
Jan 18, 2025
ee1e10e
Normalize newlines in test-chat-templates for windows tests
Jan 18, 2025
e63520f
Forward decl minja::chat_template to avoid eager json dep
Jan 18, 2025
33322e8
Flush stdout in chat template before potential crash
Jan 18, 2025
5074e6f
Fix copy elision warning
Jan 18, 2025
fc60802
Rm unused optional include
Jan 18, 2025
0e74c9d
Add missing optional include to server.cpp
Jan 18, 2025
e3c475c
Disable jinja test that has a cryptic windows failure
Jan 18, 2025
cc50356
minja: fix vigogne (https://github.com/google/minja/pull/22)
Jan 18, 2025
153e852
Apply suggestions from code review
ochafik Jan 20, 2025
db9dd0c
Finish suggested renamings
Jan 20, 2025
c9e8fdd
Move chat_templates inside server_context + remove mutex
Jan 20, 2025
8c84aef
Update --chat-template-file w/ recent change to --chat-template
Jan 20, 2025
154bfaa
Refactor chat template validation
Jan 20, 2025
099f983
Merge remote-tracking branch 'origin/master' into jinja
Jan 20, 2025
54a669e
Guard against missing eos/bos tokens (null token otherwise throws in …
Jan 20, 2025
8348c60
Warn against missing eos / bos tokens when jinja template references …
Jan 20, 2025
ee475d2
rename: common_chat_template[s]
Jan 20, 2025
8a7c89e
reinstate assert on chat_templates.template_default
Jan 20, 2025
8347da9
Update minja to https://github.com/google/minja/commit/b8437df626ac6c…
Jan 20, 2025
ff2cce5
Update minja to https://github.com/google/minja/pull/25
Jan 21, 2025
9d8ebd6
Update minja from https://github.com/google/minja/pull/27
8000 Jan 21, 2025
cbb9b81
rm unused optional header
Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template
  • Loading branch information
ochafik committed Jan 13, 2025
commit 78861a3eb2f8583115cba378caad95b34c274b9c
16 changes: 2 additions & 14 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1822,17 +1822,6 @@ std::string common_chat_format_example(const struct llama_model * model,
return common_chat_apply_template(model, tmpl, msgs, true);
}

static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}

llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
Expand All @@ -1841,9 +1830,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
std::string default_template_src = chat_template_override;
std::string tool_use_template_src = chat_template_override;
if (chat_template_override.empty()) {
// TODO:
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
default_template_src = llama_model_chat_template(model, /* name */ nullptr);
tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use");
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!tool_use_template_src.empty()) {
Expand Down
4 changes: 2 additions & 2 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}
Expand Down
2 changes: 1 addition & 1 deletion examples/simple-chat/simple-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
break;
}

const char * tmpl = llama_model_chat_template(model);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);

// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ extern "C" {
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);

// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);

// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
Expand Down
6 changes: 4 additions & 2 deletions src/llama-arch.cpp
6D40
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
Expand Down Expand Up @@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};

LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

std::string LLM_KV::operator()(llm_kv kv) const {
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}

std::string LLM_TN_IMPL::str() const {
Expand Down
4 changes: 3 additions & 1 deletion src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
Expand Down Expand Up @@ -335,9 +336,10 @@ enum llm_tensor_layer {
};

struct LLM_KV {
LLM_KV(llm_arch arch);
LLM_KV(llm_arch arch, const char * suffix = nullptr);

llm_arch arch;
const char * suffix;

std::string operator()(llm_kv kv) const;
};
Expand Down
6 changes: 4 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3912,8 +3912,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
return model->size();
}

const char * llama_model_chat_template(const struct llama_model * model) {
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
return nullptr;
}
Expand Down
0