10000 llama-run : include temperature option (#10899) · mglambda/llama.cpp@1e2eee3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e2eee3

Browse files
ericcurtinmglambda
authored andcommitted
llama-run : include temperature option (ggml-org#10899)
This commit updates the `examples/run/README.md` file to include a new option for setting the temperature and updates the `run.cpp` file to parse this option. Signed-off-by: Eric Curtin <ecurtin@redhat.com>
1 parent 99bb723 commit 1e2eee3

File tree

2 files changed

+75
-38
lines changed

2 files changed

+75
-38
lines changed

examples/run/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Options:
1919
Context size (default: 2048)
2020
-n, --ngl <value>
2121
Number of GPU layers (default: 0)
22+
--temp <value>
23+
Temperature (default: 0.8)
2224
-v, --verbose, --log-verbose
2325
Set verbosity level to infinity (i.e. log all messages, useful for debugging)
2426
-h, --help

examples/run/run.cpp

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) {
5555
class Opt {
5656
public:
5757
int init(int argc, const char ** argv) {
58+
ctx_params = llama_context_default_params();
59+
model_params = llama_model_default_params();
60+
context_size_default = ctx_params.n_batch;
61+
ngl_default = model_params.n_gpu_layers;
62+
common_params_sampling sampling;
63+
temperature_default = sampling.temp;
64+
65+
if (argc < 2) {
66+
printe("Error: No arguments provided.\n");
67+
print_help();
68+
return 1;
69+
}
70+
5871
// Parse arguments
5972
if (parse(argc, argv)) {
6073
printe("Error: Failed to parse arguments.\n");
61-
help();
74+
print_help();
6275
return 1;
6376
}
6477

6578
// If help is requested, show help and exit
66-
if (help_) {
67-
help();
79+
if (help) {
80+
print_help();
6881
return 2;
6982
}
7083

84+
ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
85+
model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
86+
temperature = temperature >= 0 ? temperature : temperature_default;
87+
7188
return 0; // Success
7289
}
7390

91+
llama_context_params ctx_params;
92+
llama_model_params model_params;
7493
std::string model_;
75-
std::string user_;
76-
int context_size_ = -1, ngl_ = -1;
77-
bool verbose_ = false;
94+
std::string user;
95+
int context_size = -1, ngl = -1;
96+
float temperature = -1;
97+
bool verbose = false;
7898

7999
private:
80-
bool help_ = false;
100+
int context_size_default = -1, ngl_default = -1;
101+
float temperature_default = -1;
102+
bool help = false;
81103

82104
bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
83105
return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
@@ -89,25 +111,40 @@ class Opt {
89111
}
90112

91113
option_value = std::atoi(argv[++i]);
114+
115+
return 0;
116+
}
117+
118+
int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
119+
if (i + 1 >= argc) {
120+
return 1;
121+
}
122+
123+
option_value = std::atof(argv[++i]);
124+
92125
return 0;
93126
}
94127

95128
int parse(int argc, const char ** argv) {
96129
bool options_parsing = true;
97130
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
98131
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
99-
if (handle_option_with_value(argc, argv, i, context_size_) == 1) {
132+
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
100133
return 1;
101134
}
102135
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
103-
if (handle_option_with_value(argc, argv, i, ngl_) == 1) {
136+
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
137+
return 1;
138+
}
139+
} else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
140+
if (handle_option_with_value(argc, argv, i, temperature) == 1) {
104141
return 1;
105142
}
106143
} else if (options_parsing &&
107144
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
108-
verbose_ = true;
145+
verbose = true;
109146
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
110-
help_ = true;
147+
help = true;
111148
return 0;
112149
} else if (options_parsing && strcmp(argv[i], "--") == 0) {
113150
options_parsing = false;
@@ -120,16 +157,16 @@ class Opt {
120157
model_ = argv[i];
121158
} else if (positional_args_i == 1) {
122159
++positional_args_i;
123-
user_ = argv[i];
160+
user = argv[i];
124161
} else {
125-
user_ += " " + std::string(argv[i]);
162+
user += " " + std::string(argv[i]);
126163
}
127164
}
128165

129166
return 0;
130167
}
131168

132-
void help() const {
169+
void print_help() const {
133170
printf(
134171
"Description:\n"
135172
" Runs a llm\n"
@@ -142,6 +179,8 @@ class Opt {
142179
" Context size (default: %d)\n"
143180
" -n, --ngl <value>\n"
144181
" Number of GPU layers (default: %d)\n"
182+
" --temp <value>\n"
183+
" Temperature (default: %.1f)\n"
145184
" -v, --verbose, --log-verbose\n"
146185
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
147186
" -h, --help\n"
@@ -170,7 +209,7 @@ class Opt {
170209
" llama-run file://some-file3.gguf\n"
171210
" llama-run --ngl 999 some-file4.gguf\n"
172211
" llama-run --ngl 999 some-file5.gguf Hello World\n",
173-
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
212+
context_size_default, ngl_default, temperature_default);
174213
}
175214
};
176215

@@ -495,12 +534,12 @@ class LlamaData {
495534
return 1;
496535
}
497536

498-
context = initialize_context(model, opt.context_size_);
537+
context = initialize_context(model, opt);
499538
if (!context) {
500539
return 1;
501540
}
502541

503-
sampler = initialize_sampler();
542+
sampler = initialize_sampler(opt);
504543
return 0;
505544
}
506545

@@ -619,14 +658,12 @@ class LlamaData {
619658
// Initializes the model and returns a unique pointer to it
620659
llama_model_ptr initialize_model(Opt & opt) {
621660
ggml_backend_load_all();
622-
llama_model_params model_params = llama_model_default_params();
623-
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
624661
resolve_model(opt.model_);
625662
printe(
626663
"\r%*s"
627664
"\rLoading model",
628665
get_terminal_width(), " ");
629-
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
666+
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params));
630667
if (!model) {
631668
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
632669
}
@@ -636,10 +673,8 @@ class LlamaData {
636673
}
637674

638675
// Initializes the context with the specified parameters
639-
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
640-
llama_context_params ctx_params = llama_context_default_params();
641-
ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
642-
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
676+
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
677+
llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params));
643678
if (!context) {
644679
printe("%s: error: failed to create the llama_context\n", __func__);
645680
}
@@ -648,10 +683,10 @@ class LlamaData {
648683
}
649684

650685
// Initializes and configures the sampler
651-
llama_sampler_ptr initialize_sampler() {
686+
llama_sampler_ptr initialize_sampler(const Opt & opt) {
652687
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
653688
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
654-
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
689+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature));
655690
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
656691

657692
return sampler;
@@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const
798833
}
799834

800835
// Helper function to handle user input
801-
static int handle_user_input(std::string & user_input, const std::string & user_) {
802-
if (!user_.empty()) {
803-
user_input = user_;
836+
static int handle_user_input(std::string & user_input, const std::string & user) {
837+
if (!user.empty()) {
838+
user_input = user;
804839
return 0; // No need for interactive input
805840
}
806841

@@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() {
832867
}
833868

834869
// Function to tokenize the prompt
835-
static int chat_loop(LlamaData & llama_data, const std::string & user_) {
870+
static int chat_loop(LlamaData & llama_data, const std::string & user) {
836871
int prev_len = 0;
837872
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
838873
static const bool stdout_a_terminal = is_stdout_a_terminal();
839874
while (true) {
840875
// Get user input
841876
std::string user_input;
842-
while (handle_user_input(user_input, user_)) {
877+
while (handle_user_input(user_input, user)) {
843878
}
844879

845-
add_message("user", user_.empty() ? user_input : user_, llama_data);
880+
add_message("user", user.empty() ? user_input : user, llama_data);
846881
int new_len;
847882
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
848883
return 1;
@@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
854889
return 1;
855890
}
856891

857-
if (!user_.empty()) {
892+
if (!user.empty()) {
858893
break;
859894
}
860895

@@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
869904

870905
static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
871906
const Opt * opt = static_cast<Opt *>(p);
872-
if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
907+
if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
873908
printe("%s", text);
874909
}
875910
}
@@ -890,11 +925,11 @@ int main(int argc, const char ** argv) {
890925
}
891926

892927
if (!is_stdin_a_terminal()) {
893-
if (!opt.user_.empty()) {
894-
opt.user_ += "\n\n";
928+
if (!opt.user.empty()) {
929+
opt.user += "\n\n";
895930
}
896931

897-
opt.user_ += read_pipe_data();
932+
opt.user += read_pipe_data();
898933
}
899934

900935
llama_log_set(log_callback, &opt);
@@ -903,7 +938,7 @@ int main(int argc, const char ** argv) {
903938
return 1;
904939
}
905940

906-
if (chat_loop(llama_data, opt.user_)) {
941+
if (chat_loop(llama_data, opt.user)) {
907942
return 1;
908943
}
909944

0 commit comments

Comments
 (0)
0