8000 Add support for batch size to `--perplexity` by glinscott · Pull Request #407 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Add support for batch size to --perplexity #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 13, 2023
Next Next commit
Add support to batch size for perplexity
  • Loading branch information
glinscott committed Mar 22, 2023
commit 9ea43d4d9124d6a05ba1027dd05d65c5ffdfeae7
25 changes: 16 additions & 9 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,26 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int count = 0;
double nll = 0.0;
int seq_count = tokens.size() / params.n_ctx;
int n_vocab = llama_n_vocab(ctx);

fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);

for (int i = 0; i < seq_count; ++i) {
int start = i * params.n_ctx;
int end = start + params.n_ctx - 1;
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);

std::vector<float> logits;
int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
for (int j = 0; j < num_batches; ++j) {
int batch_start = start + j * params.n_batch;
int batch_size = std::min(end - batch_start, params.n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
auto end_t = std::chrono::high_resolution_clock::now();
if (i == 0) {
Expand All @@ -120,13 +129,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.

auto logits = llama_get_logits(ctx);
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
int n_vocab = llama_n_vocab(ctx);
std::vector<float> tok_logits(
logits + j * n_vocab,
logits + (j + 1) * n_vocab);
logits.begin() + j * n_vocab,
logits.begin() + (j + 1) * n_vocab);
double prob = softmax(tok_logits)[tokens[start + j + 1]];
nll += -std::log(prob);
++count;
Expand Down
0