8000 Server Example Refactor and Improvements by digiwombat · Pull Request #1570 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Server Example Refactor and Improvements #1570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 161 commits into from
Jun 17, 2023
Merged
Changes from 1 commit
Commits
Show all changes
161 commits
Select commit Hold shift + click to select a range
1c3fdf8
Add all generation parameters to server.cpp and allow resetting context
digiwombat May 23, 2023
2071d73
Forgot to remove some testing code.
digiwombat May 23, 2023
421e66b
Update examples/server/server.cpp
digiwombat May 23, 2023
add5f1b
Update examples/server/server.cpp
digiwombat May 23, 2023
3537ad1
Merge branch 'ggerganov:master' into master
digiwombat May 23, 2023
8d7b28c
Fixed some types in the params.
digiwombat May 23, 2023
c2b55cc
Added LoRA Loading
digiwombat May 25, 2023
48cb16a
Merge branch 'ggerganov:master' into master
digiwombat May 27, 2023
66ed19d
Corrected dashes in the help lines.
digiwombat May 27, 2023
36c86d7
Automate Context resetting and minor fixes
digiwombat May 27, 2023
d20f36b
Removed unnecessary last_prompt_token set
digiwombat May 27, 2023
fdce895
Merge branch 'ggerganov:master' into master
digiwombat May 27, 2023
e84b802
Change top_k type.
digiwombat May 28, 2023
1f40a78
Didn't see the already defined top_k var.
digiwombat May 28, 2023
51e0994
server rewrite
SlyEcho May 27, 2023
f93fe36
Add all generation parameters to server.cpp and allow resetting context
digiwombat May 23, 2023
df0e0d0
Forgot to remove some testing code.
digiwombat May 23, 2023
549291f
keep processed from the beginning
SlyEcho May 28, 2023
177868e
Changed to params/args
digiwombat May 28, 2023
e8efd75
Initial timeout code and expanded json return on completion.
digiwombat May 28, 2023
23928f2
Added generation_settings to final json object.
digiwombat May 28, 2023
2e5c5ee
Changed JSON names to match the parameter name rather than the variab…
digiwombat May 28, 2023
dda915c
Added capturing the stopping word and sending it along with the final…
digiwombat May 28, 2023
7740301
Set unspecified generation settings back to default. (Notes below)
digiwombat May 28, 2023
7186d65
seed and gen params
SlyEcho May 28, 2023
15ddc49
Merge remote-tracking branch 'slyecho/server_refactor'
digiwombat May 28, 2023
74c6f36
Editorconfig suggested fixes
SlyEcho May 28, 2023
2c9ee7a
Apply suggestions from code review
digiwombat May 28, 2023
655899d
Add ignore_eos option to generation settings.
digiwombat May 28, 2023
b38d41e
--memory_f32 flag to --memory-f32 to match common.cpp
digiwombat May 28, 2023
6c58f64
--ctx_size flag to --ctx-size to match common.cpp
digiwombat May 28, 2023
33b6957
Fixed failing to return result on stopping token.
digiwombat May 28, 2023
42cf4d8
Merge branch 'master' into master
SlyEcho May 28, 2023
03ea8f0
Fix for the regen issue.
digiwombat May 30, 2023
d6fff56
add streaming via server-sent events
May 30, 2023
3292f05
Changed to single API endpoint for streaming and non.
digiwombat May 30, 2023
38eaf2b
Removed testing fprintf calls.
digiwombat May 30, 2023
a25f830
Default streaming to false if it's not set in the request body.
digiwombat May 31, 2023
2533878
Merge branch 'master' into sse
digiwombat May 31, 2023
e6de69a
Merge pull request #3 from anon998/sse
digiwombat May 31, 2023
7a853dc
prevent the server from swallowing exceptions in debug mode
May 31, 2023
aa0788b
add --verbose flag and request logging
May 31, 2023
9197674
Merge pull request #4 from anon998/logging
digiwombat May 31, 2023
b6f536d
Cull to end of generated_text when encountering a stopping string in …
digiwombat May 31, 2023
7a8104f
add missing quote when printing stopping strings
May 31, 2023
3a079d5
stop generating when the stream is closed
May 31, 2023
9f2424a
Merge pull request #5 from anon998/stop-stream
digiwombat May 31, 2023
c1cbde8
print error when server can't bind to the interface
May 31, 2023
2c08f29
make api server use only a single thread
May 31, 2023
284bc29
reserve memory for generated_text
May 31, 2023
f1710b9
add infinite generation when n_predict is -1
May 31, 2023
aa2bbb2
fix parameter type
May 31, 2023
27911d6
fix default model alias
May 31, 2023
dd30219
buffer incomplete multi-byte characters
May 31, 2023
40e1380
print timings + build info
May 31, 2023
d58e486
default penalize_nl to false + format
May 31, 2023
3edaf6b
print timings by default
May 31, 2023
96fa480
Merge pull request #6 from anon998/fix-multibyte
digiwombat May 31, 2023
7332b41
Simple single-line server log for requests
digiwombat May 31, 2023
dda4c10
Switch to the CPPHTTPLIB logger. Verbose adds body dump as well as re…
digiwombat May 31, 2023
86337e3
Server console logs now come in one flavor: Verbose.
digiwombat May 31, 2023
1b96df2
Spacing fix. Nothing to see here.
digiwombat May 31, 2023
276fa99
Misunderstood the instructions, I think. Back to the raw JSON output …
digiwombat May 31, 2023
43d295f
filter empty stopping strings
May 31, 2023
1bd7cc6
reuse format_generation_settings for logging
May 31, 2023
497160a
remove old log function
May 31, 2023
f2e1130
Merge pull request #7 from anon998/logging-reuse
digiwombat May 31, 2023
9104fe5
Change how the token buffers work.
SlyEcho May 31, 2023
8478e59
Merge pull request #8 from SlyEcho/server_refactor
digiwombat May 31, 2023
bed308c
Apply suggestions from code review
SlyEcho May 31, 2023
342604b
Added a super simple CORS header as default for all endpoints.
digiwombat May 31, 2023
e9b1f0b
fix stopping strings
May 31, 2023
5f6e16d
Merge pull request #9 from anon998/stopping-strings
digiwombat Jun 1, 2023
f7882e2
Fixed a crash caused by erasing from empty last_n_tokens
digiwombat Jun 1, 2023
5bbc030
Add Options enpoints and Access-Control-Allow-Headers to satisfy CORS…
cirk2 Jun 1, 2023
8c6a5fc
last tokens fixes
SlyEcho Jun 1, 2023
9531ae6
Add logit bias support
SlyEcho Jun 1, 2023
797155a
Merge pull request #10 from cirk2/master
digiwombat Jun 1, 2023
af71126
Merge pull request #11 from SlyEcho/server_refactor
digiwombat Jun 1, 2023
49a18bd
remove unused parameter warning
Jun 1, 2023
6025476
default penalize_nl back to true
Jun 1, 2023
8cbc4be
clear logit_bias between requests + print
Jun 1, 2023
d29b6d5
Merge pull request #12 from anon998/clear-logit-bias
digiwombat Jun 1, 2023
0bc0477
Apply suggestions from code review
SlyEcho Jun 2, 2023
731ecc0
fix typo
Jun 2, 2023
ebfead6
remove unused variables
Jun 2, 2023
1488a0f
make functions that never return false void
Jun 2, 2023
49dce94
make types match gpt_params exactly
Jun 2, 2023
a8a9f19
small fixes
Jun 2, 2023
2932db1
avoid creating element in logit_bias accidentally
Jun 2, 2023
47efbb5
use std::isinf to check if ignore_eos is active
Jun 2, 2023
88cc7bb
Stuff with logits
SlyEcho Jun 2, 2023
abb7782
Merge branch 'master' into small-fixes
Jun 2, 2023
bebea65
Merge pull request #13 from anon998/small-fixes
digiwombat Jun 2, 2023
8f9e546
trim partial stopping strings when not streaming
Jun 2, 2023
f820740
move multibyte check to doCompletion
Jun 2, 2023
f5d5e70
Merge pull request #14 from anon998/do-completion-update
digiwombat Jun 2, 2023
1bd52c8
Merge branch 'ggerganov:master' into master
digiwombat Jun 2, 2023
3df0192
improve long input truncation
SlyEcho Jun 2, 2023
28cc0cd
Merge pull request #15 from SlyEcho/server_refactor
digiwombat Jun 2, 2023
3ff27d3
Fixed up a few things in embedding mode.
digiwombat Jun 2, 2023
41bb71b
replace invalid characters instead of crashing
Jun 2, 2023
4dd72fc
Merge pull request #16 from anon998/fix-log-json
digiwombat Jun 2, 2023
16e1c98
Removed the embedding api endpoint and associated code.
digiwombat Jun 2, 2023
7cebe2e
Merge branch 'master' of https://github.com/digiwombat/llama.cpp
digiwombat Jun 2, 2023
bcd6167
improve docs and example
SlyEcho Jun 2, 2023
de6df48
Removed embedding from README
digiwombat Jun 2, 2023
310bf61
Merge pull request #17 from SlyEcho/server_refactor
digiwombat Jun 2, 2023
5758e9f
Removed embedding from flags.
digiwombat Jun 2, 2023
e1e2be2
remove --keep from help text
Jun 2, 2023
a6ed390
update readme
Jun 2, 2023
05a5a48
make help text load faster
Jun 2, 2023
98ae2de
parse --mlock and --no-mmap + format
Jun 2, 2023
df2ecc9
Merge pull request #18 from anon998/update-readme
digiwombat Jun 2, 2023
64a0653
Merge remote-tracking branch 'upstream/master'
digiwombat Jun 7, 2023
61befcb
Apply suggestions from code review
SlyEcho Jun 8, 2023
ccd85e0
Apply suggestions from code review
SlyEcho Jun 8, 2023
a9c3477
Spaces to 4 and other code style cleanup. Notes in README.
digiwombat Jun 9, 2023
cc2b336
Missed a pair of catch statements for formatting.
digiwombat Jun 9, 2023
23a1b18
Merge branch 'ggerganov:master' into master
digiwombat Jun 9, 2023
7580427
Resolving some review comments
digiwombat Jun 9, 2023
889d904
Merge branch 'master' of https://github.com/digiwombat/llama.cpp
digiwombat Jun 9, 2023
7cdeb08
More formatting cleanup
digiwombat Jun 9, 2023
1a9141b
Remove model assign in main(). Clarified stop in README.
digiwombat Jun 9, 2023
917540c
Clarify build instructions in README.
lesaun Jun 10, 2023
d6d263f
Merge pull request #19 from lesaun/master
digiwombat Jun 10, 2023
bac0ddb
Merge branch 'ggerganov:master' into master
digiwombat Jun 10, 2023
2c00bf8
more formatting changes
SlyEcho Jun 11, 2023
9612d12
big logging update
SlyEcho Jun 11, 2023
6518f9c
build settings
SlyEcho Jun 11, 2023
eee8b28
Merge pull request #20 from SlyEcho/server_refactor
digiwombat Jun 11, 2023
4148b9b
remove void
SlyEcho Jun 12, 2023
dff11a1
json parsing improvements
SlyEcho Jun 12, 2023
13cf692
more json changes and stop info
SlyEcho Jun 12, 2023
b91200a
javascript chat update.
SlyEcho Jun 12, 2023
1510337
fix make flags propagation
SlyEcho Jun 12, 2023
fc4264d
api url
SlyEcho Jun 12, 2023
28694f7
add a simple bash script too
SlyEcho Jun 12, 2023
429ed95
move CPPHTTPLIB settings inside server
SlyEcho Jun 12, 2023
f344d09
streaming shell script
SlyEcho Jun 12, 2023
50e7c54
Merge pull request #21 from SlyEcho/server_refactor
digiwombat Jun 12, 2023
fc78910
Merge branch 'ggerganov:master' into master
digiwombat Jun 12, 2023
6d72f0f
Make chat shell script work by piping the content out of the subshell.
digiwombat Jun 12, 2023
9d564db
trim response and trim trailing space in prompt
Jun 13, 2023
9099709
Merge pull request #22 from anon998/bash-trim
digiwombat Jun 13, 2023
b8b8a6e
Add log flush
SlyEcho Jun 13, 2023
6627a02
Allow overriding the server address
SlyEcho Jun 13, 2023
1f39452
remove old verbose variable
Jun 13, 2023
99ef967
add static prefix to the other functions too
Jun 13, 2023
575cf23
remove json_indent variable
Jun 13, 2023
7df316b
fix linter warnings + make variables const
Jun 13, 2023
7a48ade
fix comment indentation
Jun 13, 2023
6075d78
Merge pull request #23 from anon998/fix-linter-warnings
digiwombat Jun 13, 2023
546f850
Update examples/server/server.cpp
SlyEcho Jun 14, 2023
bd81096
fix typo in readme + don't ignore integers
Jun 14, 2023
5e107c2
Merge pull request #24 from anon998/logit-bias
digiwombat Jun 14, 2023
f858cd6
Merge remote-tracking branch 'upstream/master'
digiwombat Jun 14, 2023
aee8595
Update README.md
digiwombat Jun 15, 2023
488c62a
Merge remote-tracking branch 'upstream/master'
digiwombat Jun 15, 2023
fb49c05
Merge branch 'ggerganov:master' into master
digiwombat Jun 16, 2023
1b4b93a
Merge branch 'ggerganov:master' into master
digiwombat Jun 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Change how the token buffers work.
There is now just embd (and last_n_tokens).

The input can also be of any length in which case it will be truncated
like it normally would.
  • Loading branch information
SlyEcho committed May 31, 2023
commit 9104fe5a7cf26a811a360ef0000c7ae195748819
264 changes: 124 additions & 140 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ struct server_params
bool verbose = false;
};

static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: statement should be inside braces [readability-braces-around-statements]

Suggested change
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What statement? ; ??

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for loops with ; tend to be error prone, i would suggest {} instead. (basically agree with clangtidy here)

return i;
}

struct llama_server_context
{
bool stream = false;
Expand All @@ -28,10 +34,7 @@ struct llama_server_context

std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
std::vector<llama_token> processed_tokens;
std::vector<llama_token> embd_inp;

std::vector<llama_token> last_prompt_tokens;
llama_context *ctx = nullptr;
gpt_params params;

Expand All @@ -55,11 +58,10 @@ struct llama_server_context
generated_text.reserve(params.n_ctx);
stopping_word = "";

//processed_tokens.clear();
embd_inp.clear();
n_remain = 0;
n_past = 0;
n_consumed = 0;
last_n_tokens.clear();
}

bool loadModel(const gpt_params &params_)
Expand All @@ -80,177 +82,159 @@ struct llama_server_context
bool loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
if (prompt_tokens == last_prompt_tokens)
{
embd.clear();

if (params.n_keep < 0) {
params.n_keep = (int)prompt_tokens.size();
}
// compare the evaluated prompt with the new prompt
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
if (prompt_tokens[n_past] != processed_tokens[n_past]) {
break;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);

// if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
prompt_tokens = new_tokens;
}
processed_tokens.resize(n_past);
if (prompt_tokens.size() > n_past) {
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());

// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == prompt_tokens.size()) {
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
last_prompt_tokens = prompt_tokens;
has_next_token = true;
return true;
}

void beginCompletion()
{
if(n_remain == 0) {
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
{
params.n_keep = (int)embd_inp.size();
}
}
// number of tokens to keep when resetting context


n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed);
}

llama_token nextToken() {
llama_token result = -1;
if (embd.size() > 0)

if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
const int n_left = (params.n_ctx - params.n_keep)/2;

std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
}

while (n_past < embd.size())
{
if (n_past + embd.size() > (size_t)params.n_ctx)
int n_eval = (int)embd.size() - n_past;
if (n_eval > params.n_batch)
{
// Reset context
const int n_left = n_past - params.n_keep;
n_past = std::max(1, params.n_keep);
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
n_eval = params.n_batch;
}
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
{
int n_eval = (int)embd.size() - i;
if (n_eval > params.n_batch)
{
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false;
return result;
}
n_past += n_eval;
fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false;
return result;
}
n_past += n_eval;
}
embd.clear();
if (embd_inp.size() <= n_consumed)

// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);

// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
logits[it->first] += it->second;
}

// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
{
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the llama_token_data{} constructor, saves us the move constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at the code around this and I'm not sure what the intended alternate solution is since the llama_token_data is being assembled in place in the for loop and emplace_back won't be too happy with just slapping the variables in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emplace_back() does forward the parameters.
if it for some reason does not work, use normal push_back(), it is equivalent when construct the object before hand.

Copy link 10000
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weirdly, I get a build error when I strip the constructor out which is why I asked without looking into it more. I've looked into the build error, and I'm still confused.

'llama_token_data::llama_token_data': no overloaded function takes 3 arguments

That's with VS2022. Not sure what the issue is since the struct definitely takes those 3 arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea, we had that issue before. Just change it to push_back() so we don't pretend to be better than we are. :)

and wait for a better world...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to push the objects? n_vocab is fixed, we could just use array indexing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I misunderstood. So throw the constructor back in, but we use push_back() to say we're sad and hope for brighter days ahead.

@SlyEcho Do we have a constructed set of llama_token_data to index from as is? It's all running up after the logit_bias changes so I assume that's why it's rebuilding the array. Or are you saying just build them one to one, like

candidates[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess technically we could keep it in a field in llama_server_context and initialize it only once. In the sampling code the loop can assign the logit value and set p = 0.0f, but no memory allocation is required any more. The token_id is constant as well.

That is assuming none of the llama_sample_* functions change it but I don't see why they would and it's easy to check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But currently it is the same in main.cpp and do we want to deviate too much? In the long run the sampling code should be pulled out to its own file (it is not actually LLaMA related, either, any kind of LLM implemented in ggml should be able to use it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote if it's good enough for main.cpp and isn't adding significant runtime to a given token response overall, it's probably fine. And I agree. Probably better to make it recognizable for anyone else who comes wandering by if new sampling methods get added or (ideally) the sampling code gets rolled out somewhere else.

}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};

// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[llama_token_nl()] = nl_logit;
}

llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};

// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
if (temp <= 0)
{
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
}
else
{
if (mirostat == 1)
{
logits[llama_token_nl()] = nl_logit;
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}

if (temp <= 0)
else if (mirostat == 2)
{
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
{
if (mirostat == 1)
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
{
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
processed_tokens.push_back(id);
num_tokens_predicted++;
}

// add it to the context
9B1B embd.push_back(id);
result = id;
// decrement remaining sampling budget
--n_remain;
}
else
{
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > n_consumed)
{
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]);
processed_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int)embd.size() >= params.n_batch)
{
break;
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
num_tokens_predicted++;
}

// add it to the context
embd.push_back(id);
result = id;
// decrement remaining sampling budget
--n_remain;

if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false;
Expand Down
0