8000 [ATen][CUDA] Optimize 128 bit vectorization by Aidyn-A · Pull Request #148320 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ATen][CUDA] Optimize 128 bit vectorization #148320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

Aidyn-A
Copy link
Collaborator
@Aidyn-A Aidyn-A commented Mar 3, 2025

Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

cc @ptrblck @msaroufim @eqy @jerryzh168 @manuelcandales @SherlockNoMad @angelayi

@Aidyn-A Aidyn-A added module: cuda Related to torch.cuda, and CUDA support in general module: core aten Related to change to the Core ATen opset labels Mar 3, 2025
@Aidyn-A Aidyn-A requested review from malfet, atalman and ngimel March 3, 2025 10:15
@Aidyn-A Aidyn-A self-assigned this Mar 3, 2025
@Aidyn-A Aidyn-A requested review from eqy and syed-ahmed as code owners March 3, 2025 10:15
Copy link
pytorch-bot bot commented Mar 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148320

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Cancelled Job, 50 Pending, 2 Unrelated Failures

As of commit 00da106 with merge base ce94b21 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Mar 3, 2025
@@ -215,6 +250,11 @@ static inline void launch_vectorized_kernel(
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
// Since we are not compiling vec8 kernel on sm<90, we are not calling it.
auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major < 9) {
Copy link
Collaborator
@Skylion007 Skylion007 Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does it increase compilation times? This is an very hot codepath and is used for primitives including copying Is the original PR description wrong? This won't affect compile times since it's not constexpr.

After examining where this call is used, if I am not mistaken; it's only used for vectorized copy of fp8 and lower dtypes. That is the reason why Why would this decrease/increase compilation time? Wouldn't those float8 kernels just not be built on lower SMs? Except SM89?

Do we actually not want to allow this on FP8 supporting consumer GPUs like SM89?! We may want to update the vec8 kernels if we are not compiling them on SM89

Suggested change
if (dprop->major < 9) {
if (dprop->major < 9 || (dprop->major == 8 && dprop->minor >= 9)) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~How much does it increase compilation times?

I would like ask @atalman about this.

it's only used for vectorized copy of fp8 and lower dtypes

Yes, it is for float16, bfloat16 and fp8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A Maybe I am mistaken, but I thought float16 and bfloat16 had their own nvidia primitives they use, and do not go through this codepath.

Copy link
Collaborator Author
@Aidyn-A Aidyn-A Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, TensorIterator is used for most of elementwise ops like (relu, exp, log, sigmoid, sin, cos etc.) on all dtypes. Though, not all ops go trough TensorIterator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A triggered the windows builds, lets see if there is an improvement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Aidyn-A I don't see improvement over the nightly build time (Build Pytorch Binary):

This PR:
https://github.com/pytorch/pytorch/actions/runs/13677007625/job/38239672663?pr=148320

Nightly
https://github.com/pytorch/pytorch/actions/runs/13759231430/job/38471718241

Roughtly
cuda 11.8 - cuda 12.6 : 3h:40m-3h:55m and cuda 12.8 4h:20m

This is run of #147455
https://github.com/pytorch/pytorch/actions/runs/13413313017/job/37468143832?pr=147455
cuda 11.8 - cuda 12.6 : 3h:20m-3h:45m

// vectorized memory access
if constexpr (vec_size == 8) {
// To save some build time on CUDA, we are going to utilize vec8 only on SM90+ devices.
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. This is where the buildtime is changed. Please allow SM89 devices too if they support them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just don't instantiate this kernel with vec_size == 8 on older arches, don't add 2 similar codeblocks to the kernel itself

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just don't instantiate this kernel with vec_size == 8 on older arches, don't add 2 similar codeblocks to the kernel itself

I wish I could do that, the issue is that __CUDA_ARCH__ is available only inside kernels. The place where @eqy pointed is a host side code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the binary size savings really worth it? This adds considerable complexity and maintenance burden to the kernel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel this was raised as it was adding many (dozens?) of minutes to the windows build time, not due to binary size concerns

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 3, 2025
@atalman atalman added the ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR label Mar 5, 2025
@Aidyn-A Aidyn-A force-pushed the cuda_vectorize_for_sm90+ branch from ea89139 to 4e0accc Compare March 11, 2025 17:53
@Aidyn-A Aidyn-A marked this pull request as draft March 12, 2025 14:29
@Aidyn-A Aidyn-A force-pushed the cuda_vectorize_for_sm90+ branch from 2b243a5 to b0c21f7 Compare April 26, 2025 16:45
@Aidyn-A Aidyn-A marked this pull request as ready for review April 28, 2025 14:50
@malfet malfet added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label May 1, 2025
Copy link
Contributor
@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@atalman atalman added this to the 2.7.1 milestone May 1, 2025
@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented May 1, 2025

ROCm and CPU test failures are unrelated
@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 1, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Details for Dev Infra team Raised by workflow job

@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented May 2, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot pytorchmergebot force-pushed the cuda_vectorize_for_sm90+ branch from b0c21f7 to c68d7e5 Compare May 2, 2025 08:32
@pytorchmergebot
Copy link
Collaborator

Successfully rebased cuda_vectorize_for_sm90+ onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cuda_vectorize_for_sm90+ && git pull --rebase)

@malfet
Copy link
Contributor
malfet commented May 2, 2025

@pytorchbot merge -f "This looks fine"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor
malfet commented May 6, 2025

@pytorchbot cherry-pick --onto release/2.7 -c regression

pytorchbot pushed a commit that referenced this pull request May 6, 2025
Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)
@pytorchbot
Copy link
Collaborator

Cherry picking #148320

The cherry pick PR is at #152967 and it is recommended to link a regression cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

malfet pushed a commit that referenced this pull request May 7, 2025
[ATen][CUDA] Optimize 128 bit vectorization (#148320)

Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)

Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general open source release notes: cuda release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Nightly Windows builds started to time out around Jan 31, 2025
9 participants
0