-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Inductor-CPU] Disable auto-tuning for templated int8 WoQ GEMM for small M to fix perf regression #148502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor-CPU] Disable auto-tuning for templated int8 WoQ GEMM for small M to fix perf regression #148502
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148502
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit b08a6a3 with merge base 98458e5 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@leslie-fang-intel, please advise if we should disable the templated GEMM kernel for int8 WoQ case on machines that don't support AMX ISA. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share the performance comparisons on the micro and end-to-end benchmarks with various small M? I'm wondering why the template GEMM with AVX512 is slower than AMX counterpart. The GEMM with small M is memory-intensive, not compute-intensive, AMX cannot help a lot here...
It's faster @jgong5, the idea behind this PR is to choose ATen op In this PR, I deliberately (re-)enabled AMX GEMM even for small I guess another approach would've been to disable the AVX512 GEMM micro-kernel altogether for int8 WoQ with BF16 activation, so that it wouldn't be chosen even on machines that lack AMX support (I haven't collected perf data on such machines, though). Thanks! |
How much perf gap did we see for a specific GEMM between ATEN kernel and AVX512 Micro GEMM? For both the benchmark data and model runtime data. |
Hi @leslie-fang-intel, the rationale for this PR can be proved simply by benchmarking with int8 WoQ enabled & disabled for an LLM, and comparing next-token generation time. With max-autotune enabled, the only change is that templated GEMMs are used instead of The data you asked for is thus not necessary to prove the source of regression, so can we follow the standard operating procedure of fixing a regression first, so that at least the regression with respect to Thanks! |
6753078
to
9fbd241
Compare
Hi @leslie-fang-intel @jgong5, I modified this PR to disable auto-tuning for int8 WoQ GEMM case for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you have converted this PR to a draft and are currently collecting performance data. Please re-request a review when you believe the PR is ready.
I understand the problem of differed performance you saw between ATen kernel and templated AVX512 GEMM kernel during auto-tuning and end-to-end runs. It is definitively something we need to fix. But I am concerned about the way you address the problem. It is more like a temporary workaround and this workaround promotes an even worse kernel (AMX GEMM kernel) for small "M". If we force the compilation to use the template GEMMs, we would have a regression with your fix. Can you share perf numbers as I requested for us to know the problems better? |
Discovered something surprising - the ATen kernel
@jgong5, yes it's a workaround.
Sorry, is there any plan to force use of templated GEMMs even if their ATen counterpart is more performant? Thanks! |
User can set |
Thanks for the info, @leslie-fang-intel! The revision wouldn't run into that problem but I'm currently looking into revising the auto-tuning implementation to avoid using this workaround. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine if you disable templated codegen for this particular cases and always use the aten in this PR. But feel free to improve the gemm template in it too.
return create_epilogue_with_attr( | ||
buf, "mul", other=realize_inputs(expand(scale, layout.size)) | ||
) | ||
def _use_autotuning() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a check function use_cpp_gemm_template
that does similar things. It is called in the same function below.
Summary
Described in #148494 - this PR fixes a regression (compared to the default Inductor-CPU behavior of not using max-autotune) for templated int8 WoQ GEMM (with BF16 activation) for small M dimension by disabling auto-tuning for small
M
, so that the ATen_weight_int8pack_mm
kernel would be used.The significance is next-token generation of LLMs.
Turning off auto-tuning for small
M
is a workaround. Ideally, we should improve the auto-tuning infra to prevent templated AVX512 GEMM for int8 WoQ being chosen if_weight_int8pack_mm
would be faster E2E.Details
During auto-tuning, AVX512 GEMM micro-kernel is chosen for small
M
, but it's faster during auto-tuning, and performs worse E2E, which is expected as it can exploit cache locality for inputs while being called several times for the same inputs in a loop, but the same behavior isn't observed for its ATen counterpart_weight_int8pack_mm
, which performs worse than it during auto-tuning but performs better E2E. However, it too would've benefited from better cache locality for inputs if it had been benchmarked for a longer time-period. Even so, the latency of the templated GEMM would still have been lower, even if we had benchmarked for more time._weight_int8pack_mm
latency during autotuning benchmarking_weight_int8pack_mm
latency E2E_weight_int8pack_mm
UTs
python test/inductor/test_cpu_select_algorithm.py -v -k test_int8_woq_mm
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov