8000 10% performance boost on ARM · psy2013GitHub/llama.cpp@113a9e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 113a9e8

Browse files
committed
10% performance boost on ARM
1 parent 404fac0 commit 113a9e8

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

ggml.c

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,31 +1360,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
13601360
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
13611361

13621362
// dot product into int16x8_t
1363-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1364-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1363+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1364+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
13651365

1366-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1367-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1368-
1369-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1370-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1371-
1372-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1373-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1374-
1375-
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
1376-
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
1377-
1378-
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
1379-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
1380-
1381-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1382-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1366+
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1367+
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
13831368

13841369
// scalar
13851370
#if defined(__ARM_FEATURE_QRDMX)
1386-
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
1387-
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
1371+
sum0 += d0_0*d1_0*vaddvq_s32(p_0);
1372+
sum1 += d0_1*d1_1*vaddvq_s32(p_1);
13881373
#else
13891374
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
13901375
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + 37C3 vgetq_lane_s16(p_1, 7));

0 commit comments

Comments
 (0)
0