-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Inductor][CPU] Add torchao da8w8 pattern with sym quantized act & wgt #142110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142110
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 604e4ec with merge base b31d3b2 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Will rebase with the latest main-branch to avoid a CI failure that's also been happening in other PRs - |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#142110) ### Summary Extends pytorch#142036 for Inductor pattern-matching pattern covered for torchao API `int8_dynamic_activation_int8_weight` in the following scenario (inference-only, freezing enabled) - - int8 quantized (symmetrically) activation (per token quantized). - Statically (so, scales are also constant. But then they would have been constant even in case of dynamic quantization due to constant weights, anyway) per-channel int8 quantized (symmetrically) weights (which are also constant because freezing is enabled). The pattern that's matched is `torch._intmm` -> convert to FP32/BF16 -> [optional expand for activation scale] ->`mul` -> `mul`. We don't check if the activation is dynamically quantized or whether the weights are statically quantized, though (since the implementation won't have have any side-effects even if that wouldn't be true). In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped (if activation is 2D). ### More details oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused. The fusion pattern used in this PR is `torch._intmm` -> convert to FP32/BF16 ->`mul`, which will be replaced by oneDNN qlinear op. The speedup over eager-mode is due to 2 reasons - 1. fusion of int8xint8 -> int32 GEMM, conversion to FP32/BF16 & application of weight scale. (In case of BF16, many intermediate conversions are also avoided). 2. weight is pre-packed & cached by Inductor, so a reorder is avoided at run-time. But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen. ### Verification Added UT in this PR ``` python test/inductor/test_mkldnn_pattern_matcher.py -v -k test_da8w8_sym_act_sym_wgt_with_int_mm ``` #### Corresponding torchao UTs 1. int8 Smoothquant legacy API - `TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear`. The difference from pytorch#139595 is that there are no reshapes of the linear output in this pattern. 2. int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights - ` TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpu` Pull Request resolved: pytorch#142110 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 ghstack dependencies: pytorch#142036
Stack from ghstack (oldest at bottom):
Summary
Extends #142036 for Inductor pattern-matching pattern covered for torchao API
int8_dynamic_activation_int8_weight
in the following scenario (inference-only, freezing enabled) -The pattern that's matched is
torch._intmm
-> convert to FP32/BF16 -> [optional expand for activation scale] ->mul
->mul
.We don't check if the activation is dynamically quantized or whether the weights are statically quantized, though (since the implementation won't have have any side-effects even if that wouldn't be true).
In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped (if activation is 2D).
More details
oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused.
The fusion pattern used in this PR is
torch._intmm
-> convert to FP32/BF16 ->mul
, which will be replaced by oneDNN qlinear op.The speedup over eager-mode is due to 2 reasons -
But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen.
Verification
Added UT in this PR
Corresponding torchao UTs< 8000 /h4>-
-
int8 Smoothquant legacy API -
TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear
.The difference from [Inductor][CPU] Fuse SmoothQuant int8 linear pattern #139595 is that there are no reshapes of the linear output in this pattern.
int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights -
TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpu
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov