-
Notifications
You must be signed in to change notification settings - Fork 24.3k
force computation in opmath_t for CUDA fused optimizers #154069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154069
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 2f5411f with merge base 59c5fff ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, left some comments
@@ -10,49 +10,51 @@ namespace at::native { | |||
|
|||
namespace { | |||
|
|||
template <typename scalar_t, int depth> | |||
constexpr uint8_t kParamIdx = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any perf impact of this or was renaming mostly for code understandability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for readability (and consistency with the other fused optimizers). I thought it may seem a bit hard to understand at first to be indexing to random parts of the args/r_args.
tl.state_steps_addresses[tensor_loc])); | ||
|
||
const opmath_t bias_correction1 = | ||
1 - at::native::pow_(static_cast<opmath_t>(beta1), step_count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these inner static_casts still necessary when in lines 170/171 they've been casted? ik static cast doesn't have runtime ramifications, but just curious if we could remove some redundant code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct, I must've missed them in my final sweep of the file. I was going a bit out of order in the edits.
I'll remove them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I switched the code to have a section at the top of the function blocks to do all the casting.
c6d3e59
to
6f08e0f
Compare
I'm not as familiar with the AMP codebase, but are there any methods to mitigate the loss of precision from moving to |
@MeetThePatel I haven't looked into the exact failing test cases but we have a few approaches to make sure it's okay
1 is easier to do/verify so I would try to incorporate TensorTracker if not already. |
Fixes #153649
Benchmarks before and after: