8000 feat: Allow overriding GGUF metadata when loading model by KerfuffleV2 · Pull Request #4092 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

feat: Allow overriding GGUF metadata when loading model #4092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix broken logic for parsing bool KV overrides
Fix issue where overrides didn't apply when key missing in GGUF metadata
Resolve merge changes
  • Loading branch information
KerfuffleV2 committed Nov 18, 2023
commit aa7cf3143be910d0c35d60023440c05a850e05c4
7 changes: 5 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
if (std::strcmp(sep, "true")) {
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false")) {
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
Expand Down Expand Up @@ -888,6 +888,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf("\n");
#ifndef LOG_DISABLE_LOGS
log_print_usage();
Expand Down
32 changes: 13 additions & 19 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
}
}

static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);

switch (type) {
Expand Down Expand Up @@ -1895,16 +1895,13 @@ namespace GGUFMeta {
if (try_override<T>(target, override)) {
return true;
}
if (k < 0) { return false; }
target = get_kv(ctx, k);
return true;
}

static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
const int kid = gguf_find_key(ctx, key);
if (kid < 0) {
return false;
}
return set(ctx, kid, target, override);
return set(ctx, gguf_find_key(ctx, key), target, override);
}

static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
Expand Down Expand Up @@ -2367,6 +2364,7 @@ static void llm_load_hparams(
llama_model_loader & ml,
llama_model & model) {
auto & hparams = model.hparams;
const gguf_context * ctx = ml.ctx_gguf;

// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
Expand Down Expand Up @@ -2678,19 +2676,15 @@ static void llm_load_vocab(
}

// Handle add_bos_token and add_eos_token
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
int kid = gguf_find_key(ctx, key.c_str());
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
}
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
kid = gguf_find_key(ctx, key.c_str());
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
{
bool temp = true;

if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.special_add_bos = int(temp);
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.special_add_eos = int(temp);
}
}
}

Expand Down
0