8000 quantize-stats: add option to test against reference quantization · unbounded/llama.cpp@63cfa43 · GitHub
[go: up one dir, main page]

Skip to content

Commit 63cfa43

Browse files
committed
quantize-stats: add option to test against reference quantization
Expose reference quantization implementation and add option to use it for tests.
1 parent d491507 commit 63cfa43

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct quantize_stats_params {
2424
bool verbose = false;
2525
bool per_layer_stats = false;
2626
bool print_histogram = false;
27+
bool reference = false;
2728
std::vector<std::string> include_layers;
2829
std::vector<std::string> exclude_layers;
2930
std::vector<enum ggml_type> include_types;
@@ -49,6 +50,8 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
4950
fprintf(stderr, " -h, --help show this help message and exit\n");
5051
fprintf(stderr, " -m FNAME, --model FNAME\n");
5152
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
53+
fprintf(stderr, " -r, --reference\n");
54+
fprintf(stderr, " use reference implementation (default: false)\n");
5255
fprintf(stderr, " -v, --verbose\n");
5356
fprintf(stderr, " verbose output (default: false)\n");
5457
fprintf(stderr, " -p, --per-layer-stats\n");
@@ -135,6 +138,7 @@ void test_roundtrip_on_layer(
135138
std::string & name,
136139
bool print_layer_stats,
137140
const quantize_fns_t & qfns,
141+
bool use_reference,
138142
const ggml_tensor * layer,
139143
float * input_scratch,
140144
char *quantized_scratch,
@@ -156,7 +160,11 @@ void test_roundtrip_on_layer(
156160
input_scratch = ggml_get_data_f32(layer) + offset;
157161
}
158162

159-
qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
163+
if (use_reference) {
164+
qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size);
165+
} else {
166+
qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
167+
}
160168
qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size);
161169

162170
update_error_stats(chunk_size, input_scratch, output_scratch, total_error);
@@ -184,6 +192,8 @@ int main(int argc, char ** argv) {
184192
if (arg == "-h" || arg == "--help") {
185193
quantize_stats_print_usage(argc, argv);
186194
exit(0);
195+
} else if (arg == "-r" || arg == "--reference") {
196+
params.reference = true;
187197
} else if (arg == "-v") {
188198
params.verbose = true;
189199
} else if (arg == "-p" || arg == "--per-layer-stats") {
@@ -320,6 +330,7 @@ int main(int argc, char ** argv) {
320330
layer_name,
321331
params.per_layer_stats,
322332
qfns,
333+
params.reference,
323334
kv_tensor.second,
324335
input_scratch.data(),
325336
quantized_scratch.data(),

ggml.c

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6499,14 +6499,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
64996499

65006500
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
65016501
[GGML_TYPE_Q4_0] = {
6502-
.dequantize_row_q = dequantize_row_q4_0,
6503-
.quantize_row_q = quantize_row_q4_0,
6504-
.vec_dot_q = ggml_vec_dot_q4_0,
6502+
.dequantize_row_q = dequantize_row_q4_0,
6503+
.quantize_row_q = quantize_row_q4_0,
6504+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6505+
.vec_dot_q = ggml_vec_dot_q4_0,
65056506
},
65066507
[GGML_TYPE_Q4_1] = {
6507-
.dequantize_row_q = dequantize_row_q4_1,
6508-
.quantize_row_q = quantize_row_q4_1,
6509-
.vec_dot_q = ggml_vec_dot_q4_1,
6508+
.dequantize_row_q = dequantize_row_q4_1,
6509+
.quantize_row_q = quantize_row_q4_1,
6510+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6511+
.vec_dot_q = ggml_vec_dot_q4_1,
65106512
},
65116513
};
65126514

ggml_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restri
1515
typedef struct {
1616
dequantize_row_q_t dequantize_row_q;
1717
quantize_row_q_t quantize_row_q;
18+
quantize_row_q_t quantize_row_q_reference;
1819
vec_dot_q_t vec_dot_q;
1920
} quantize_fns_t;
2021

0 commit comments

Comments
 (0)
0