-
Notifications
You must be signed in to change notification settings - Fork 11.9k
ggml: aarch64: Implement SVE F32 kernels for Mamba Model #13602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Hi @ggerganov please support to review this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to split the PR in 2 parts. First part with:
- Add SVE support for ggml_vec_dot_f32() function.
- Add SVE support for ggml_vec_mad_f32() function.
- Add SVE support for ggml_vec_scale_f32() function.
The second part with Mamba-specific changes.
For the first part I need to see what is the improvement over the existing GGML_SIMD
implementation using ARM_NEON
for example, which AFAIK should always be available when SVE is available.
#if defined(__ARM_FEATURE_SVE) | ||
|
||
GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); | ||
const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; | ||
const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 | ||
const int ggml_f32_step = 2 * ggml_f32_epr; | ||
GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); | ||
|
||
GGML_F32_VEC ax[GGML_F32_ARR]; | ||
GGML_F32_VEC ay[GGML_F32_ARR]; | ||
const int np = (n & ~(ggml_f32_step - 1)); | ||
svfloat32_t ax1,ax2; | ||
svfloat32_t ay1,ay2; | ||
for ( int i = 0; i < np; i += ggml_f32_step) { | ||
|
||
for (int i = 0; i < np; i += GGML_F32_STEP) { | ||
for (int j = 0; j < GGML_F32_ARR; j++) { | ||
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); | ||
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||
ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); | ||
ax1 = GGML_F32_VEC_LOAD(x + i); | ||
ay1 = GGML_F32_VEC_LOAD(y + i); | ||
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); | ||
|
||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); | ||
GGML_F32_VEC_STORE(y + i, ay1); | ||
|
||
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); | ||
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); | ||
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); | ||
|
||
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); | ||
} | ||
} | ||
// leftovers | ||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only | ||
if(np<n) { | ||
svbool_t pg =svwhilelt_b32(np, n); | ||
ax1 = svld1_f32(pg, x + np); | ||
ay1 = svld1_f32(pg, y + np); | ||
ay1 = svmad_f32_m(pg, ax1, vx, ay1); | ||
|
||
svst1_f32(pg, y + np, ay1); | ||
} | ||
#else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How big is the benefit from these special-cased implementations compared to using the GGML_SIMD
abstraction? If the benefit is not significant, it's better to use the existing implementation and avoid this extra code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggerganov,
The function ggml_vec_mad_f32() is called from flash attention operation. Currently this operation is supported in some models like phi-3-4k, falcon-7B and Qwen2-7B. I have converted all these to F32 gguf format. But none of the models are using ggml_compute_forward_flash_attn_back_f32() instead they are using ggml_compute_forward_flash_attn_ext_f16() which doesnot call ggml_vec_mad_f32() . For this reason I could'nt show the performance results. But, I can assure there will be good benifit compared to Neon version if model uses this function because we saw speed up for ggml_vec_dot_f32() which is similar to this function.
Will it be fine to proceed pushing this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggerganov,
Following up on the previous comment. Please let me know if you’re okay with proceeding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate the changes in ops.cpp
in a separate PR.
svfloat32_t ay1,ay2; | ||
for ( int i = 0; i < np; i += ggml_f32_step) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
svfloat32_t ay1,ay2; | |
for ( int i = 0; i < np; i += ggml_f32_step) { | |
svfloat32_t ay1; | |
svfloat32_t ay2; | |
for (int i = 0; i < np; i += ggml_f32_step) { |
} | ||
// leftovers | ||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only | ||
if(np<n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(np<n) { | |
if (np < n) { |
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only | ||
if(np<n) { | ||
svbool_t pg = svwhilelt_b32(np, n); | ||
ay1 = svld1_f32(pg, y+np); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ay1 = svld1_f32(pg, y+np); | |
ay1 = svld1_f32(pg, y + np); |
} | ||
// leftovers | ||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only | ||
if(np<n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(np<n) { | |
if (np < n) { |
svfloat32_t ax1,ax2; | ||
svfloat32_t ay1,ay2; | ||
for ( int i = 0; i < np; i += ggml_f32_step) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
svfloat32_t ax1,ax2; | |
svfloat32_t ay1,ay2; | |
for ( int i = 0; i < np; i += ggml_f32_step) { | |
svfloat32_t ax1, ax2; | |
svfloat32_t ay1, ay2; | |
for (int i = 0; i < np; i += ggml_f32_step) { |
if(np2<n){ | ||
svbool_t pg = svwhilelt_b32(np2,n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(np2<n){ | |
svbool_t pg = svwhilelt_b32(np2,n); | |
if (np2 < n){ | |
svbool_t pg = svwhilelt_b32(np2, n); |
// leftovers | ||
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop | ||
const int np2 = (n & ~(ggml_f32_epr - 1)); | ||
for ( int i = np; i < np2; i += ggml_f32_epr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for ( int i = np; i < np2; i += ggml_f32_epr) { | |
for (int i = np; i < np2; i += ggml_f32_epr) { |
svfloat32_t sum8 = svdup_n_f32(0.0f); | ||
svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8; | ||
svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8; | ||
for ( int i = 0; i < np; i += ggml_f32_step) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for ( int i = 0; i < np; i += ggml_f32_step) { | |
for (int i = 0; i < np; i += ggml_f32_step) { |
#if defined(__ARM_FEATURE_SVE) | ||
|
||
const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#if defined(__ARM_FEATURE_SVE) | |
const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; | |
#if defined(__ARM_FEATURE_SVE) | |
const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; |
#define GGML_F32xt svfloat32_t | ||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f) | ||
#define GGML_F32xt_SET1(x) svdup_n_f32(x) | ||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) | ||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) | ||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) | ||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) | ||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) | ||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) | ||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) | ||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vertical-align these
This PR adds SVE kernel support for F32 datatype specific to Mamba Model on ARM architecture.
Major code changes:
Performance
This PR improves performance by ~1.3x compared to the previous NEON-based implementation.
Model: falcon-mamba-7B-F32.gguf
Command: ./build/bin/llama-bench -m falcon-mamba-7B-F32.gguf -t 8,16,32,64 -p 128,1024 -n 0
Perplexity
There is no change in model accuracy as a result of this PR.
Command: ./build/bin/llama-perplexity -s 0 -np 128 -t 64 -m falcon-mamba-7B-F32.gguf -c 128 -b 128 --chunks 16 -f scripts/wikitext-2-raw/wiki.test.raw
Contributor: Vineel Abhinav Gottala
cc: @Vithulep