@@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32(
7641
7641
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7642
7642
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7643
7643
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7644
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7644
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7646
7646
7647
7647
// use the output as the source for the next token-wise iterations
7648
7648
if (i2 > 0 ) { s0 = s; }
@@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8070
8070
#define GGML_F32X_MUL GGML_F32x16_MUL
8071
8071
#define GGML_F32X_FMA GGML_F32x16_FMA
8072
8072
#define WKV_VECTOR_SIZE 16
8073
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8074
+ #define GGML_F32X GGML_F32xt
8075
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8076
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8077
+ #define GGML_F32X_STORE GGML_F32xt_STORE
8078
+ #define GGML_F32X_MUL GGML_F32xt_MUL
8079
+ #define GGML_F32X_FMA GGML_F32xt_FMA
8080
+ #define WKV_VECTOR_SIZE 8
8073
8081
#elif defined(__ARM_NEON) && defined(__aarch64__)
8074
8082
#define GGML_F32X GGML_F32x4
8075
8083
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8080
8088
#define WKV_VECTOR_SIZE 4
8081
8089
#endif
8082
8090
8091
+ int wkv_vector_size;
8083
8092
#ifdef WKV_VECTOR_SIZE
8084
- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8093
+ #if defined(__ARM_FEATURE_SVE)
8094
+ wkv_vector_size = svcntw ();
8095
+ #else
8096
+ wkv_vector_size = WKV_VECTOR_SIZE;
8097
+ #endif
8098
+ const int64_t vec_count = head_size / wkv_vector_size;
8085
8099
8086
8100
for (int64_t t = 0 ; t < T; t++) {
8087
8101
size_t t_offset = t * t_stride;
@@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8111
8125
GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val);
8112
8126
8113
8127
for (int64_t j = 0 ; j < vec_count; j++) {
8114
- size_t base_j = j * WKV_VECTOR_SIZE ;
8128
+ size_t base_j = j * wkv_vector_size ;
8115
8129
size_t t_h_j_offset = t_h_offset + base_j;
8116
8130
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8117
8131
@@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8136
8150
}
8137
8151
8138
8152
// Handle remaining elements, this will not be used.
8139
- for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size; j++) {
8153
+ for (int64_t j = vec_count * wkv_vector_size ; j < head_size; j++) {
8140
8154
size_t t_h_j_offset = t_h_offset + j;
8141
8155
size_t h_2d_i_j_offset = h_2d_i_offset + j;
8142
8156
float v_val = v[t_h_j_offset];
@@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32(
8272
8286
#define GGML_F32X_MUL GGML_F32x16_MUL
8273
8287
#define GGML_F32X_FMA GGML_F32x16_FMA
8274
8288
#define GLA_VECTOR_SIZE 16
8289
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8290
+ #define GGML_F32X GGML_F32xt
8291
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8292
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8293
+ #define GGML_F32X_STORE GGML_F32xt_STORE
8294
+ #define GGML_F32X_MUL GGML_F32xt_MUL
8295
+ #define GGML_F32X_FMA GGML_F32xt_FMA
8296
+ #define GLA_VECTOR_SIZE 8
8275
8297
#elif defined(__ARM_NEON) && defined(__aarch64__)
8276
8298
#define GGML_F32X GGML_F32x4
8277
8299
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32(
8282
8304
#define GLA_VECTOR_SIZE 4
8283
8305
#endif
8284
8306
8307
+ int
9E88
gla_vector_size;
8285
8308
#ifdef GLA_VECTOR_SIZE
8286
- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8309
+ #if defined(__ARM_FEATURE_SVE)
8310
+ gla_vector_size = svcntw ();
8311
+ #else
8312
+ gla_vector_size = GLA_VECTOR_SIZE;
8313
+ #endif
8314
+ const int64_t vec_count = head_size / gla_vector_size;
8287
8315
8288
8316
for (int64_t t = 0 ; t < T; t++) {
8289
8317
size_t t_offset = t * t_stride;
@@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32(
8310
8338
GGML_F32X g_vec = GGML_F32X_SET1 (g_val);
8311
8339
8312
8340
for (int64_t j = 0 ; j < vec_count; j++) {
8313
- size_t base_j = j * GLA_VECTOR_SIZE ;
8341
+ size_t base_j = j * gla_vector_size ;
8314
8342
size_t t_h_j_offset = t_h_offset + base_j;
8315
8343
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8316
8344
@@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32(
8334
8362
}
8335
8363
8336
8364
// Handle remaining elements, this will not be used.
8337
- for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size; j++) {
8365
+ for (int64_t j = vec_count * gla_vector_size ; j < head_size; j++) {
8338
8366
size_t t_h_j_offset = t_h_offset + j;
8339
8367
size_t h_2d_i_j_offset = h_2d_i_offset + j;
8340
8368
float v_val = v[t_h_j_offset];
@@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
8443
8471
int64_t h_stride_2d = head_size * head_size;
8444
8472
8445
8473
#if defined(GGML_SIMD)
8446
- for (int64_t t = 0 ; t < T; t++) {
8447
- int64_t t_offset = t * t_stride;
8448
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
8449
- float * state_cur = state + state_offset;
8450
- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8451
-
8452
- for (int64_t h = h_start; h < h_end; h++) {
8453
- int64_t h_offset = h * h_stride;
8454
- int64_t t_h_offset = t_offset + h_offset;
8455
- int64_t h_2d_offset = h * h_stride_2d;
8456
-
8457
- for (int64_t ii = 0 ; ii < head_size; ii++) {
8458
- int64_t t_h_i_offset = t_h_offset + ii;
8459
- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8460
-
8461
- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8474
+ #if defined(__ARM_FEATURE_SVE)
8475
+ // scalar Route to scalar implementation //TODO: Write SVE code
8476
+ for (int64_t t = 0 ; t < T; t++) {
8477
+ int64_t t_offset = t * t_stride;
8478
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8479
+ float * state_cur = state + state_offset;
8480
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8481
+
8482
+ for (int64_t h = h_start; h < h_end; h++) {
8483
+ int64_t h_offset = h * h_stride;
8484
+ int64_t t_h_offset = t_offset + h_offset;
8485
+ int64_t h_2d_offset = h * h_stride_2d;
8486
+
8487
+ for (int64_t i = 0 ; i < head_size; i++) {
8488
+ int64_t t_h_i_offset = t_h_offset + i;
8489
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8490
+
8491
+ float v_val = v[t_h_i_offset];
8492
+
8493
+ float sa = 0 , result = 0 ;
8494
+ for (int64_t j = 0 ; j < head_size; j++) {
8495
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8496
+ }
8462
8497
8463
- float sa = 0 ;
8464
- {
8465
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8466
- GGML_F32_VEC ax[GGML_F32_ARR];
8467
- GGML_F32_VEC ay[GGML_F32_ARR];
8468
- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8469
- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8470
- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8471
- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8472
- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8473
- }
8498
+ for (int64_t j = 0 ; j < head_size; j++) {
8499
+ int64_t t_h_j_offset = t_h_offset + j;
8500
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8501
+
8502
+ float r_val = r[t_h_j_offset];
8503
+ float w_val = w[t_h_j_offset];
8504
+ float k_val = k[t_h_j_offset];
8505
+ float b_val = b[t_h_j_offset];
8506
+ float kv_val = v_val * k_val;
8507
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8508
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8509
+ result += state_cur[h_2d_i_j_offset] * r_val;
8474
8510
}
8475
- GGML_F32_VEC_REDUCE (sa, sum) ;
8511
+ dst_data[t_h_i_offset] = result ;
8476
8512
}
8513
+ }
8514
+ }
8515
+ #else
8516
+ for (int64_t t = 0 ; t < T; t++) {
8517
+ int64_t t_offset = t * t_stride;
8518
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8519
+ float * state_cur = state + state_offset;
8520
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8521
+
8522
+ for (int64_t h = h_start; h < h_end; h++) {
8523
+ int64_t h_offset = h * h_stride;
8524
+ int64_t t_h_offset = t_offset + h_offset;
8525
+ int64_t h_2d_offset = h * h_stride_2d;
8526
+
8527
+
558
for (int64_t ii = 0 ; ii < head_size; ii++) {
8528
+ int64_t t_h_i_offset = t_h_offset + ii;
8529
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8530
+
8531
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8532
+
8533
+ float sa = 0 ;
8534
+ {
8535
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8536
+ GGML_F32_VEC ax[GGML_F32_ARR];
8537
+ GGML_F32_VEC ay[GGML_F32_ARR];
8538
+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8539
+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8540
+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8541
+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8542
+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8543
+ }
8544
+ }
8545
+ GGML_F32_VEC_REDUCE (sa, sum);
8546
+ }
8477
8547
8478
- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
8548
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
8479
8549
8480
- int64_t j = 0 ;
8481
- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8482
- for (; j < head_size; j += GGML_F32_STEP) {
8483
- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8484
- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8485
- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8550
+ int64_t j = 0 ;
8551
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8552
+ for (; j < head_size; j += GGML_F32_STEP) {
8553
+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8554
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8555
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8486
8556
8487
- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8488
- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8489
- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8490
- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
8557
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8558
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8559
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8560
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
8491
8561
8492
- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
8562
+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
8493
8563
8494
- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8495
- // kv + s * decay + sa * b
8496
- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8497
- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8498
- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
8564
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8565
+ // kv + s * decay + sa * b
8566
+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8567
+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8568
+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
8499
8569
8500
- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8570
+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8571
+ }
8572
+ }
8573
+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8574
+
8575
+ // There shouldn't be left-overs though.
8576
+ for (; j < head_size; j++) {
8577
+ int64_t t_h_j_offset = t_h_offset + j;
8578
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8579
+
8580
+ float r_val = r[t_h_j_offset];
8581
+ float w_val = w[t_h_j_offset];
8582
+ float k_val = k[t_h_j_offset];
8583
+ float b_val = b[t_h_j_offset];
8584
+ float kv_val = v[t_h_i_offset] * k_val;
8585
+
8586
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8587
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8588
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8501
8589
}
8502
- }
8503
- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8504
-
8505
- // There shouldn't be left-overs though.
8506
- for (; j < head_size; j++) {
8507
- int64_t t_h_j_offset = t_h_offset + j;
8508
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8509
-
8510
- float r_val = r[t_h_j_offset];
8511
- float w_val = w[t_h_j_offset];
8512
- float k_val = k[t_h_j_offset];
8513
- float b_val = b[t_h_j_offset];
8514
- float kv_val = v[t_h_i_offset] * k_val;
8515
-
8516
- float prev_state_val = state_prev[h_2d_i_j_offset];
8517
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8518
- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8519
8590
}
8520
8591
}
8521
8592
}
8522
- }
8593
+ # endif
8523
8594
#else
8524
8595
for (int64_t t = 0 ; t < T; t++) {
8525
8596
int64_t t_offset = t * t_stride;
0 commit comments