8000 Convert vector to f16 for dequantize mul mat vec by JohannesGaessler · Pull Request #1913 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Convert vector to f16 for dequantize mul mat vec #1913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 19, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

8000
Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dfloat2
  • Loading branch information
JohannesGaessler committed Jun 18, 2023
commit 8ac993bd0a38e44c3d3e0c0096efd82a1918806f
105 changes: 62 additions & 43 deletions ggml-cuda.cu
6865
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,25 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0)
#endif // CUDART_VERSION >= 11

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef float dfloat; // dequantize float
typedef float2 dfloat2;

static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) {
a.x += b;
a.y += b;
}

static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) {
a.x -= b;
a.y -= b;
}

static __device__ __forceinline__ void dmul_inplace(dfloat2 & a, const dfloat b) {
a.x *= b;
a.y *= b;
}

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
Expand Down Expand Up @@ -234,82 +252,81 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}

static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

const float d = x[ib].d;
const dfloat d = x[ib].d;

const uint8_t vui = x[ib].qs[iqs];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;
v.x = vui & 0xF;
v.y = vui >> 4;

v0 = (vi0 - 8)*d;
v1 = (vi1 - 8)*d;
dsub_inplace(v, 8.0f);
dmul_inplace(v, d);
}

static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;

const float d = x[ib].d;
const float m = x[ib].m;
const dfloat d = x[ib].d;
const dfloat m = x[ib].m;

const uint8_t vui = x[ib].qs[iqs];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;
v.x = vui & 0xF;
v.y = vui >> 4;

v0 = vi0*d + m;
v1 = vi1*d + m;
dmul_inplace(v, d);
dadd_inplace(v, m);
}

static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;

const float d = x[ib].d;
const dfloat d = x[ib].d;

uint32_t qh 8000 ;
memcpy(&qh, x[ib].qh, sizeof(qh));

const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;

const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);

v0 = x0*d;
v1 = x1*d;
dsub_inplace(v, 16.0f);
dmul_inplace(v, d);
}

static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;

const float d = x[ib].d;
const float m = x[ib].m;
const dfloat d = x[ib].d;
const dfloat m = x[ib].m;

uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));

const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;

const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);

v0 = x0*d + m;
v1 = x1*d + m;
dmul_inplace(v, d);
dadd_inplace(v, m);
}

static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;

const float d = x[ib].d;
const dfloat d = x[ib].d;

const int8_t vi0 = x[ib].qs[iqs + 0];
const int8_t vi1 = x[ib].qs[iqs + 1];
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];

v0 = vi0*d;
v1 = vi1*d;
dmul_inplace(v, d);
}

//================================== k-quants
Expand Down Expand Up @@ -843,11 +860,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
}
}

static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;

v0 = __half2float(x[ib + iqs + 0]);
v1 = __half2float(x[ib + iqs + 1]);
v.x = __half2float(x[ib + iqs + 0]);
v.y = __half2float(x[ib + iqs + 1]);
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
Expand All @@ -864,9 +881,11 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
const int y_offset = qr == 1 ? 1 : qk/2;

// dequantize
float & v0 = y[iybs + iqs + 0];
float & v1 = y[iybs + iqs + y_offset];
dequantize_kernel(vx, ib, iqs, v0, v1);
dfloat2 v;
dequantize_kernel(vx, ib, iqs, v);

y[iybs + iqs + 0] = v.x;
y[iybs + iqs + y_offset] = v.y;
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
Expand Down Expand Up @@ -899,13 +918,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
// process 2 vals per j iter

// dequantize
float v0, v1;
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
dfloat2 v;
dequantize_kernel(vx, ib, iqs + j/qr, v);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val

// matrix multiplication
tmp += v0 * __half2float(y[iybs + iqs + j/qr + 0]);
tmp += v1 * __half2float(y[iybs + iqs + j/qr + y_offset]);
tmp += v.x * __half2float(y[iybs + iqs + j/qr + 0]);
tmp += v.y * __half2float(y[iybs + iqs + j/qr + y_offset]);
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
}
}
Expand Down
0