8000 Enable fp16 linear layers in PyTorch via ACL by renato-arantes · Pull Request #144992 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable fp16 linear layers in PyTorch via ACL #144992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

renato-arantes
Copy link
Contributor
@renato-arantes renato-arantes commented Jan 16, 2025

This pull request aims to enable the use of linear layers with the fp16 data type through the ACL.

On a Graviton3 instance running with 16 threads, torch.randn(2048, 4096, dtype=torch.half) will take 50+% less time to complete compared with torch.randn(2048, 4096, dtype=torch.float32).

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @yf225 @ColinPeppler @desertfire

Signed-off-by: Renato Arantes <renato.arantes@arm.com>
@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: linalg_frontend release notes category labels Jan 16, 2025
Copy link
pytorch-bot bot commented Jan 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144992

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 3 Cancelled Jobs, 25 Unrelated Failures

As of commit fa04172 with merge base cf28d61 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@annop-w
Copy link
Contributor
annop-w commented Jan 16, 2025

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Jan 16, 2025
@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 16, 2025
@malfet
Copy link
Contributor
malfet commented Jan 16, 2025

@renato-arantes can you add some explanation as to why are you doing this? (I suspect performance, if so, I would love to see some sort of a script that one can run to measure the perf improvements between before and after)

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 16, 2025 22:10 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 16, 2025 22:10 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 16, 2025 22:10 Inactive
@renato-arantes
Copy link
Contributor Author
renato-arantes commented Jan 17, 2025

Hi @malfet

Yes, this PR is about improving performance as it enables the path from PyTorch to ACL and, therefore, avoids running bf16 reference in oneDNN. On average, on an AWS c7g instance running with 16 threads, for bf16, we got 8191.85 μs, and for fp32, we got 9629.15 μs, an improvement of 1437.30 μs or 17%. Here is the script used for this benchmark:

import torch
import torch.nn as nn
import torch.profiler as profiler
import time
# Enable torch.no_grad globally
torch.set_grad_enabled(False)

# Define models as nn.Modules
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, bias):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=bias)

    def forward(self, x):
        return self.linear(x)

N = 2048
K = 4096
M = 512
bias = False
dtype = torch.float16 # <<<--- change here for fp32

model = LinearModel(K, N, bias=bias).to(dtype=dtype)

# Generate random inputs
input = torch.randn(M, K, dtype=dtype)

# Number of iterations for benchmarking
num_iterations = 1000

# Warm-up function
def warmup(model, input_tensor):
    for _ in range(100):  # Warm-up phase to stabilize performance
        _ = model(input_tensor)

# Benchmark function
def benchmark(model, input_tensor):
    start_time = time.time()

    for _ in range(num_iterations):
        _ = model(input_tensor)

    end_time = time.time()
    return (end_time - start_time) / num_iterations

# Warm-up
print("Warming up...")
warmup(model, input)

# Benchmark without profiler
print("Benchmarking...")
average_time = benchmark(model, input)

print(f"Average execution time for Linear Layer: {average_time * 1e6:.2f} microseconds")

Signed-off-by: Renato Arantes <renato.arantes@arm.com>
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 17, 2025 14:56 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 17, 2025 14:56 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 17, 2025 14:56 Inactive
@IvanYashchuk IvanYashchuk removed their request for review January 20, 2025 09:38
nikhil-arm
nikhil-arm previously approved these changes Jan 20, 2025
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 20, 2025 11:29 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 20, 2025 11:29 Inactive
digantdesai
digantdesai previously approved these changes Jan 23, 2025
Copy link
Contributor
@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good to me.

I would improve the summary a bit and add before and after performance numbers on fp16 (and probably also for bf16 if you unblocked that as well). Also a test script, as a link in the summary :)

I left a couple of comments, let's make sure we address them before merging.

8000
@@ -117,7 +117,7 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
):
input_kernel = 1
if output.is_contiguous(memory_format=torch.contiguous_format) or (
TEST_ACL and dtype == torch.bfloat16
TEST_ACL and (dtype == torch.bfloat16 or dtype == torch.half)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, how does this work on non ACL?

@@ -90,6 +90,10 @@ inline bool mkldnn_bf16_device_check_arm() {
return cpuinfo_initialize() && cpuinfo_has_arm_bf16();
}

inline bool mkldnn_fp16_device_check_arm() {
return cpuinfo_initialize() && cpuinfo_has_arm_neon_fp16();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't care about aarch32 else we would need to check for fp16arith

@nikhil-arm
Copy link
Collaborator

@pytorchmergebot revert -c nosignal -m "Accuracy Test failures"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Feb 15, 2025
This reverts commit 5b37249.

Reverted #144992 on behalf of https://github.com/nikhil-arm due to Accuracy Test failures ([comment](#144992 (comment)))
@pytorchmergebot
Copy link
Collaborator

@renato-arantes your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Feb 15, 2025
@pytorch-bot pytorch-bot bot dismissed stale reviews from nikhil-arm, digantdesai, and malfet February 15, 2025 12:41

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@fadara01
Copy link
Collaborator
fadara01 commented Feb 15, 2025

This PR only passed the CI tests because the Arm Compute Library (ACL) version in the jammy docker image used in the CI is outdated - v24.04, while the one used in manylinux is v24.09.

ACL v24.04 with multi_isa=1 and arch=armv8a does not enable FP16, while ACL v24.09 does. Hence the CI above only tested the oneDNN reference implementation as FP16 was not enabled in ACL. Had ACL v24.09 been used, the CI would have failed (as it does in #138889 where the ACL version in jammy is up to date) since the tolerance in the tests assumes FP32 accumulation, while ACL does FP16 accumulation.

Reverting this should fix the CI failures in #138889

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results February 15, 2025 14:54 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results February 15, 2025 14:54 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results February 15, 2025 14:55 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results February 15, 2025 14:55 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results February 15, 2025 14:55 Inactive
Copy link
Contributor
@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the tolerance in the tests assumes FP32 accumulation, while ACL does FP16 accumulation.

PyTorch eager mode always does accumulation over fp32 even for fp16 input dtypes. There is a PR somewhere that introduces context manager that allows for lower precision, but in general default codepath should do reductions in higher-precision dtypes, as defined in op_math_t

Raymo111 pushed a commit that referenced this pull request Feb 20, 2025
This reverts commit 5b37249.

Reverted #144992 on behalf of https://github.com/nikhil-arm due to Accuracy Test failures ([comment](#144992 (comment)))
pytorch-bot bot pushed a commit that referenced this pull request Feb 24, 2025
This reverts commit 5b37249.

Reverted #144992 on behalf of https://github.com/nikhil-arm due to Accuracy Test failures ([comment](#144992 (comment)))
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 25, 2025
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@fadara01
Copy link
Collaborator

@renato-arantes what is the current status of this? do we plan to re-land this with newer versions of oneDNN/ACL that do accumulation in FP32?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: linalg_frontend release notes category Reverted Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0