@@ -24,6 +24,7 @@ struct quantize_stats_params {
24
24
bool verbose = false ;
25
25
bool per_layer_stats = false ;
26
26
bool print_histogram = false ;
27
+ bool reference = false ;
27
28
std::vector<std::string> include_layers;
28
29
std::vector<std::string> exclude_layers;
29
30
std::vector<enum ggml_type> include_types;
@@ -49,6 +50,8 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
49
50
fprintf (stderr, " -h, --help show this help message and exit\n " );
50
51
fprintf (stderr, " -m FNAME, --model FNAME\n " );
51
52
fprintf (stderr, " model path (default: %s)\n " , params.model .c_str ());
53
+ fprintf (stderr, " -r, --reference\n " );
54
+ fprintf (stderr, " use reference implementation (default: false)\n " );
52
55
fprintf (stderr, " -v, --verbose\n " );
53
56
fprintf (stderr, " verbose output (default: false)\n " );
54
57
fprintf (stderr, " -p, --per-layer-stats\n " );
@@ -135,6 +138,7 @@ void test_roundtrip_on_layer(
135
138
std::string & name,
136
139
bool print_layer_stats,
137
140
const quantize_fns_t & qfns,
141
+ bool use_reference,
138
142
const ggml_tensor * layer,
139
143
float * input_scratch,
140
144
char *quantized_scratch,
@@ -156,7 +160,11 @@ void test_roundtrip_on_layer(
156
160
input_scratch = ggml_get_data_f32 (layer) + offset;
157
161
}
158
162
159
- qfns.quantize_row_q (input_scratch, quantized_scratch, chunk_size);
163
+ if (use_reference) {
164
+ qfns.quantize_row_q_reference (input_scratch, quantized_scratch, chunk_size);
165
+ } else {
166
+ qfns.quantize_row_q (input_scratch, quantized_scratch, chunk_size);
167
+ }
160
168
qfns.dequantize_row_q (quantized_scratch, output_scratch, chunk_size);
161
169
162
170
update_error_stats (chunk_size, input_scratch, output_scratch, total_error);
@@ -184,6 +192,8 @@ int main(int argc, char ** argv) {
184
192
if (arg == " -h" || arg == " --help" ) {
185
193
quantize_stats_print_usage (argc, argv);
186
194
exit (0 );
195
+ } else if (arg == " -r" || arg == " --reference" ) {
196
+ params.reference = true ;
187
197
} else if (arg == " -v" ) {
188
198
params.verbose = true ;
189
199
} else if (arg == " -p" || arg == " --per-layer-stats" ) {
@@ -320,6 +330,7 @@ int main(int argc, char ** argv) {
320
330
layer_name,
321
331
params.per_layer_stats ,
322
332
qfns,
333
+ params.reference ,
323
334
kv_tensor.second ,
324
335
input_scratch.data (),
325
336
quantized_scratch.data (),
0 commit comments