8000 ggml: aarch64: Implement SVE F32 kernels for vector functions (#13843) · ggml-org/llama.cpp@1b8fb81 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 1b8fb81

Browse files
ggml: aarch64: Implement SVE F32 kernels for vector functions (#13843)
* F32-Mamba-SVE * F32-Mamba-SVE * Resolve test errors-1 * Resolve test errors-2 * F32-vec-SVE * F32-vec-SVE * F32-vec-SVE
1 parent 53ae306 commit 1b8fb81

File tree

4 files changed

+513
-138
lines changed

4 files changed

+513
-138
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 143 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32(
76417641
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
76427642
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
76437643
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7644+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
76467646

76477647
// use the output as the source for the next token-wise iterations
76487648
if (i2 > 0) { s0 = s; }
@@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80708070
#define GGML_F32X_MUL GGML_F32x16_MUL
80718071
#define GGML_F32X_FMA GGML_F32x16_FMA
80728072
#define WKV_VECTOR_SIZE 16
8073+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8074+
#define GGML_F32X GGML_F32xt
8075+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8076+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8077+
#define GGML_F32X_STORE GGML_F32xt_STORE
8078+
#define GGML_F32X_MUL GGML_F32xt_MUL
8079+
#define GGML_F32X_FMA GGML_F32xt_FMA
8080+
#define WKV_VECTOR_SIZE 8
80738081
#elif defined(__ARM_NEON) && defined(__aarch64__)
80748082
#define GGML_F32X GGML_F32x4
80758083
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80808088
#define WKV_VECTOR_SIZE 4
80818089
#endif
80828090

8091+
int wkv_vector_size;
80838092
#ifdef WKV_VECTOR_SIZE
8084-
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8093+
#if defined(__ARM_FEATURE_SVE)
8094+
wkv_vector_size = svcntw();
8095+
#else
8096+
wkv_vector_size = WKV_VECTOR_SIZE;
8097+
#endif
8098+
const int64_t vec_count = head_size / wkv_vector_size;
80858099

80868100
for (int64_t t = 0; t < T; t++) {
80878101
size_t t_offset = t * t_stride;
@@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81118125
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
81128126

81138127
for (int64_t j = 0; j < vec_count; j++) {
8114-
size_t base_j = j * WKV_VECTOR_SIZE;
8128+
size_t base_j = j * wkv_vector_size;
81158129
size_t t_h_j_offset = t_h_offset + base_j;
81168130
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
81178131

@@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81368150
}
81378151

81388152
// Handle remaining elements, this will not be used.
8139-
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8153+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
81408154
size_t t_h_j_offset = t_h_offset + j;
81418155
size_t h_2d_i_j_offset = h_2d_i_offset + j;
81428156
float v_val = v[t_h_j_offset];
@@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32(
82728286
#define GGML_F32X_MUL GGML_F32x16_MUL
82738287
#define GGML_F32X_FMA GGML_F32x16_FMA
82748288
#define GLA_VECTOR_SIZE 16
8289+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8290+
#define GGML_F32X GGML_F32xt
8291+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8292+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8293+
#define GGML_F32X_STORE GGML_F32xt_STORE
8294+
#define GGML_F32X_MUL GGML_F32xt_MUL
8295+
#define GGML_F32X_FMA GGML_F32xt_FMA
8296+
#define GLA_VECTOR_SIZE 8
82758297
#elif defined(__ARM_NEON) && defined(__aarch64__)
82768298
#define GGML_F32X GGML_F32x4
82778299
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32(
82828304
#define GLA_VECTOR_SIZE 4
82838305
#endif
82848306

8307+
int 9E88 gla_vector_size;
82858308
#ifdef GLA_VECTOR_SIZE
8286-
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8309+
#if defined(__ARM_FEATURE_SVE)
8310+
gla_vector_size = svcntw();
8311+
#else
8312+
gla_vector_size = GLA_VECTOR_SIZE;
8313+
#endif
8314+
const int64_t vec_count = head_size / gla_vector_size;
82878315

82888316
for (int64_t t = 0; t < T; t++) {
82898317
size_t t_offset = t * t_stride;
@@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32(
83108338
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
83118339

83128340
for (int64_t j = 0; j < vec_count; j++) {
8313-
size_t base_j = j * GLA_VECTOR_SIZE;
8341+
size_t base_j = j * gla_vector_size;
83148342
size_t t_h_j_offset = t_h_offset + base_j;
83158343
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
83168344

@@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32(
83348362
}
83358363

83368364
// Handle remaining elements, this will not be used.
8337-
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8365+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
83388366
size_t t_h_j_offset = t_h_offset + j;
83398367
size_t h_2d_i_j_offset = h_2d_i_offset + j;
83408368
float v_val = v[t_h_j_offset];
@@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
84438471
int64_t h_stride_2d = head_size * head_size;
84448472

84458473
#if defined(GGML_SIMD)
8446-
for (int64_t t = 0; t < T; t++) {
8447-
int64_t t_offset = t * t_stride;
8448-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8449-
float * state_cur = state + state_offset;
8450-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8451-
8452-
for (int64_t h = h_start; h < h_end; h++) {
8453-
int64_t h_offset = h * h_stride;
8454-
int64_t t_h_offset = t_offset + h_offset;
8455-
int64_t h_2d_offset = h * h_stride_2d;
8456-
8457-
for (int64_t ii = 0; ii < head_size; ii++) {
8458-
int64_t t_h_i_offset = t_h_offset + ii;
8459-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8460-
8461-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8474+
#if defined(__ARM_FEATURE_SVE)
8475+
// scalar Route to scalar implementation //TODO: Write SVE code
8476+
for (int64_t t = 0; t < T; t++) {
8477+
int64_t t_offset = t * t_stride;
8478+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8479+
float * state_cur = state + state_offset;
8480+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8481+
8482+
for (int64_t h = h_start; h < h_end; h++) {
8483+
int64_t h_offset = h * h_stride;
8484+
int64_t t_h_offset = t_offset + h_offset;
8485+
int64_t h_2d_offset = h * h_stride_2d;
8486+
8487+
for (int64_t i = 0; i < head_size; i++) {
8488+
int64_t t_h_i_offset = t_h_offset + i;
8489+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8490+
8491+
float v_val = v[t_h_i_offset];
8492+
8493+
float sa = 0, result = 0;
8494+
for (int64_t j = 0; j < head_size; j++) {
8495+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8496+
}
84628497

8463-
float sa = 0;
8464-
{
8465-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8466-
GGML_F32_VEC ax[GGML_F32_ARR];
8467-
GGML_F32_VEC ay[GGML_F32_ARR];
8468-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8469-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8470-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8471-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8472-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8473-
}
8498+
for (int64_t j = 0; j < head_size; j++) {
8499+
int64_t t_h_j_offset = t_h_offset + j;
8500+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8501+
8502+
float r_val = r[t_h_j_offset];
8503+
float w_val = w[t_h_j_offset];
8504+
float k_val = k[t_h_j_offset];
8505+
float b_val = b[t_h_j_offset];
8506+
float kv_val = v_val * k_val;
8507+
float prev_state_val = state_prev[h_2d_i_j_offset];
8508+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8509+
result += state_cur[h_2d_i_j_offset] * r_val;
84748510
}
8475-
GGML_F32_VEC_REDUCE(sa, sum);
8511+
dst_data[t_h_i_offset] = result;
84768512
}
8513+
}
8514+
}
8515+
#else
8516+
for (int64_t t = 0; t < T; t++) {
8517+
int64_t t_offset = t * t_stride;
8518+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8519+
float * state_cur = state + state_offset;
8520+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8521+
8522+
for (int64_t h = h_start; h < h_end; h++) {
8523+
int64_t h_offset = h * h_stride;
8524+
int64_t t_h_offset = t_offset + h_offset;
8525+
int64_t h_2d_offset = h * h_stride_2d;
8526+
8527+ 558
for (int64_t ii = 0; ii < head_size; ii++) {
8528+
int64_t t_h_i_offset = t_h_offset + ii;
8529+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8530+
8531+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8532+
8533+
float sa = 0;
8534+
{
8535+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8536+
GGML_F32_VEC ax[GGML_F32_ARR];
8537+
GGML_F32_VEC ay[GGML_F32_ARR];
8538+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8539+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8540+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8541+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8542+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8543+
}
8544+
}
8545+
GGML_F32_VEC_REDUCE(sa, sum);
8546+
}
84778547

8478-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8548+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
84798549

8480-
int64_t j = 0;
8481-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8482-
for (; j < head_size; j += GGML_F32_STEP) {
8483-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8484-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8485-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8550+
int64_t j = 0;
8551+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8552+
for (; j < head_size; j += GGML_F32_STEP) {
8553+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8554+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8555+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
84868556

8487-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8488-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8489-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8490-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8557+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8558+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8559+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8560+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
84918561

8492-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8562+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
84938563

8494-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8495-
// kv + s * decay + sa * b
8496-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8497-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8498-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8564+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8565+
// kv + s * decay + sa * b
8566+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8567+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8568+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
84998569

8500-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8570+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8571+
}
8572+
}
8573+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8574+
8575+
// There shouldn't be left-overs though.
8576+
for (; j < head_size; j++) {
8577+
int64_t t_h_j_offset = t_h_offset + j;
8578+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8579+
8580+
float r_val = r[t_h_j_offset];
8581+
float w_val = w[t_h_j_offset];
8582+
float k_val = k[t_h_j_offset];
8583+
float b_val = b[t_h_j_offset];
8584+
float kv_val = v[t_h_i_offset] * k_val;
8585+
8586+
float prev_state_val = state_prev[h_2d_i_j_offset];
8587+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8588+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85018589
}
8502-
}
8503-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8504-
8505-
// There shouldn't be left-overs though.
8506-
for (; j < head_size; j++) {
8507-
int64_t t_h_j_offset = t_h_offset + j;
8508-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8509-
8510-
float r_val = r[t_h_j_offset];
8511-
float w_val = w[t_h_j_offset];
8512-
float k_val = k[t_h_j_offset];
8513-
float b_val = b[t_h_j_offset];
8514-
float kv_val = v[t_h_i_offset] * k_val;
8515-
8516-
float prev_state_val = state_prev[h_2d_i_j_offset];
8517-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8518-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85198590
}
85208591
}
85218592
}
8522-
}
8593+
#endif
85238594
#else
85248595
for (int64_t t = 0; t < T; t++) {
85258596
int64_t t_offset = t * t_stride;

0 commit comments

Comments
 (0)
0