8000 Enable qint8 and quint8 add for AArch64 using ACL directly by fadara01 · Pull Request #148653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable qint8 and quint8 add for AArch64 using ACL directly #148653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

fadara01
Copy link
Collaborator
@fadara01 fadara01 commented Mar 6, 2025

Stack from ghstack (oldest at bottom):

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson david.svantesson-yeung@arm.com
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Mar 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148653

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8be99ae with merge base 46f096b (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Mar 6, 2025
fadara01 added a commit that referenced this pull request Mar 6, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>
Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f

ghstack-source-id: b01e2e7
Pull Request resolved: #148653
@fadara01
Copy link
Collaborator Author
fadara01 commented Mar 6, 2025

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Mar 6, 2025
@fadara01
Copy link
Collaborator Author
fadara01 commented Mar 6, 2025

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Mar 6, 2025
Update
8000
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 6, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>
Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f

ghstack-source-id: 2453ea5
Pull Request resolved: #148653
@fadara01
Copy link
Collaborator Author

The two failures here have be due to a flakey trunk. Please see how all the tests are passing in the original PR, this ghstack PR was created from (verbatim)

[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 12, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>
Change-Id: I2d55977b1ac9e19e67058dc31d7b88fc386f1b5f

ghstack-source-id: dd16790
Pull Request resolved: #148653
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 12, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

ghstack-source-id: 407339a
Pull Request resolved: #148653
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 12, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

ghstack-source-id: 840e43d
Pull Request resolved: #148653
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 12, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

ghstack-source-id: 4b9fe53
Pull Request resolved: #148653
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 17, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

ghstack-source-id: 882f236
Pull Request resolved: #148653
@fadara01
Copy link
Collaborator Author

Failure above is unrelated

ERROR: No matching distribution found for torchtune==0.6.0.dev20250115

@malfet
Copy link
Contributor
malfet commented Mar 18, 2025

@pytorchbot merge -f "Seems fine, provided it gets removed before next release and integrated thru OneDNN"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

fadara01 added a commit to fadara01/pytorch that referenced this pull request Mar 18, 2025
…48653)

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: pytorch#148653
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#148585

(cherry picked from commit 6c2db8f)
pytorch-bot bot pushed a commit that referenced this pull request Mar 27, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: #148653
Approved by: https://github.com/malfet
ghstack dependencies: #148585

(cherry picked from commit 6c2db8f)
malfet added a commit that referenced this pull request Mar 28, 2025
…ough ACL directly. (#149435)

* Enable fast qlinear static/dynamic path for AArch64 through ACL directly (#148585)

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs #126687, #139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation.
However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`.

- **For `qlinear_dynamic` (dynamically quantized matmuls):**

This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:
```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))
```

- **For `qlinear` (statically quantized matmuls):**

This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64.

```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()

    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)

    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))
```

Pull Request resolved: #148585
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 08a644a)

* Enable qint8 and quint8 add for AArch64 using ACL directly (#148653)

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: #148653
Approved by: https://github.com/malfet
ghstack dependencies: #148585

(cherry picked from commit 6c2db8f)

* [Build] Guard per-op headers in ACLUtils.cpp (#149417)

To fix internal build failures, where per-op headers are not generated.
We really should have lint for something like that.

Test Plan: CI

Reviewed By: izaitsevfb

Differential Revision: D71406882

Pull Request resolved: #149417
Approved by: https://github.com/Skylion007, https://github.com/izaitsevfb

(cherry picked from commit 5db3a4a)

---------

Co-authored-by: Nikita Shulga <nshulga@meta.com>
@github-actions github-actions bot deleted the gh/fadara01/7/head branch April 22, 2025 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
arm priority ciflow/linux-aarch64 linux aarch64 CI workflow Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0