8000 Be more strict about converting float to double by sw · Pull Request #458 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Be more strict about converting float to double #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Be more strict about converting float to double
  • Loading branch information
sw authored and ggerganov committed Mar 28, 2023
commit 54b75a77fb58292c6dde89d213e82fae5b171d68
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ if (LLAMA_ALL_WARNINGS)
-Wall
-Wextra
-Wpedantic
-Wshadow
-Wcast-qual
-Wdouble-promotion
-Wshadow
-Wstrict-prototypes
-Wpointer-arith
-Wno-unused-function
)
set(cxx_flags
-Wall
-Wextra
-Wpedantic
-Wcast-qual
-Wdouble-promotion
)
else()
# todo : msvc
Expand Down
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ endif
#

# keep standard at C11 and C++11
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC \
-Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC \
-Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion
LDFLAGS =

# OS specific
Expand Down
6 changes: 3 additions & 3 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ int main(int argc, char ** argv) {
const auto embeddings = llama_get_embeddings(ctx);

for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);
printf("%f ", (double)embeddings[i]);
}
printf("\n");
}
Expand Down
11 changes: 6 additions & 5 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
(double)params.temp, params.top_k, (double)params.top_p, params.repeat_last_n, (double)params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");

Expand Down Expand Up @@ -274,10 +275,10 @@ int main(int argc, char ** argv) {

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
const int top_k = params.top_k;
const double top_p = (double)params.top_p;
const double temp = (double)params.temp;
const double repeat_penalty = (double)params.repeat_penalty;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make the params struct have doubles?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to change it too much, but we could alternatively make everything involved with the logits a float instead, except maybe for sum and cumsum in llama_sample_top_p_top_k.

After all, these three parameters are set by the user with 2 decimal places or so...


llama_token id = 0;

Expand Down
4 changes: 2 additions & 2 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
const int64_t t_main_end_us = ggml_time_us();

printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
}

return 0;
Expand Down
Loading
0