8000 [Inductor-CPU] Disable auto-tuning for templated int8 WoQ GEMM for small M to fix perf regression by sanchitintel · Pull Request #148502 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor-CPU] Disable auto-tuning for templated int8 WoQ GEMM for small M to fix perf regression #148502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

sanchitintel
Copy link
Collaborator
@sanchitintel sanchitintel commented Mar 4, 2025

Summary

Described in #148494 - this PR fixes a regression (compared to the default Inductor-CPU behavior of not using max-autotune) for templated int8 WoQ GEMM (with BF16 activation) for small M dimension by disabling auto-tuning for small M, so that the ATen _weight_int8pack_mm kernel would be used.
The significance is next-token generation of LLMs.

Turning off auto-tuning for small M is a workaround. Ideally, we should improve the auto-tuning infra to prevent templated AVX512 GEMM for int8 WoQ being chosen if _weight_int8pack_mm would be faster E2E.

Details

During auto-tuning, AVX512 GEMM micro-kernel is chosen for small M, but it's faster during auto-tuning, and performs worse E2E, which is expected as it can exploit cache locality for inputs while being called several times for the same inputs in a loop, but the same behavior isn't observed for its ATen counterpart _weight_int8pack_mm, which performs worse than it during auto-tuning but performs better E2E. However, it too would've benefited from better cache locality for inputs if it had been benchmarked for a longer time-period. Even so, the latency of the templated GEMM would still have been lower, even if we had benchmarked for more time.

M N K Templated GEMM latency during autotuning benchmarking Templated GEMM latency E2E _weight_int8pack_mm latency during autotuning benchmarking _weight_int8pack_mm latency E2E Ratio of E2E latency of templated GEMM over _weight_int8pack_mm
1 4096 4096 31.2 us 91.1 us 108.7 us 76.07 us 1.19
1 1024 4096 16.1 us 33.36 us 52.9 us 24.275 us 1.37
1 14336 4096 112.8 us 274.16 us 335.3 us 233.197 us 1.17
1 4096 14336 128.1 us 280.76 us 330 us 237.797 us 1.18
1 4096 128256 1.642 ms 2.16 ms 2.118ms 2.034 ms 1.06

UTs

python test/inductor/test_cpu_select_algorithm.py -v -k test_int8_woq_mm

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148502

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b08a6a3 with merge base 98458e5 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@sanchitintel sanchitintel added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Mar 4, 2025
@sanchitintel
Copy link
Collaborator Author

@leslie-fang-intel, please advise if we should disable the templated GEMM kernel for int8 WoQ case on machines that don't support AMX ISA. Thanks!

Copy link
Collaborator
@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share the performance comparisons on the micro and end-to-end benchmarks with various small M? I'm wondering why the template GEMM with AVX512 is slower than AMX counterpart. The GEMM with small M is memory-intensive, not compute-intensive, AMX cannot help a lot here...

@sanchitintel
Copy link
Collaborator Author
sanchitintel commented Mar 5, 2025

why the template GEMM with AVX512 is slower than AMX counterpart

It's faster

@jgong5, the idea behind this PR is to choose ATen op _weight_int8pack_mm for small M during auto-tuning because the ATen kernel performs better than the templated AVX512 GEMM end-to-end, although the templated AVX512 GEMM performs better during auto-tuning benchmarking. We had discussed this issue last year as well.

In this PR, I deliberately (re-)enabled AMX GEMM even for small M. because it's quite slow, so the ATen kernel (_weight_int8pack_mm) would be chosen over it during auto-tuning.

I guess another approach would've been to disable the AVX512 GEMM micro-kernel altogether for int8 WoQ with BF16 activation, so that it wouldn't be chosen even on machines that lack AMX support (I haven't collected perf data on such machines, though).

Thanks!

@leslie-fang-intel
Copy link
Collaborator

One reason for this behavior could be different cache behavior with respect to input data (for auto-tuning benchmarking vs. E2E model run). A secondary reason could be that the ATen kernel is generic but there are many codegened GEMMs created, which are not as i-cache friendly as the ATen kernel (since there's only one _weight_int8pack_mm but many codegened kernels), and the i-cache misses are not amortized by the speedup the codegened kernels bring in this case.

How much perf gap did we see for a specific GEMM between ATEN kernel and AVX512 Micro GEMM? For both the benchmark data and model runtime data.

@sanchitintel
Copy link
Collaborator Author
sanchitintel commented Mar 5, 2025

Hi @leslie-fang-intel, the rationale for this PR can be proved simply by benchmarking with int8 WoQ enabled & disabled for an LLM, and comparing next-token generation time. With max-autotune enabled, the only change is that templated GEMMs are used instead of _weight_int8pack_mm, and thus they're the only differentiating factor.

The data you asked for is thus not necessary to prove the source of regression, so can we follow the standard operating procedure of fixing a regression first, so that at least the regression with respect to max-autotune mode disabled (Inductor-CPU default behavior) can be fixed? Would it be okay to move the debugging discussions to the linked issue #148494?

Thanks!

@sanchitintel sanchitintel requested a review from jgong5 March 5, 2025 11:21
@sanchitintel sanchitintel requested review from mingfeima and removed request for chunyuan-w March 5, 2025 11:46
@sanchitintel sanchitintel changed the title [Inductor-CPU] Let AMX ISA be chosen for int8 WoQ GEMM for any input shapes [Inductor-CPU] Fix perf regression for templated int8 WoQ GEMM for small M dimension Mar 5, 2025
@sanchitintel sanchitintel force-pushed the int8_woq_gemm_use_amx_isa_if_available branch from 6753078 to 9fbd241 Compare March 5, 2025 19:30
Copy link
linux-foundation-easycla bot commented Mar 5, 2025

CLA Signed


The committers listed above are authorized under a signed CLA.

@sanchitintel
Copy link
Collaborator Author
sanchitintel commented Mar 5, 2025

Hi @leslie-fang-intel @jgong5, I modified this PR to disable auto-tuning for int8 WoQ GEMM case for M < 32 (templated GEMM is slower for M < 32, irrespective of whether AVX512 or AMX micro-kernel used. We may need to refine this heuristic, perhaps by considering shapes of K & N as well, so I'll gather data to support it, but then end-to-end performance may be different). For M >= 32, AMX GEMM micro-kernel based templated GEMM kernels would be used on machines that support the AMX ISA.

@sanchitintel sanchitintel marked this pull request as draft March 5, 2025 23:35
Copy link
Collaborator
@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you have converted this PR to a draft and are currently collecting performance data. Please re-request a review when you believe the PR is ready.

@jgong5
Copy link
Collaborator
jgong5 commented Mar 6, 2025

why the template GEMM with AVX512 is slower than AMX counterpart

It's faster

@jgong5, the idea behind this PR is to choose ATen op _weight_int8pack_mm for small M during auto-tuning because the ATen kernel performs better than the templated AVX512 GEMM end-to-end, although the templated AVX512 GEMM performs better during auto-tuning benchmarking. We had discussed this issue last year as well.

In this PR, I deliberately (re-)enabled AMX GEMM even for small M. because it's quite slow, so the ATen kernel (_weight_int8pack_mm) would be chosen over it during auto-tuning.

I guess another approach would've been to disable the AVX512 GEMM micro-kernel altogether for int8 WoQ with BF16 activation, so that it wouldn't be chosen even on machines that lack AMX support (I haven't collected perf data on such machines, though).

Thanks!

I understand the problem of differed performance you saw between ATen kernel and templated AVX512 GEMM kernel during auto-tuning and end-to-end runs. It is definitively something we need to fix. But I am concerned about the way you address the problem. It is more like a temporary workaround and this workaround promotes an even worse kernel (AMX GEMM kernel) for small "M". If we force the compilation to use the template GEMMs, we would have a regression with your fix. Can you share perf numbers as I requested for us to know the problems better?

@sanchitintel
Copy link
Collaborator Author
sanchitintel commented Mar 6, 2025

Can you share perf numbers as I requested for us to know the problems better?

Discovered something surprising - the ATen kernel _weight_int8pack_mm performs better E2E than during auto-tuning benchmarking! And its E2E performance is better than that of templated GEMM with AVX512 micro-kernel.
Similar trend with other small Ms (I've benchmarked 1 to 8 so far), and will post it soon.

M N K Templated GEMM latency during autotuning benchmarking Templated GEMM latency E2E _weight_int8pack_mm latency during autotuning benchmarking _weight_int8pack_mm latency E2E Ratio of E2E latency of templated GEMM over _weight_int8pack_mm
1 4096 4096 31.2 us 91.1 us 108.7 us 76.07 us 1.19
1 1024 4096 16.1 us 33.36 us 52.9 us 24.275 us 1.37
1 14336 4096 112.8 us 274.16 us 335.3 us 233.197 us 1.17
1 4096 14336 128.1 us 280.76 us 330 us 237.797 us 1.18
1 4096 128256 1.642 ms 2.16 ms 2.118ms 2.034 ms 1.06

But I am concerned about the way you address the problem. It is more like a temporary workaround and this workaround promotes an even worse kernel (AMX GEMM kernel) for small "M".

@jgong5, yes it's a workaround. _weight_int8pack_mm would be used instead of templated GEMM for small M.
Also, y'day, in this PR, I disabled auto-tuning for small M altogether for int8 WoQ case.
Previously in this PR, the poorly performing AMX micro-kernel based templated GEMM was deliberately being used for auto-tuning, so that it wouldn't be chosen after auto-tuning.

If we force the compilation to use the template GEMMs, we would have a regression with your fix.

Sorry, is there any plan to force use of templated GEMMs even if their ATen counterpart is more performant?
If not, then we wouldn't have had a regression.

Thanks!

@sanchitintel sanchitintel changed the title [Inductor-CPU] Fix perf regression for templated int8 WoQ GEMM for small M dimension [Inductor-CPU] Disable auto-tuning for templated int8 WoQ GEMM for small M to fix perf regression Mar 6, 2025
@leslie-fang-intel
Copy link
Collaborator

Sorry, is there any plan to force use of templated GEMMs even if their ATen counterpart is more performant?
If not, then we wouldn't have had a regression.

User can set max_autotune_gemm_backends="CPP" to force use it.

@sanchitintel
Copy link
Collaborator Author
sanchitintel commented Mar 7, 2025

User can set max_autotune_gemm_backends="CPP" to force use it.

Thanks for the info, @leslie-fang-intel! The revision wouldn't run into that problem but I'm currently looking into revising the auto-tuning implementation to avoid using this workaround.

Copy link
Collaborator
@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine if you disable templated codegen for this particular cases and always use the aten in this PR. But feel free to improve the gemm template in it too.

return create_epilogue_with_attr(
buf, "mul", other=realize_inputs(expand(scale, layout.size))
)
def _use_autotuning() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a check function use_cpp_gemm_template that does similar things. It is called in the same function below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0