@@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) {
55
55
class Opt {
56
56
public:
57
57
int init (int argc, const char ** argv) {
58
+ ctx_params = llama_context_default_params ();
59
+ model_params = llama_model_default_params ();
60
+ context_size_default = ctx_params.n_batch ;
61
+ ngl_default = model_params.n_gpu_layers ;
62
+ common_params_sampling sampling;
63
+ temperature_default = sampling.temp ;
64
+
65
+ if (argc < 2 ) {
66
+ printe (" Error: No arguments provided.\n " );
67
+ print_help ();
68
+ return 1 ;
69
+ }
70
+
58
71
// Parse arguments
59
72
if (parse (argc, argv)) {
60
73
printe (" Error: Failed to parse arguments.\n " );
61
- help ();
74
+ print_help ();
62
75
return 1 ;
63
76
}
64
77
65
78
// If help is requested, show help and exit
66
- if (help_ ) {
67
- help ();
79
+ if (help ) {
80
+ print_help ();
68
81
return 2 ;
69
82
}
70
83
84
+ ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
85
+ model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
86
+ temperature = temperature >= 0 ? temperature : temperature_default;
87
+
71
88
return 0 ; // Success
72
89
}
73
90
91
+ llama_context_params ctx_params;
92
+ llama_model_params model_params;
74
93
std::string model_;
75
- std::string user_;
76
- int context_size_ = -1 , ngl_ = -1 ;
77
- bool verbose_ = false ;
94
+ std::string user;
95
+ int context_size = -1 , ngl = -1 ;
96
+ float temperature = -1 ;
97
+ bool verbose = false ;
78
98
79
99
private:
80
- bool help_ = false ;
100
+ int context_size_default = -1 , ngl_default = -1 ;
101
+ float temperature_default = -1 ;
102
+ bool help = false ;
81
103
82
104
bool parse_flag (const char ** argv, int i, const char * short_opt, const char * long_opt) {
83
105
return strcmp (argv[i], short_opt) == 0 || strcmp (argv[i], long_opt) == 0 ;
@@ -89,25 +111,40 @@ class Opt {
89
111
}
90
112
91
113
option_value = std::atoi (argv[++i]);
114
+
115
+ return 0 ;
116
+ }
117
+
118
+ int handle_option_with_value (int argc, const char ** argv, int & i, float & option_value) {
119
+ if (i + 1 >= argc) {
120
+ return 1 ;
121
+ }
122
+
123
+ option_value = std::atof (argv[++i]);
124
+
92
125
return 0 ;
93
126
}
94
127
95
128
int parse (int argc, const char ** argv) {
96
129
bool options_parsing = true ;
97
130
for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
98
131
if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
99
- if (handle_option_with_value (argc, argv, i, context_size_ ) == 1 ) {
132
+ if (handle_option_with_value (argc, argv, i, context_size ) == 1 ) {
100
133
return 1 ;
101
134
}
102
135
} else if (options_parsing && (strcmp (argv[i], " -n" ) == 0 || strcmp (argv[i], " --ngl" ) == 0 )) {
103
- if (handle_option_with_value (argc, argv, i, ngl_) == 1 ) {
136
+ if (handle_option_with_value (argc, argv, i, ngl) == 1 ) {
137
+ return 1 ;
138
+ }
139
+ } else if (options_parsing && strcmp (argv[i], " --temp" ) == 0 ) {
140
+ if (handle_option_with_value (argc, argv, i, temperature) == 1 ) {
104
141
return 1 ;
105
142
}
106
143
} else if (options_parsing &&
107
144
(parse_flag (argv, i, " -v" , " --verbose" ) || parse_flag (argv, i, " -v" , " --log-verbose" ))) {
108
- verbose_ = true ;
145
+ verbose = true ;
109
146
} else if (options_parsing && parse_flag (argv, i, " -h" , " --help" )) {
110
- help_ = true ;
147
+ help = true ;
111
148
return 0 ;
112
149
} else if (options_parsing && strcmp (argv[i], " --" ) == 0 ) {
113
150
options_parsing = false ;
@@ -120,16 +157,16 @@ class Opt {
120
157
model_ = argv[i];
121
158
} else if (positional_args_i == 1 ) {
122
159
++positional_args_i;
123
- user_ = argv[i];
160
+ user = argv[i];
124
161
} else {
125
- user_ += " " + std::string (argv[i]);
162
+ user += " " + std::string (argv[i]);
126
163
}
127
164
}
128
165
129
166
return 0 ;
130
167
}
131
168
132
- void help () const {
169
+ void print_help () const {
133
170
printf (
134
171
" Description:\n "
135
172
" Runs a llm\n "
@@ -142,6 +179,8 @@ class Opt {
142
179
" Context size (default: %d)\n "
143
180
" -n, --ngl <value>\n "
144
181
" Number of GPU layers (default: %d)\n "
182
+ " --temp <value>\n "
183
+ " Temperature (default: %.1f)\n "
145
184
" -v, --verbose, --log-verbose\n "
146
185
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n "
147
186
" -h, --help\n "
@@ -170,7 +209,7 @@ class Opt {
170
209
" llama-run file://some-file3.gguf\n "
171
210
" llama-run --ngl 999 some-file4.gguf\n "
172
211
" llama-run --ngl 999 some-file5.gguf Hello World\n " ,
173
- llama_context_default_params (). n_batch , llama_model_default_params (). n_gpu_layers );
212
+ context_size_default, ngl_default, temperature_default );
174
213
}
175
214
};
176
215
@@ -495,12 +534,12 @@ class LlamaData {
495
534
return 1 ;
496
535
}
497
536
498
- context = initialize_context (model, opt. context_size_ );
537
+ context = initialize_context (model, opt);
499
538
if (!context) {
500
539
return 1 ;
501
540
}
502
541
503
- sampler = initialize_sampler ();
542
+ sampler = initialize_sampler (opt );
504
543
return 0 ;
505
544
}
506
545
@@ -619,14 +658,12 @@ class LlamaData {
619
658
// Initializes the model and returns a unique pointer to it
620
659
llama_model_ptr initialize_model (Opt & opt) {
621
660
ggml_backend_load_all ();
622
- llama_model_params model_params = llama_model_default_params ();
623
- model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers ;
624
661
resolve_model (opt.model_ );
625
662
printe (
626
663
" \r %*s"
627
664
" \r Loading model" ,
628
665
get_terminal_width (), " " );
629
- llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), model_params));
666
+ llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), opt. model_params ));
630
667
if (!model) {
631
668
printe (" %s: error: unable to load model from file: %s\n " , __func__, opt.model_ .c_str ());
632
669
}
@@ -636,10 +673,8 @@ class LlamaData {
636
673
}
637
674
638
675
// Initializes the context with the specified parameters
639
- llama_context_ptr initialize_context (const llama_model_ptr & model, const int n_ctx) {
640
- llama_context_params ctx_params = llama_context_default_params ();
641
- ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch ;
642
- llama_context_ptr context (llama_new_context_with_model (model.get (), ctx_params));
676
+ llama_context_ptr initialize_context (const llama_model_ptr & model, const Opt & opt) {
677
+ llama_context_ptr context (llama_new_context_with_model (model.get (), opt.ctx_params ));
643
678
if (!context) {
644
679
printe (" %s: error: failed to create the llama_context\n " , __func__);
645
680
}
@@ -648,10 +683,10 @@ class LlamaData {
648
683
}
649
684
650
685
// Initializes and configures the sampler
651
- llama_sampler_ptr initialize_sampler () {
686
+ llama_sampler_ptr initialize_sampler (const Opt & opt ) {
652
687
llama_sampler_ptr sampler (llama_sampler_chain_init (llama_sampler_chain_default_params ()));
653
688
llama_sampler_chain_add (sampler.get (), llama_sampler_init_min_p (0 .05f , 1 ));
654
- llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (0 . 8f ));
689
+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (opt. temperature ));
655
690
llama_sampler_chain_add (sampler.get (), llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
656
691
657
692
return sampler;
@@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const
798
833
}
799
834
800
835
// Helper function to handle user input
801
- static int handle_user_input (std::string & user_input, const std::string & user_ ) {
802
- if (!user_ .empty ()) {
803
- user_input = user_ ;
836
+ static int handle_user_input (std::string & user_input, const std::string & user ) {
837
+ if (!user .empty ()) {
838
+ user_input = user ;
804
839
return 0 ; // No need for interactive input
805
840
}
806
841
@@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() {
832
867
}
833
868
834
869
// Function to tokenize the prompt
835
- static int chat_loop (LlamaData & llama_data, const std::string & user_ ) {
870
+ static int chat_loop (LlamaData & llama_data, const std::string & user ) {
836
871
int prev_len = 0 ;
837
872
llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
838
873
static const bool stdout_a_terminal = is_stdout_a_terminal ();
839
874
while (true ) {
840
875
// Get user input
841
876
std::string user_input;
842
- while (handle_user_input (user_input, user_ )) {
877
+ while (handle_user_input (user_input, user )) {
843
878
}
844
879
845
- add_message (" user" , user_ .empty () ? user_input : user_ , llama_data);
880
+ add_message (" user" , user .empty () ? user_input : user , llama_data);
846
881
int new_len;
847
882
if (apply_chat_template_with_error_handling (llama_data, true , new_len) < 0 ) {
848
883
return 1 ;
@@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
854
889
return 1 ;
855
890
}
856
891
857
- if (!user_ .empty ()) {
892
+ if (!user .empty ()) {
858
893
break ;
859
894
}
860
895
@@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
869
904
870
905
static void log_callback (const enum ggml_log_level level, const char * text, void * p) {
871
906
const Opt * opt = static_cast <Opt *>(p);
872
- if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
907
+ if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
873
908
printe (" %s" , text);
874
909
}
875
910
}
@@ -890,11 +925,11 @@ int main(int argc, const char ** argv) {
890
925
}
891
926
892
927
if (!is_stdin_a_terminal ()) {
893
- if (!opt.user_ .empty ()) {
894
- opt.user_ += " \n\n " ;
928
+ if (!opt.user .empty ()) {
929
+ opt.user += " \n\n " ;
895
930
}
896
931
897
- opt.user_ += read_pipe_data ();
932
+ opt.user += read_pipe_data ();
898
933
}
899
934
900
935
llama_log_set (log_callback, &opt);
@@ -903,7 +938,7 @@ int main(int argc, const char ** argv) {
903
938
return 1 ;
904
939
}
905
940
906
- if (chat_loop (llama_data, opt.user_ )) {
941
+ if (chat_loop (llama_data, opt.user )) {
907
942
return 1 ;
908
943
}
909
944
0 commit comments