@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8519
8519
8520
8520
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
8521
8521
assert(n % QK_K == 0);
8522
+ #ifdef __ARM_FEATURE_MATMUL_INT8
8523
+ assert((nrc == 2) || (nrc == 1));
8524
+ #else
8522
8525
assert(nrc == 1);
8526
+ #endif
8523
8527
UNUSED(nrc);
8524
8528
UNUSED(bx);
8525
8529
UNUSED(by);
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8530
8534
8531
8535
const int nb = n / QK_K;
8532
8536
8537
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
8538
+ if (nrc == 2) {
8539
+ const block_q6_K * GGML_RESTRICT x0 = x;
8540
+ const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
8541
+ const block_q8_K * GGML_RESTRICT y0 = y;
8542
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
8543
+
8544
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
8545
+
8546
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
8547
+ const uint8_t * GGML_RESTRICT ql0 = x0->ql;
8548
+ const uint8_t * GGML_RESTRICT ql1 = x1->ql;
8549
+ const uint8_t * GGML_RESTRICT qh0 = x0->qh;
8550
+ const uint8_t * GGML_RESTRICT qh1 = x1->qh;
8551
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
8552
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
8553
+
8554
+ const uint8x16_t mone = vdupq_n_u8(0x30);
8555
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
8556
+
8557
+ int32x4_t visum = vdupq_n_s32(0);
8558
+
8559
+ // process 8 blocks per iteration, totally 16 blocks
8560
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
8561
+ int8x16_t vx0[8], vx1[8];
8562
+
8563
+ // de-quantize vx0[8]
8564
+ {
8565
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
8566
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
8567
+
8568
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8569
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8570
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8571
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8572
+
8573
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8574
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8575
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8576
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8577
+
8578
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8579
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8580
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8581
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8582
+
8583
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8584
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8585
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8586
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8587
+ }
8588
+
8589
+ // de-quantize vx1[8]
8590
+ {
8591
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
8592
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
8593
+
8594
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8595
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8596
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8597
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8598
+
8599
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8600
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8601
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8602
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8603
+
8604
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8605
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8606
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8607
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8608
+
8609
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8610
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8611
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8612
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8613
+ }
8614
+
8615
+ // process 16 elements (one block with same scale) per iteration
8616
+ // - vx = concat(ql, qh) - 32
8617
+ // - r1,r2,r3,r4 = smmla(vx, vy)
8618
+ for (int k = 0; k < 8; ++k) {
8619
+ const int blk = j * 8 + k;
8620
+
8621
+ const int8x16_t vy0 = vld1q_s8(qy0);
8622
+ const int8x16_t vy1 = vld1q_s8(qy1);
8623
+ qy0 += 16;
8624
+ qy1 += 16;
8625
+
8626
+ const int32x4_t block_scale = {
8627
+ x0->scales[blk],
8628
+ x0->scales[blk],
8629
+ x1->scales[blk],
8630
+ x1->scales[blk],
8631
+ };
8632
+
8633
+ // calculate four results at once with outer product
8634
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8635
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8636
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8637
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8638
+ int32x4_t vr = vdupq_n_s32(0);
8639
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
8640
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
8641
+
8642
+ // apply block scale, will NOT overflow
8643
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
8644
+ visum = vmlaq_s32(visum, vr, block_scale);
8645
+ }
8646
+ }
8647
+
8648
+ // adjust bias, apply superblock scale
8649
+ {
8650
+ int32_t bias[4];
8651
+ #ifdef __ARM_FEATURE_SVE
8652
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8653
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
8654
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
8655
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
8656
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
8657
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
8658
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
8659
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
8660
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
8661
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
8662
+ const svint64_t zero = svdup_n_s64(0);
8663
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
8664
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
8665
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
8666
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
8667
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
8668
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
8669
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
8670
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
8671
+ #else
8672
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
8673
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
8674
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
8675
+
8676
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
8677
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8678
+ scales_s8 = vld1q_s8(x1->scales);
8679
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8680
+
8681
+ int32x4_t prod;
8682
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
8683
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
8684
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
8685
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
8686
+ bias[0] = vaddvq_s32(prod);
8687
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
8688
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
8689
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
8690
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
8691
+ bias[1] = vaddvq_s32(prod);
8692
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
8693
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
8694
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
8695
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
8696
+ bias[2] = vaddvq_s32(prod);
8697
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
8698
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
8699
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
8700
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
8701
+ bias[3] = vaddvq_s32(prod);
8702
+
8703
+ #endif
8704
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
8705
+
8706
+ const float32x4_t superblock_scale = {
8707
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
8708
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
8709
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
8710
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
8711
+ };
8712
+
8713
+ visum = vsubq_s32(visum, vibias);
8714
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
8715
+ }
8716
+ }
8717
+
8718
+ // vfsum = ABCD -> ACBD
8719
+ // AC -> s, BD -> (s+bs)
8720
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
8721
+ vst1_f32(s, vget_low_f32 (vfsum));
8722
+ vst1_f32(s + bs, vget_high_f32(vfsum));
8723
+
8724
+ return;
8725
+ }
8726
+ #endif
8727
+
8533
8728
#ifdef __ARM_FEATURE_SVE
8534
8729
const int vector_length = ggml_cpu_get_sve_cnt()*8;
8535
8730
float sum = 0;
0 commit comments