-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Enable qint8 and quint8 add for AArch64 using ACL directly #148653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148653
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 8be99ae with merge base 46f096b ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f ghstack-source-id: b01e2e7 Pull Request resolved: #148653
@pytorchbot label "module: arm" |
@pytorchbot label "ciflow/linux-aarch64" |
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f ghstack-source-id: 2453ea5 Pull Request resolved: #148653
The two failures here have be due to a flakey trunk. Please see how all the tests are passing in the original PR, this ghstack PR was created from (verbatim) |
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f ghstack-source-id: dd16790 Pull Request resolved: #148653
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> ghstack-source-id: 407339a Pull Request resolved: #148653
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> ghstack-source-id: 840e43d Pull Request resolved: #148653
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> ghstack-source-id: 4b9fe53 Pull Request resolved: #148653
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> ghstack-source-id: 882f236 Pull Request resolved: #148653
Failure above is unrelated
|
@pytorchbot merge -f "Seems fine, provided it gets removed before next release and integrated thru OneDNN" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…48653) This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Pull Request resolved: pytorch#148653 Approved by: https://github.com/malfet ghstack dependencies: pytorch#148585 (cherry picked from commit 6c2db8f)
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Pull Request resolved: #148653 Approved by: https://github.com/malfet ghstack dependencies: #148585 (cherry picked from commit 6c2db8f)
…ough ACL directly. (#149435) * Enable fast qlinear static/dynamic path for AArch64 through ACL directly (#148585) This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly. Context: PRs #126687, #139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation. However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature. Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation. This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`. - **For `qlinear_dynamic` (dynamically quantized matmuls):** This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below: ``` # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com> # SPDX-License-Identifier: BSD-3-Clause import torch from transformers import AutoModel, AutoConfig import time import numpy as np from argparse import ArgumentParser class ModelArgumentParser(ArgumentParser): def __init__(self) -> None: super().__init__(description="huggingface model") self.add_argument("--context_length", help="context length - number of input tokens", type=int, default=64 ) self.add_argument("--model", help="model checkpoint - i.e. 'bert-base-uncased'", type=str, default=None) self.add_argument("--iters", help="benchmark iterations", default=500) if __name__ == "__main__": parser = ModelArgumentParser() args = parser.parse_args() model_name = args.model config = AutoConfig.from_pretrained(model_name) batch_size = 1 model = AutoModel.from_pretrained(model_name) model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) model.eval() inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu") times = [] with torch.no_grad(): # warmup for _ in range(10): model(inputs) # benchmark for _ in range(args.iters): s = time.time_ns() model(inputs) times.append((time.time_ns() - s) / 1e6) print("Model = ", model_name) print("Context Length = ", args.context_length) print("Min (ms) = ", min(times)) print("Mean (ms) = ", np.mean(times)) ``` - **For `qlinear` (statically quantized matmuls):** This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below. The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`. The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64. ``` # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com> # SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn from torch.quantization import QConfig from torch.ao.quantization.observer import HistogramObserver, default_weight_observer import torch import torch.nn as nn import numpy as np import random from argparse import ArgumentParser import time class ModelArgumentParser(ArgumentParser): def __init__(self) -> None: super().__init__() self.add_argument("--M", help="M dimension", type=int, default=64 ) self.add_argument("--K", help="K dimension", type=int, default=64 ) self.add_argument("--N", help="N dimension", type=int, default=64 ) self.add_argument("--signed_input", help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8", action="store_true" ) self.add_argument("--seed", help="Random seed", type=int, default=42 ) self.add_argument("--iters", help="benchmark iterations", default=500) def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) class LinearModel(nn.Module): def __init__(self, K, N): super(LinearModel, self).__init__() self.quant = torch.quantization.QuantStub() self.fc = nn.Linear(K, N) self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.fc(x) x = self.dequant(x) return x def quantize_model(model, args): qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8 if args.signed_input else torch.quint8), weight=default_weight_observer, ) # Prepare the model for static quantization # Specify quantization configurations model.qconfig = qconfig model_prepared = torch.quantization.prepare(model_fp32) # Calibrate the model with sample inputs # Example input data for calibration with torch.no_grad(): sample_data = torch.randn(args.M, args.K) model_prepared(sample_data) # Convert the prepared model to a quantized model model_quantized = torch.quantization.convert(model_prepared) return model_quantized if __name__ == "__main__": parser = ModelArgumentParser() args = parser.parse_args() set_seed(args.seed) model_fp32 = LinearModel(args.K, args.N) model_quantized = quantize_model(model_fp32, args) inputs = torch.randn(args.M, args.K) times = [] with torch.no_grad(): # warmup for _ in range(10): model_quantized(inputs) # benchmark for _ in range(args.iters): s = time.time_ns() model_quantized(inputs) times.append((time.time_ns() - s) / 1e6) print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input) print("Min Times (ms) = ", min(times)) print("Mean Times (ms) = ", np.mean(times)) ``` Pull Request resolved: #148585 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> (cherry picked from commit 08a644a) * Enable qint8 and quint8 add for AArch64 using ACL directly (#148653) This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly. Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x. Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com> Pull Request resolved: #148653 Approved by: https://github.com/malfet ghstack dependencies: #148585 (cherry picked from commit 6c2db8f) * [Build] Guard per-op headers in ACLUtils.cpp (#149417) To fix internal build failures, where per-op headers are not generated. We really should have lint for something like that. Test Plan: CI Reviewed By: izaitsevfb Differential Revision: D71406882 Pull Request resolved: #149417 Approved by: https://github.com/Skylion007, https://github.com/izaitsevfb (cherry picked from commit 5db3a4a) --------- Co-authored-by: Nikita Shulga <nshulga@meta.com>
Stack from ghstack (oldest at bottom):
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.
Co-authored-by: David Svantesson david.svantesson-yeung@arm.com
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01