8000 force computation in opmath_t for CUDA fused optimizers by MeetThePatel · Pull Request #154069 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

force computation in opmath_t for CUDA fused optimizers #154069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

MeetThePatel
Copy link
Contributor

Fixes #153649

Benchmarks before and after:

(On RTX 5070)

Adam
=====================================

Main:
Mean time per run: 11066.4259 µs
Median: 11065.9903

With forced opmath_t:
Mean time per run: 8368.6145 µs
Median: 8367.6631

SGD
=====================================

Main:
Mean time per run: 3679.6347 µs
Median: 3678.9713

With forced opmath_t:
Mean time per run: 3518.2261 µs
Median: 3517.4846

Copy link
pytorch-bot bot commented May 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154069

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit 2f5411f with merge base 59c5fff (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label May 21, 2025
@Skylion007 Skylion007 requested review from jansel and janeyx99 and removed request for jansel May 21, 2025 22:52
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 22, 2025
Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some comments

@@ -10,49 +10,51 @@ namespace at::native {

namespace {

template <typename scalar_t, int depth>
constexpr uint8_t kParamIdx = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any perf impact of this or was renaming mostly for code understandability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for readability (and consistency with the other fused optimizers). I thought it may seem a bit hard to understand at first to be indexing to random parts of the args/r_args.

tl.state_steps_addresses[tensor_loc]));

const opmath_t bias_correction1 =
1 - at::native::pow_(static_cast<opmath_t>(beta1), step_count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these inner static_casts still necessary when in lines 170/171 they've been casted? ik static cast doesn't have runtime ramifications, but just curious if we could remove some redundant code here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct, I must've missed them in my final sweep of the file. I was going a bit out of order in the edits.
I'll remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched the code to have a section at the top of the function blocks to do all the casting.

@MeetThePatel
Copy link
Contributor Author
Mismatched elements: 31 / 64 (48.4%)
Greatest absolute difference: 0.00017547607421875 at index (7, 0) (up to 3e-05 allowed)
Greatest relative difference: 1.2930149750900455e-05 at index (1, 7) (up to 1.3e-06 allowed)

To execute this test, run the following from the base repo dir:
    python test/test_cuda.py TestCudaOptimsCUDA.test_grad_scaling_autocast_fused_optimizers_AdamW_cuda_float32

I'm not as familiar with the AMP codebase, but are there any methods to mitigate the loss of precision from moving to opmath_t for the optimizer computations?

@janeyx99
Copy link
Contributor
janeyx99 commented Jun 2, 2025

@MeetThePatel I haven't looked into the exact failing test cases but we have a few approaches to make sure it's okay

  1. use TensorTracker to realign the numbers after every step https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_optimizers.py#L2220 (there are some examples in test/test_optim.py)
  2. go through the exact ops that happen and mathematically compound the error ranges to verify whether the end result makes sense.

1 is easier to do/verify so I would try to incorporate TensorTracker if not already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: cuda release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use opmath_t and not double compute in fused optimizers
4 participants
0