@@ -529,62 +529,42 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
529
529
530
530
#if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
531
531
if (ggml_cpu_has_neon ()) {
532
- const void * b_ptr = vx ;
533
- const void * a_ptr = vy ;
534
- float * res_ptr = s ;
535
-
536
- __asm__ __volatile__(
537
- "movi v31.16b, #0x4\n"
538
- "movi v30.16b, #0xf0\n"
539
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
540
- "1:" // Column loop
541
- "add x22, %x[a_ptr], #0x2\n"
542
- "movi v29.16b, #0x0\n"
543
- "mov x21, %x[nb]\n"
544
- "2:" // Block loop
545
- "ldr q28, [%x[b_ptr], #0x0]\n"
546
- "ldr q27, [x22, #0x0]\n"
547
- "movi v26.4s, #0x0\n"
548
- "sub x20, x22, #0x2\n"
549
- "ldr q25, [x22, #0x10]\n"
550
- "ldr q24, [%x[b_ptr], #0x10]\n"
551
- "sub x21, x21, #0x1\n"
552
- "add x22, x22, #0x22\n"
553
- "ldr q23, [%x[b_ptr], #0x20]\n"
554
- "ldr q22, [%x[b_ptr], #0x30]\n"
555
- "ld1r { v21.8h }, [x20]\n"
556
- "ldr q20, [%x[b_ptr], #-0x8]\n"
557
- "sshl v16.16b, v28.16b, v31.16b\n"
558
- "and v28.16b, v28.16b, v30.16b\n"
559
- "sshl v19.16b, v24.16b, v31.16b\n"
560
- "and v24.16b, v24.16b, v30.16b\n"
561
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
562
- "sshl v18.16b, v23.16b, v31.16b\n"
563
- "and v23.16b, v23.16b, v30.16b\n"
564
- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
565
- "sshl v17.16b, v22.16b, v31.16b\n"
566
- "and v22.16b, v22.16b, v30.16b\n"
567
- "fcvtl v21.4s, v21.4h\n"
568
- "fcvtl v16.4s, v20.4h\n"
569
- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
570
- "fmul v16.4s, v16.4s, v21.4s\n"
571
- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
572
- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
573
- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
574
- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
575
- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
576
- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
577
- "scvtf v26.4s, v26.4s, #0x4\n"
578
- "fmla v29.4s, v26.4s, v16.4s\n"
579
- "cbnz x21, 2b\n"
580
- "sub %x[nc], %x[nc], #0x4\n"
581
- "str q29, [%x[res_ptr], #0x0]\n"
582
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
583
- "cbnz %x[nc], 1b\n"
584
- : [b_ptr ] "+&r" (b_ptr ), [res_ptr ] "+&r" (res_ptr ), [nc ] "+&r" (nc )
585
- : [a_ptr ] "r" (a_ptr ), [nb ] "r" (nb )
586
- : "memory" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" , "x20" , "x21" , "x22"
587
- );
532
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * )vx ;
533
+
534
+ for (int c = 0 ; c < nc ; c += ncols_interleaved ) {
535
+ const block_q8_0 * a_ptr = (const block_q8_0 * )vy ;
536
+ float32x4_t acc = vdupq_n_f32 (0 );
537
+ for (int b = 0 ; b < nb ; b += 1 ) {
538
+ int8x16_t b0 = vld1q_s8 ((const int8_t * )b_ptr -> qs );
539
+ int8x16_t b1 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 16 );
540
+ int8x16_t b2 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 32 );
541
+ int8x16_t b3 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 48 );
542
+ float16x4_t bd = vld1_f16 ((const __fp16 * )b_ptr -> d );
543
+
544
+ int8x16_t a0 = vld1q_s8 (a_ptr -> qs );
545
+ int8x16_t a1 = vld1q_s8 (a_ptr -> qs + qk /2 );
546
+ float16x4_t ad = vld1_dup_f16 ((const __fp16 * )& a_ptr -> d );
547
+
548
+ int32x4_t ret = vdupq_n_s32 (0 );
549
+
550
+ ret = vdotq_laneq_s32 (ret , b0 << 4 , a0 , 0 );
551
+ ret = vdotq_laneq_s32 (ret , b1 << 4 , a0 , 1 );
552
+ ret = vdotq_laneq_s32 (ret , b2 << 4 , a0 , 2 );
553
+ ret = vdotq_laneq_s32 (ret , b3 << 4 , a0 , 3 );
554
+
555
+ ret = vdotq_laneq_s32 (ret , b0 & 0xf0U , a1 , 0 );
556
+ ret = vdotq_laneq_s32 (ret , b1 & 0xf0U , a1 , 1 );
557
+ ret = vdotq_laneq_s32 (ret , b2 & 0xf0U , a1 , 2 );
558
+ ret = vdotq_laneq_s32 (ret , b3 & 0xf0U , a1 , 3 );
559
+
560
+ acc = vfmaq_f32 (acc , vcvtq_n_f32_s32 (ret , 4 ),
561
+ vmulq_f32 (vcvt_f32_f16 (ad ), vcvt_f32_f16 (bd )));
562
+ a_ptr ++ ;
563
+ b_ptr ++ ;
564
+ }
565
+ vst1q_f32 (s , acc );
566
+ s += ncols_interleaved ;
567
+ }
588
568
return ;
589
569
}
590
570
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
0 commit comments