@@ -6995,7 +6995,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6995
6995
6996
6996
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
6997
6997
assert(n % QK_K == 0);
6998
+ #ifdef __ARM_FEATURE_MATMUL_INT8
6999
+ assert((nrc == 2) || (nrc == 1));
7000
+ #else
6998
7001
assert(nrc == 1);
7002
+ #endif
6999
7003
UNUSED(nrc);
7000
7004
UNUSED(bx);
7001
7005
UNUSED(by);
@@ -7012,6 +7016,146 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7012
7016
7013
7017
uint32_t utmp[4];
7014
7018
7019
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
7020
+ if (nrc == 2) {
7021
+ const block_q4_K * GGML_RESTRICT x0 = x;
7022
+ const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
7023
+ const block_q8_K * GGML_RESTRICT y0 = y;
7024
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
7025
+
7026
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
7027
+
7028
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
7029
+
7030
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
7031
+ const uint8_t * GGML_RESTRICT qx0 = x0->qs;
7032
+ const uint8_t * GGML_RESTRICT qx1 = x1->qs;
7033
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
7034
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
7035
+
7036
+ // decode scales and mins
7037
+ int8_t x0_scales[8], x1_scales[8];
7038
+ int16x8_t x0_mins, x1_mins;
7039
+ {
7040
+ uint32_t scales_mins[3];
7041
+ memcpy(scales_mins, x0->scales, 12);
7042
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
7043
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
7044
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
7045
+ x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
7046
+ uint32_t scales[2];
7047
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
7048
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
7049
+ memcpy(x0_scales, scales, 8);
7050
+ }
7051
+ {
7052
+ uint32_t scales_mins[3];
7053
+ memcpy(scales_mins, x1->scales, 12);
7054
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
7055
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
7056
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
7057
+ x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
7058
+ uint32_t scales[2];
7059
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
7060
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
7061
+ memcpy(x1_scales, scales, 8);
7062
+ }
7063
+
7064
+ int32x4_t visum = {0};
7065
+
7066
+ // process 64 data points per iteration, totally 256 data points
7067
+ for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
7068
+ const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
7069
+ const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
7070
+
7071
+ int8x16_t vx0[4], vx1[4];
7072
+ {
7073
+ const uint8x16x2_t vv = vld1q_u8_x2(qx0);
7074
+ vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
7075
+ vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
7076
+ vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
7077
+ vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
7078
+ }
7079
+ {
7080
+ const uint8x16x2_t vv = vld1q_u8_x2(qx1);
7081
+ vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
7082
+ vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
7083
+ vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
7084
+ vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
7085
+ }
7086
+
7087
+ // process 32 data points (share same block scale) per iteration
7088
+ for (int k = 0; k < 2; ++k) {
7089
+ const int blk = j * 2 + k;
7090
+ const int32x4_t block_scale = {
7091
+ x0_scales[blk],
7092
+ x0_scales[blk],
7093
+ x1_scales[blk],
7094
+ x1_scales[blk],
7095
+ };
7096
+
7097
+ int32x4_t vr = {0};
7098
+ for (int l = 0; l < 2; ++l) {
7099
+ const int idx = k * 2 + l;
7100
+ const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
7101
+ const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
7102
+ const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
7103
+ const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
7104
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
7105
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
7106
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
7107
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
7108
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
7109
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
7110
+ }
7111
+ // apply block scale, will NOT overflow
7112
+ // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
7113
+ visum = vmlaq_s32(visum, vr, block_scale);
7114
+ }
7115
+ }
7116
+
7117
+ // adjust bias, apply superblock scale
7118
+ {
7119
+ int32_t bias[4];
7120
+ // no obvious uplift from sve sdot-16, just use neon mul add
7121
+ const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
7122
+ const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
7123
+ bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
7124
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
7125
+ bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
7126
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
7127
+ bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
7128
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
7129
+ bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
7130
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
7131
+ const float32x4_t dmins = {
7132
+ GGML_FP16_TO_FP32(x0->dmin) * y0->d,
7133
+ GGML_FP16_TO_FP32(x0->dmin) * y1->d,
7134
+ GGML_FP16_TO_FP32(x1->dmin) * y0->d,
7135
+ GGML_FP16_TO_FP32(x1->dmin) * y1->d,
7136
+ };
7137
+ vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
7138
+
7139
+ const float32x4_t superblock_scale = {
7140
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
7141
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
7142
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
7143
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
7144
+ };
7145
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
7146
+ }
7147
+ }
7148
+
7149
+ // vfsum = ABCD -> ACBD
7150
+ // AC -> s, BD -> (s+bs)
7151
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
7152
+ vst1_f32(s, vget_low_f32 (vfsum));
7153
+ vst1_f32(s + bs, vget_high_f32(vfsum));
7154
+
7155
+ return;
7156
+ }
7157
+ #endif
7158
+
7015
7159
#ifdef __ARM_FEATURE_SVE
7016
7160
float sumf = 0;
7017
7161
for (int i = 0; i < nb; ++i) {
0 commit comments