8000 Sample interface, new samplers, by ivanstepanovftw · Pull Request #1126 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Sample interface, new samplers, #1126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2023
Merged
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mirostat
  • Loading branch information
ivanstepanovftw committed Apr 28, 2023
commit f01c67fe55d4c48b7903394416303aafc20e3f3b
37 changes: 29 additions & 8 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.alpha_presence = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat_eta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat_tau") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -264,14 +282,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p);
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z);
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty);
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence);
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency);
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat);
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
Expand Down
17 changes: 12 additions & 5 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@ struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt

// sampling parameters
int32_t top_k = 40;
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.10f;
int32_t top_k = 0; // <= 0 to use vocab size
float top_p = 1.0f; // 1.0 = disabled
float tfs_z = 1.0f; // 1.0 = disabled
float typical_p = 1.0f; // 1.0 = disabled
float temp = 1.0f; // 1.0 = disabled
float repeat_penalty = 1.0f; // 1.0 = disabled
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float alpha_frequency = 0.0f; // 0.0 = disabled
float alpha_presence = 0.0f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.0f; // target entropy
float mirostat_eta = 0.1f; // learning rate

std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
Expand Down
53 changes: 27 additions & 26 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");

Expand Down Expand Up @@ -396,6 +396,9 @@ int main(int argc, char ** argv) {
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.alpha_presence;
const float alpha_frequency = params.alpha_frequency;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;

// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
Expand All @@ -415,47 +418,45 @@ int main(int argc, char ** argv) {

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (size_t i = 0; i < n_vocab; i++) {
for (size_t i = 0; i < (size_t) n_vocab; i++) {
candidates.emplace_back(i, logits[i], 0.0f);
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// Apply penalties
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(&candidates_p,
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(&candidates_p,
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);


#if 1
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(&candidates_p, top_k);
llama_sample_tail_free(&candidates_p, tfs_z);
llama_sample_typical(&candidates_p, typical_p);
llama_sample_top_p(&candidates_p, top_p);

llama_sample_temperature(&candidates_p, temp);
// printf("`%d`", candidates_p.size);
id = llama_sample_token(ctx, &candidates_p);
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
static int mirostat_k = 40;
const int mirostat_m = 100;
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k);
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
llama_sample_typical(ctx, &candidates_p, typical_p);
llama_sample_top_p(ctx, &candidates_p, top_p);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
#else
const float tau = 5.0f;
static float mu = 2.0f * tau;
static int k = 40;
const float eta = 0.1f;
const int m = 100;
const float N = n_vocab;
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
#endif
// printf("`%d`", candidates_p.size);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
Expand Down
Loading
0