8000 k-quants by ikawrakow · Pull Request #1684 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

k-quants #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8673a41
Starting to add k-quantization to ggml
May 27, 2023
b4f7134
Adding Q3_K and Q8_K (de)-quantization
May 27, 2023
c93cce3
Q3_K now working on CUDA and AVX2/scalar
May 28, 2023
a3c0673
Some improvement for Q3_K on CUDA
May 28, 2023
3d8b1de
Some more CUDA optimizations for Q3_K
May 29, 2023
a0b8e9f
Adding Q4_K - scalar, AVX2, CUDA
May 29, 2023
cf221af
Adding Q6_K - scalar, AVX2, CUDA
May 29, 2023
b835d0f
Adding Q5_K - scalar, AVX2, CUDA
May 29, 2023
5c5191a
Per convention, all QX_K quantizations use Q5_K for output.weight
May 29, 2023
d537b97
Adding quantization mixes
May 29, 2023
54f808d
Quantization mixes: didn't quite get what I wanted in the last commit
May 29, 2023
a2533a7
Q4_K dot product for ARM_NEON
May 30, 2023
5ca15ce
Q6_K dot product for ARM_NEON
May 30, 2023
a197eb5
Q5_K dot product for ARM_NEON
May 30, 2023
13264fa
Adding Q3_K dot for ARM_NEON
May 30, 2023
4faa040
A very slightly faster ARM_NEON Q3_K dot
May 31, 2023
b439efb
Adding Q2_K - just CUDA for now
May 31, 2023
8516fdf
Adding scalar and AVX2 Q2_K dot
May 31, 2023
6ec7057
Adding ARM_NEON Q2_K dot
May 31, 2023
7bcc376
A slightly faster ARM_NEON Q2_K dot
Jun 1, 2023
e51ce72
Fixed bug in Q2_K CUDA dot product kernel
Jun 1, 2023
c5959d5
Don't print zeros/NaNs when no count histogram has been collected
Jun 1, 2023
9a9c5a0
A 10% faster CUDA vector dot kernel for Q3_K
Jun 1, 2023
894210a
A slightly daster Q4_K AVX2 dot product
Jun 2, 2023
abd99a8
A slightly faster ARM_NEON A4_K dot product
Jun 3, 2023
8f5d42d
Minor
Jun 3, 2023
6ef1382
Fix quantization error test
Jun 3, 2023
0a71a4e
Fix docker build
Jun 3, 2023
431693c
Added forgotten ggml.o dependence on k_quants.h to the Makefile
Jun 4, 2023
32a5f3a
Had unintentionally committed the Makefile with -Ofast enabled
Jun 4, 2023
12d4344
ggml : rename k_quants -> ggml-quants-k, use lowercase in code
ggerganov Jun 5, 2023
af275fa
Merge branch 'master' into ik/k_quants
ggerganov Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding Q4_K - scalar, AVX2, CUDA
Performance is the same or perhaps very slightly better than Q4_0 on the CPU.
On the GPU, single token prediction is ~10% better than Q4_0,
batch mode (perplexity is about the same).
  • Loading branch information
Iwan Kawrakow committed Jun 3, 2023
commit a0b8e9f3c90e482dbe0ca82f45f585de24f1ba67
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
{"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
{"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K},
{"q4_K", LLAMA_FTYPE_MOSTLY_Q4_K},
};

bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
Expand Down
101 changes: 101 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ typedef struct {
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");

typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");

#define WARP_SIZE 32

#define CUDA_MUL_BLOCK_SIZE 256
Expand Down Expand Up @@ -261,6 +269,84 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs
result = sum * scale;
}

static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}

static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const block_q4_K * x = (const block_q4_K *) vx;

const int i = blockIdx.x;

//// assume 64 threads - this is very slightly better than the one below
//const int tid = threadIdx.x;
//const int il = tid/16;
//const int ir = tid%16;
//const int is = 2*il;
//const int n = 2;

// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int is = 2*il;
const int n = 4;

float * y = yy + i*QK_K + 64*il + n*ir;

const float dall = x[i].d;
const float dmin = x[i].dmin;

const uint8_t * q = x[i].qs + 32*il + n*ir;

uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
}

static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {

const block_q4_K * x = (const block_q4_K *) vx;

// iqs is in 0...248 in steps of 8 =>
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2

const float * y = yy + 64*j + ir;
const uint8_t * q = x[ib].qs + 32*j + ir;

const float dall = x[ib].d;
const float dmin = x[ib].dmin;

uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;

float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}
result = sum;

}

static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
const half * x = (const half *) vx;

Expand Down Expand Up @@ -405,6 +491,11 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}

static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
Expand Down Expand Up @@ -451,6 +542,12 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y 8000 , dst, ncols);
}

static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 2, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
Expand Down Expand Up @@ -478,6 +575,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q8_0_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_F16:
return convert_fp16_to_fp32_cuda;
default:
Expand All @@ -499,6 +598,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
return dequantize_mul_mat_vec_q8_0_cuda;
case GGML_TYPE_Q3_K:
return dequantize_mul_mat_vec_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_mul_mat_vec_q4_K_cuda;
case GGML_TYPE_F16:
return convert_mul_mat_vec_f16_cuda;
default:
Expand Down
37 changes: 31 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_q = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q4_K] = {
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
.quantize_row_q = quantize_row_q4_K,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
.quantize_row_q_dot = quantize_row_q8_K,
.vec_dot_q = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
};

// For internal test use
Expand Down Expand Up @@ -3454,12 +3462,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q8_0] = QK8_0,
[GGML_TYPE_Q8_1] = QK8_1,
[GGML_TYPE_Q3_K] = QK_K,
[GGML_TYPE_Q4_K] = QK_K,
[GGML_TYPE_Q8_K] = QK_K,
[GGML_TYPE_I8] = 1,
[GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1,
};
static_assert(GGML_TYPE_COUNT == 15, "GGML_BLCK_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 16, "GGML_BLCK_SIZE is outdated");

static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float),
Expand All @@ -3470,13 +3479,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
[GGML_TYPE_Q3_K] = sizeof(block_q3_K),
[GGML_TYPE_Q4_K] = sizeof(block_q4_K),
[GGML_TYPE_Q8_K] = sizeof(block_q8_K),
[GGML_TYPE_I8] = sizeof(int8_t),
[GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t),
};
static_assert(GGML_TYPE_COUNT == 15, "GGML_TYPE_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_SIZE is outdated");


static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
Expand All @@ -3489,12 +3498,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q8_0] = "q8_0",
[GGML_TYPE_Q8_1] = "q8_1",
[GGML_TYPE_Q3_K] = "q3_K",
[GGML_TYPE_Q4_K] = "q4_K",
[GGML_TYPE_Q8_K] = "q8_K",
[GGML_TYPE_I8] = "i8",
[GGML_TYPE_I16] = "i16",
[GGML_TYPE_I32] = "i32",
};
static_assert(GGML_TYPE_COUNT == 15, "GGML_TYPE_NAME is outdated");
static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_NAME is outdated");

static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = false,
Expand All @@ -3505,13 +3515,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = true,
[GGML_TYPE_Q8_0] = true,
[GGML_TYPE_Q8_1] = true,
[GGML_TYPE_Q3_K] = true,
[GGML_TYPE_Q4_K] = true,
[GGML_TYPE_Q8_K] = true,
[GGML_TYPE_I8] = false,
[GGML_TYPE_I16] = false,
[GGML_TYPE_I32] = false,
};
static_assert(GGML_TYPE_COUNT == 15, "GGML_IS_QUANTIZED is outdated");
static_assert(GGML_TYPE_COUNT == 16, "GGML_IS_QUANTIZED is outdated");

static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"NONE",
Expand Down Expand Up @@ -3819,6 +3829,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
}
Expand Down Expand Up @@ -7603,6 +7614,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
{
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
Expand Down Expand Up @@ -7907,6 +7919,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
{
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
Expand Down Expand Up @@ -8030,6 +8043,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
default:
{
GGML_ASSERT(false);
Expand Down Expand Up @@ -10124,6 +10138,7 @@ static void ggml_compute_forward_mul_mat(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
Expand Down Expand Up @@ -10308,6 +10323,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
default:
{
GGML_ASSERT(false);
Expand Down Expand Up @@ -10474,6 +10490,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
Expand Down Expand Up @@ -11021,6 +11038,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
Expand Down Expand Up @@ -11094,6 +11112,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
Expand Down Expand Up @@ -16104,6 +16123,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
result = ggml_quantize_q3_K(src + start, block, n, n, hist);
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
result = ggml_quantize_q4_K(src + start, block, n, n, hist);
} break;
default:
assert(false);
}
Expand Down
5 changes: 3 additions & 2 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ extern "C" {
GGML_TYPE_Q8_1 = 9,
// k-quantizations
GGML_TYPE_Q3_K = 10,
//GGML_TYPE_Q4_K = 11,
GGML_TYPE_Q4_K = 11,
//GGML_TYPE_Q5_K = 12,
//GGML_TYPE_Q6_K = 13,
GGML_TYPE_Q8_K = 11,
GGML_TYPE_Q8_K = 12,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
Expand All @@ -271,6 +271,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
GGML_FTYPE_MOSTLY_Q3_K = 10, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K = 11, // except 1d tensors
};

// available tensor operations:
Expand Down
Loading
0