8000 Introduce IQ4_NL_4_4 format and its neon implementation by FanShupei · Pull Request #10196 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Introduce IQ4_NL_4_4 format and its neon implementation #10196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
[ggml-aarch64] impl the same logic as the ASM version in q4_0_4_4 gem…
…m/gemv
  • Loading branch information
FanShupei committed Nov 10, 2024
commit c7a54d1f2bae0966eb0ad8c9a43c18475e7ccdc4
202 changes: 177 additions & 25 deletions 8000 ggml/src/ggml-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,31 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
float * res_ptr = s;

for (int x = 0; x < nc / ncols_interleaved; x++) {
// %x[nc] : loop control

const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);

float32x4_t sumf = vdupq_n_f32(0);
// v29 = sumf

for (int l = 0; l < nb; l++) {
// x21 : loop control

// x22 = a_ptr[l].qs
// %x[b_ptr] = b_ptr[l].qs

int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
// (v27, v25) = (a_0, a_1)

uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
// (v28, v24, v23, v22) = (b_0, b_1, b_2, b_3)

float16x4_t b_d_half = vld1_f16((const float16_t *)b_ptr[l].d);
// v20 = b_d_half

int8x16_t b_0_hi = vreinterpretq_s8_u8(b_0 & 0xF0);
int8x16_t b_0_lo = vreinterpretq_s8_u8(b_0 << 4);
Expand All @@ -684,11 +701,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
int8x16_t b_2_lo = vreinterpretq_s8_u8(b_2 << 4);
int8x16_t b_3_hi = vreinterpretq_s8_u8(b_3 & 0xF0);
int8x16_t b_3_lo = vreinterpretq_s8_u8(b_3 << 4);

int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
// (v16, v28) = (b_0_lo, b_0_hi)
// (v19, v24) = (b_0_lo, b_0_hi)
// (v18, v23) = (b_0_lo, b_0_hi)
// (v17, v22) = (b_0_lo, b_0_hi)

int32x4_t sumi = vdupq_n_s32(0);
// v26 = sumi
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
Expand All @@ -697,15 +716,21 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);

float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v21 = a_d

float32x4_t b_d = vcvt_f32_f16(b_d_half);
// v16 = b_d

float32x4_t d = a_d * b_d;
// v16 = d

sumf = vmlaq_f32(sumf, d, vcvtq_n_f32_s32(sumi, 4));
}

vst1q_f32(res_ptr + x * 4, sumf);
// %x[res_ptr] = res_ptr + x * 4
}
return;
}
Expand Down Expand Up @@ -1174,7 +1199,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);

float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
float32x4_t d = a_d * b_d;
Expand Down Expand Up @@ -1236,7 +1261,97 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *

#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
if (ggml_cpu_has_neon()) {
for (int y = 0; y < nr / 4; y++) {
#define UNROLL_FACTOR 4
int y = 0;
for (; y + UNROLL_FACTOR <= nr / 4; y += UNROLL_FACTOR) {
const block_q8_0x4 * a_ptr[UNROLL_FACTOR];
for (int z = 0; z < UNROLL_FACTOR; z++) {
a_ptr[z] = (const block_q8_0x4 *) vy + ((y + z) * nb);
}

for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);

float32x4_t sumf[UNROLL_FACTOR][4];
for (int z = 0; z < UNROLL_FACTOR; z ++) {
for (int m = 0; m < 4; m++) {
sumf[z][m] = vdupq_n_f32(0);
}
}
// (v15, v19, v18, v14) = sumf[0][0, 1, 2, 3]
// (v11, v13, v23, v16) = sumf[1][0, 1, 2, 3]
// (v27, v7, v0, v4 ) = sumf[2][0, 1, 2, 3]
// (v5, v21, v8, v1 ) = sumf[3][0, 1, 2, 3]

for (int l = 0; l < nb; l++) {
// x24 : loop control

// x28 = b_ptr[l].qs
// (x25, x23, x22, x21) = a_ptr[0, 1, 2, 3][l].qs

int8x16_t b_hi[4], b_lo[4];
for (int k = 0; k < 4; k++) {
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
b_hi[k] = vreinterpretq_s8_u8(b & 0xF0);
b_lo[k] = vreinterpretq_s8_u8(b << 4);
}
// (v12, v3) = (b_lo[0], b_hi[0])
// (v31, v22) = (b_lo[1], b_hi[1])
// (v6, v27) = (b_lo[2], b_hi[2])
// (v28, v30) = (b_lo[3], b_hi[3])

float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v17 = b_d

// unroll in ASM
for (int z = 0; z < UNROLL_FACTOR; z++) {
int32x4_t sumi[4];
for (int m = 0; m < 4; m++) {
sumi[m] = vdupq_n_s32(0);
}
// (v10, v29, v9, v20) = sumi[0, 1, 2, 3] (z = 0)
// (v9, v29, v20, v2) = sumi[0, 1, 2, 3] (z = 1)
// (v20, v10, v26, v2) = sumi[0, 1, 2, 3] (z = 2)
// (v26, v10, v2, v19) = sumi[0, 1, 2, 3] (z = 3)

for (int k = 0; k < 4; k++) {
int8x16_t a0 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 0);
sumi[0] = vdotq_laneq_s32(sumi[0], b_lo[k], a0, 0);
sumi[1] = vdotq_laneq_s32(sumi[1], b_lo[k], a0, 1);
sumi[2] = vdotq_laneq_s32(sumi[2], b_lo[k], a0, 2);
sumi[3] = vdotq_laneq_s32(sumi[3], b_lo[k], a0, 3);
}
for (int k = 0; k < 4; k++) {
int8x16_t a1 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 64);
sumi[0] = vdotq_laneq_s32(sumi[0], b_hi[k], a1, 0);
sumi[1] = vdotq_laneq_s32(sumi[1], b_hi[k], a1, 1);
sumi[2] = vdotq_laneq_s32(sumi[2], b_hi[k], a1, 2);
sumi[3] = vdotq_laneq_s32(sumi[3], b_hi[k], a1, 3);
}

float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[z][l].d));
// (v2, v26, v29, v20) = a_d (z = 0, 1, 2, 3)

sumf[z][0] = vmlaq_f32(sumf[z][0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi[0], 4));
sumf[z][1] = vmlaq_f32(sumf[z][1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_n_f32_s32(sumi[1], 4));
sumf[z][2] = vmlaq_f32(sumf[z][2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_n_f32_s32(sumi[2], 4));
sumf[z][3] = vmlaq_f32(sumf[z][3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi[3], 4));
}

}

for (int z = 0; z < UNROLL_FACTOR; z++) {
for (int m = 0; m < 4; m++) {
vst1q_f32(s + ((y + z) * 4 + m) * bs + x * 4, sumf[z][m]);
}
}
}
}
#undef UNROLL_FACTOR

for (; y < nr / 4; y++) {
// x10 : loop control

const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
Expand All @@ -1245,32 +1360,68 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}

// (v15, v19, v18, v14) = sumf[0, 1, 2, 3]

for (int l = 0; l < nb; l++) {
// x21 : loop control

// x25 = a_ptr[l].qs
// x24 = b_ptr[l].qs

int8x16_t a_0[4], a_1[4];
a_0[0] = vld1q_s8(a_ptr[l].qs + 0);
a_0[1] = vld1q_s8(a_ptr[l].qs + 16);
a_0[2] = vld1q_s8(a_ptr[l].qs + 32);
a_0[3] = vld1q_s8(a_ptr[l].qs + 48);
a_1[0] = vld1q_s8(a_ptr[l].qs + 64);
a_1[1] = vld1q_s8(a_ptr[l].qs + 80);
a_1[2] = vld1q_s8(a_ptr[l].qs + 96);
a_1[3] = vld1q_s8(a_ptr[l].qs + 112);
// (v5, v26) = (a_0[0], a_1[0])
// (v2, v25) = (a_0[0], a_1[0])
// (v31, v24) = (a_0[0], a_1[0])
// (v27, v16) = (a_0[0], a_1[0])

uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
// (v7, v3, v13, v28) = (b_0, b_1, b_2, b_3)

int8x16_t b_lo[4], b_hi[4];
b_hi[0] = vreinterpretq_s8_u8(b_0 & 0xF0);
b_lo[0] = vreinterpretq_s8_u8(b_0 << 4);
b_hi[1] = vreinterpretq_s8_u8(b_1 & 0xF0);
b_lo[1] = vreinterpretq_s8_u8(b_1 << 4);
b_hi[2] = vreinterpretq_s8_u8(b_2 & 0xF0);
b_lo[2] = vreinterpretq_s8_u8(b_2 << 4);
b_hi[3] = vreinterpretq_s8_u8(b_3 & 0xF0);
b_lo[3] = vreinterpretq_s8_u8(b_3 << 4);
// (v20, v7) = (b_lo[0], b_hi[0])
// (v17, v3) = (b_lo[1], b_hi[1])
// (v22, v13) = (b_lo[2], b_hi[2])
// (v9, v28) = (b_lo[3], b_hi[3])

float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
// v12 = a_d
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v21 = b_d

int32x4_t sumi_0 = vdupq_n_s32(0);
int32x4_t sumi_1 = vdupq_n_s32(0);
int32x4_t sumi_2 = vdupq_n_s32(0);
int32x4_t sumi_3 = vdupq_n_s32(0);
// (v4, v1, v0, v30) = (sumi_0, sumi_1, sumi_2, sumi_3)

for (int k = 0; k < 4; k++) {
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);

uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
int8x16_t b_hi = vreinterpretq_s8_u8(b & 0xF0);
int8x16_t b_lo = vreinterpretq_s8_u8(b << 4);

sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo[k], a_0[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo[k], a_0[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo[k], a_0[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo[k], a_0[k], 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi[k], a_1[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi[k], a_1[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi[k], a_1[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi[k], a_1[k], 3);
}

sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi_0, 4));
Expand All @@ -1279,6 +1430,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi_3, 4));
}

// NOTE: asm version has addition code to handle `nr` is not multiple of 4
for (int m = 0; m < 4; m++) {
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
}
Expand Down Expand Up @@ -3230,7 +3382,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}

for (int l = 0; l < nb; l++) {
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
Expand All @@ -3244,7 +3396,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);

uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);

Expand Down
Loading
0