-
Notifications
You must be signed in to change notification settings - Fork 12.5k
Server Example Refactor and Improvements #1570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1c3fdf8
2071d73
421e66b
add5f1b
3537ad1
8d7b28c
c2b55cc
48cb16a
66ed19d
36c86d7
d20f36b
fdce895
e84b802
1f40a78
51e0994
f93fe36
df0e0d0
549291f
177868e
e8efd75
23928f2
2e5c5ee
dda915c
7740301
7186d65
15ddc49
74c6f36
2c9ee7a
655899d
b38d41e
6c58f64
33b6957
42cf4d8
03ea8f0
d6fff56
3292f05
38eaf2b
a25f830
2533878
e6de69a
7a853dc
aa0788b
9197674
b6f536d
7a8104f
3a079d5
9f2424a
c1cbde8
2c08f29
284bc29
f1710b9
aa2bbb2
27911d6
dd30219
40e1380
d58e486
3edaf6b
96fa480
7332b41
dda4c10
86337e3
1b96df2
276fa99
43d295f
1bd7cc6
497160a
f2e1130
9104fe5
8478e59
bed308c
342604b
e9b1f0b
5f6e16d
f7882e2
5bbc030
8c6a5fc
9531ae6
797155a
af71126
49a18bd
6025476
8cbc4be
d29b6d5
0bc0477
731ecc0
ebfead6
1488a0f
49dce94
a8a9f19
2932db1
47efbb5
88cc7bb
abb7782
bebea65
8f9e546
f820740
f5d5e70
1bd52c8
3df0192
28cc0cd
3ff27d3
41bb71b
4dd72fc
16e1c98
7cebe2e
bcd6167
de6df48
310bf61
5758e9f
e1e2be2
a6ed390
05a5a48
98ae2de
df2ecc9
64a0653
61befcb
ccd85e0
a9c3477
cc2b336
23a1b18
7580427
889d904
7cdeb08
1a9141b
917540c
d6d263f
bac0ddb
2c00bf8
9612d12
6518f9c
eee8b28
4148b9b
dff11a1
13cf692
b91200a
1510337
fc4264d
28694f7
429ed95
f344d09
50e7c54
fc78910
6d72f0f
9d564db
9099709
b8b8a6e
6627a02
1f39452
99ef967
575cf23
7df316b
7a48ade
6075d78
546f850
bd81096
5e107c2
f858cd6
aee8595
488c62a
fb49c05
1b4b93a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There is now just embd (and last_n_tokens). The input can also be of any length in which case it will be truncated like it normally would.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,12 @@ struct server_params | |
bool verbose = false; | ||
}; | ||
|
||
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) { | ||
size_t i; | ||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++); | ||
return i; | ||
} | ||
|
||
struct llama_server_context | ||
{ | ||
bool stream = false; | ||
|
@@ -28,10 +34,7 @@ struct llama_server_context | |
|
||
std::vector<llama_token> embd; | ||
std::vector<llama_token> last_n_tokens; | ||
std::vector<llama_token> processed_tokens; | ||
std::vector<llama_token> embd_inp; | ||
|
||
std::vector<llama_token> last_prompt_tokens; | ||
llama_context *ctx = nullptr; | ||
gpt_params params; | ||
|
||
|
@@ -55,11 +58,10 @@ struct llama_server_context | |
generated_text.reserve(params.n_ctx); | ||
Green-Sky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
stopping_word = ""; | ||
|
||
//processed_tokens.clear(); | ||
embd_inp.clear(); | ||
n_remain = 0; | ||
n_past = 0; | ||
n_consumed = 0; | ||
last_n_tokens.clear(); | ||
} | ||
|
||
bool loadModel(const gpt_params ¶ms_) | ||
|
@@ -80,177 +82,159 @@ struct llama_server_context | |
bool loadPrompt() { | ||
params.prompt.insert(0, 1, ' '); // always add a first space | ||
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); | ||
if (prompt_tokens == last_prompt_tokens) | ||
{ | ||
embd.clear(); | ||
|
||
if (params.n_keep < 0) { | ||
params.n_keep = (int)prompt_tokens.size(); | ||
} | ||
// compare the evaluated prompt with the new prompt | ||
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) { | ||
if (prompt_tokens[n_past] != processed_tokens[n_past]) { | ||
break; | ||
} | ||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); | ||
|
||
// if input prompt is too big, truncate like normal | ||
if (prompt_tokens.size() >= (size_t)params.n_ctx) { | ||
const int n_left = (params.n_ctx - params.n_keep)/2; | ||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); | ||
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); | ||
prompt_tokens = new_tokens; | ||
} | ||
processed_tokens.resize(n_past); | ||
if (prompt_tokens.size() > n_past) { | ||
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end()); | ||
|
||
// compare the evaluated prompt with the new prompt | ||
n_past = common_part(embd, prompt_tokens); | ||
embd = prompt_tokens; | ||
if (n_past == prompt_tokens.size()) { | ||
// we have to evaluate at least 1 token to generate logits. | ||
n_past--; | ||
} | ||
last_prompt_tokens = prompt_tokens; | ||
has_next_token = true; | ||
return true; | ||
} | ||
|
||
void beginCompletion() | ||
{ | ||
if(n_remain == 0) { | ||
// number of tokens to keep when resetting context | ||
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size()) | ||
{ | ||
params.n_keep = (int)embd_inp.size(); | ||
} | ||
} | ||
// number of tokens to keep when resetting context | ||
|
||
|
||
n_remain = params.n_predict; | ||
llama_set_rng_seed(ctx, params.seed); | ||
} | ||
|
||
llama_token nextToken() { | ||
llama_token result = -1; | ||
if (embd.size() > 0) | ||
|
||
if (embd.size() >= (size_t)params.n_ctx) { | ||
// Reset context | ||
const int n_left = (params.n_ctx - params.n_keep)/2; | ||
|
||
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep); | ||
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); | ||
embd = new_tokens; | ||
n_past = params.n_keep; | ||
} | ||
|
||
while (n_past < embd.size()) | ||
{ | ||
if (n_past + embd.size() > (size_t)params.n_ctx) | ||
int n_eval = (int)embd.size() - n_past; | ||
if (n_eval > params.n_batch) | ||
{ | ||
// Reset context | ||
const int n_left = n_past - params.n_keep; | ||
n_past = std::max(1, params.n_keep); | ||
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end()); | ||
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size()); | ||
n_eval = params.n_batch; | ||
} | ||
for (int i = 0; i < (int)embd.size(); i += params.n_batch) | ||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) | ||
{ | ||
int n_eval = (int)embd.size() - i; | ||
if (n_eval > params.n_batch) | ||
{ | ||
n_eval = params.n_batch; | ||
} | ||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) | ||
{ | ||
fprintf(stderr, "%s : failed to eval\n", __func__); | ||
has_next_token = false; | ||
return result; | ||
} | ||
n_past += n_eval; | ||
fprintf(stderr, "%s : failed to eval\n", __func__); | ||
has_next_token = false; | ||
return result; | ||
} | ||
n_past += n_eval; | ||
} | ||
embd.clear(); | ||
if (embd_inp.size() <= n_consumed) | ||
|
||
// out of user input, sample next token | ||
const float temp = params.temp; | ||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; | ||
const float top_p = params.top_p; | ||
const float tfs_z = params.tfs_z; | ||
const float typical_p = params.typical_p; | ||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; | ||
const float repeat_penalty = params.repeat_penalty; | ||
const float alpha_presence = params.presence_penalty; | ||
const float alpha_frequency = params.frequency_penalty; | ||
const int mirostat = params.mirostat; | ||
const float mirostat_tau = params.mirostat_tau; | ||
const float mirostat_eta = params.mirostat_eta; | ||
const bool penalize_nl = params.penalize_nl; | ||
llama_token id = 0; | ||
{ | ||
// out of user input, sample next token | ||
const float temp = params.temp; | ||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; | ||
const float top_p = params.top_p; | ||
const float tfs_z = params.tfs_z; | ||
const float typical_p = params.typical_p; | ||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; | ||
const float repeat_penalty = params.repeat_penalty; | ||
const float alpha_presence = params.presence_penalty; | ||
const float alpha_frequency = params.frequency_penalty; | ||
const int mirostat = params.mirostat; | ||
const float mirostat_tau = params.mirostat_tau; | ||
const float mirostat_eta = params.mirostat_eta; | ||
const bool penalize_nl = params.penalize_nl; | ||
llama_token id = 0; | ||
auto logits = llama_get_logits(ctx); | ||
SlyEcho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto n_vocab = llama_n_vocab(ctx); | ||
|
||
// Apply params.logit_bias map | ||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) | ||
{ | ||
auto logits = llama_get_logits(ctx); | ||
auto n_vocab = llama_n_vocab(ctx); | ||
logits[it->first] += it->second; | ||
} | ||
|
||
// Apply params.logit_bias map | ||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) | ||
{ | ||
logits[it->first] += it->second; | ||
} | ||
std::vector<llama_token_data> candidates; | ||
candidates.reserve(n_vocab); | ||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) | ||
{ | ||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took a look at the code around this and I'm not sure what the intended alternate solution is since the llama_token_data is being assembled in place in the for loop and emplace_back won't be too happy with just slapping the variables in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weirdly, I get a build error when I strip the constructor out which is why I asked without looking into it more. I've looked into the build error, and I'm still confused.
That's with VS2022. Not sure what the issue is since the struct definitely takes those 3 arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yea, we had that issue before. Just change it to and wait for a better world... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to push the objects? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I misunderstood. So throw the constructor back in, but we use push_back() to say we're sad and hope for brighter days ahead. @SlyEcho Do we have a constructed set of llama_token_data to index from as is? It's all running up after the logit_bias changes so I assume that's why it's rebuilding the array. Or are you saying just build them one to one, like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess technically we could keep it in a field in That is assuming none of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But currently it is the same in main.cpp and do we want to deviate too much? In the long run the sampling code should be pulled out to its own file (it is not actually LLaMA related, either, any kind of LLM implemented in ggml should be able to use it) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I vote if it's good enough for main.cpp and isn't adding significant runtime to a given token response overall, it's probably fine. And I agree. Probably better to make it recognizable for anyone else who comes wandering by if new sampling methods get added or (ideally) the sampling code gets rolled out somewhere else. |
||
} | ||
|
||
std::vector<llama_token_data> candidates; | ||
candidates.reserve(n_vocab); | ||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) | ||
{ | ||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||
} | ||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; | ||
|
||
// Apply penalties | ||
float nl_logit = logits[llama_token_nl()]; | ||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); | ||
llama_sample_repetition_penalty(ctx, &candidates_p, | ||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||
last_n_repeat, repeat_penalty); | ||
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, | ||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||
last_n_repeat, alpha_frequency, alpha_presence); | ||
if (!penalize_nl) | ||
{ | ||
logits[llama_token_nl()] = nl_logit; | ||
} | ||
|
||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; | ||
|
||
// Apply penalties | ||
float nl_logit = logits[llama_token_nl()]; | ||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); | ||
llama_sample_repetition_penalty(ctx, &candidates_p, | ||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||
last_n_repeat, repeat_penalty); | ||
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, | ||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||
last_n_repeat, alpha_frequency, alpha_presence); | ||
if (!penalize_nl) | ||
if (temp <= 0) | ||
{ | ||
// Greedy sampling | ||
id = llama_sample_token_greedy(ctx, &candidates_p); | ||
} | ||
else | ||
{ | ||
if (mirostat == 1) | ||
{ | ||
logits[llama_token_nl()] = nl_logit; | ||
static float mirostat_mu = 2.0f * mirostat_tau; | ||
const int mirostat_m = 100; | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||
} | ||
|
||
if (temp <= 0) | ||
else if (mirostat == 2) | ||
{ | ||
// Greedy sampling | ||
id = llama_sample_token_greedy(ctx, &candidates_p); | ||
static float mirostat_mu = 2.0f * mirostat_tau; | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||
} | ||
else | ||
{ | ||
if (mirostat == 1) | ||
{ | ||
static float mirostat_mu = 2.0f * mirostat_tau; | ||
const int mirostat_m = 100; | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); | ||
} | ||
else if (mirostat == 2) | ||
{ | ||
static float mirostat_mu = 2.0f * mirostat_tau; | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); | ||
} | ||
else | ||
{ | ||
// Temperature sampling | ||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); | ||
llama_sample_typical(ctx, &candidates_p, typical_p, 1); | ||
llama_sample_top_p(ctx, &candidates_p, top_p, 1); | ||
llama_sample_top_k(ctx, &candidates_p, top_k, 1); | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token(ctx, &candidates_p); | ||
} | ||
} | ||
last_n_tokens.erase(last_n_tokens.begin()); | ||
last_n_tokens.push_back(id); | ||
processed_tokens.push_back(id); | ||
num_tokens_predicted++; | ||
} | ||
|
||
// add it to the context | ||
9B1B embd.push_back(id); | ||
result = id; | ||
// decrement remaining sampling budget | ||
--n_remain; | ||
} | ||
else | ||
{ | ||
// some user input remains from prompt or interaction, forward it to processing | ||
while (embd_inp.size() > n_consumed) | ||
{ | ||
embd.push_back(embd_inp[n_consumed]); | ||
last_n_tokens.erase(last_n_tokens.begin()); | ||
last_n_tokens.push_back(embd_inp[n_consumed]); | ||
processed_tokens.push_back(embd_inp[n_consumed]); | ||
++n_consumed; | ||
if ((int)embd.size() >= params.n_batch) | ||
{ | ||
break; | ||
// Temperature sampling | ||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); | ||
llama_sample_typical(ctx, &candidates_p, typical_p, 1); | ||
llama_sample_top_p(ctx, &candidates_p, top_p, 1); | ||
llama_sample_top_k(ctx, &candidates_p, top_k, 1); | ||
llama_sample_temperature(ctx, &candidates_p, temp); | ||
id = llama_sample_token(ctx, &candidates_p); | ||
} | ||
} | ||
last_n_tokens.erase(last_n_tokens.begin()); | ||
last_n_tokens.push_back(id); | ||
num_tokens_predicted++; | ||
} | ||
|
||
// add it to the context | ||
embd.push_back(id); | ||
result = id; | ||
// decrement remaining sampling budget | ||
--n_remain; | ||
|
||
if (!embd.empty() && embd.back() == llama_token_eos()) { | ||
stopping_word = llama_token_to_str(ctx, embd.back()); | ||
has_next_token = false; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: statement should be inside braces [readability-braces-around-statements]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What statement?
;
??There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for loops with
;
tend to be error prone, i would suggest{}
instead. (basically agree with clangtidy here)