@@ -1360,31 +1360,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1360
1360
const int8x16_t v1_1hs = vsubq_s8 (v1_1h , s8b );
1361
1361
1362
1362
// dot product into int16x8_t
1363
- const int16x8_t pl0l = vmull_s8 ( vget_low_s8 ( v0_0ls ), vget_low_s8 ( v1_0ls ) );
1364
- const int16x8_t pl0h = vmull_s8 ( vget_high_s8 ( v0_0ls ), vget_high_s8 ( v1_0ls ) );
1363
+ int32x4_t p_0 = vdotq_s32 ( vdupq_n_s32 ( 0 ), v0_0ls , v1_0ls );
1364
+ int32x4_t p_1 = vdotq_s32 ( vdupq_n_s32 ( 0 ), v0_1ls , v1_1ls );
1365
1365
1366
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
1367
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
1368
-
1369
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
1370
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
1371
-
1372
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
1373
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
1374
-
1375
- const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
1376
- const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
1377
-
1378
- const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
1379
- const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
1380
-
1381
- const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
1382
- const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
1366
+ p_0 = vdotq_s32 (p_0 , v0_0hs , v1_0hs );
1367
+ p_1 = vdotq_s32 (p_1 , v0_1hs , v1_1hs );
1383
1368
1384
1369
// scalar
1385
1370
#if defined(__ARM_FEATURE_QRDMX )
1386
- sum0 += d0_0 * d1_0 * vaddvq_s16 (p_0 );
1387
- sum1 += d0_1 * d1_1 * vaddvq_s16 (p_1 );
1371
+ sum0 += d0_0 * d1_0 * vaddvq_s32 (p_0 );
1372
+ sum1 += d0_1 * d1_1 * vaddvq_s32 (p_1 );
1388
1373
#else
1389
1374
sum0 += d0_0 * d1_0 * (vgetq_lane_s16 (p_0 , 0 ) + vgetq_lane_s16 (p_0 , 1 ) + vgetq_lane_s16 (p_0 , 2 ) + vgetq_lane_s16 (p_0 , 3 ) + vgetq_lane_s16 (p_0 , 4 ) + vgetq_lane_s16 (p_0 , 5 ) + vgetq_lane_s16 (p_0 , 6 ) + vgetq_lane_s16 (p_0 , 7 ));
1390
1375
sum1 += d0_1 * d1_1 * (vgetq_lane_s16 (p_1 , 0 ) + vgetq_lane_s16 (p_1 , 1 ) + vgetq_lane_s16 (p_1 , 2 ) + vgetq_lane_s16 (p_1 , 3 ) + vgetq_lane_s16 (p_1 , 4 ) + vgetq_lane_s16 (p_1 , 5 ) + vgetq_lane_s16 (p_1 , 6 ) +
37C3
vgetq_lane_s16 (p_1 , 7 ));
0 commit comments