-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Add support for differentiable LR in SGD + test v2.0 #143510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for differentiable LR in SGD + test v2.0 #143510
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143510
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ae57bfe with merge base 80a4239 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
removed default = False on differentiables lint Addressed several of Jane's comments, still not ready to merge Addressed more comments Renamed kwargs to inner_kwargs, changed x ,y to be simpler, reordered order of test var definitions to be more logical, put lr back into inner_kwargs to make the function more adaptable for future enhancement Add differentiable flag to functional_sgd.py Streamlined tester function + easier support for more kwargs Renamed + fixed test_differentiable_lr Functional refactoring (last time I added default param in sgd() it broke CI, fingers crossed!) lint Update test/optim/test_optim.py Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Add comment + attempt to under differentiable adding in torch/optim/sgd.py Will revert this commit if we choose to do ``` if differentiable and isinstance(lr, Tensor): ``` instead of ``` if isinstance(lr, Tensor): ``` Add newline to revert earlier change updated to use cpu-scalar addcmul + made name more generalizable Add support for differentiable LR in SGD + test removed default = False on differentiables lint Addressed several of Jane's comments, still not ready to merge Addressed more comments Renamed kwargs to inner_kwargs, changed x ,y to be simpler, reordered order of test var definitions to be more logical, put lr back into inner_kwargs to make the function more adaptable for future enhancement Add differentiable flag to functional_sgd.py Streamlined tester function + easier support for more kwargs Renamed + fixed test_differentiable_lr Functional refactoring (last time I added default param in sgd() it broke CI, fingers crossed!) lint Update test/optim/test_optim.py Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Add comment + attempt to under differentiable adding in torch/optim/sgd.py Will revert this commit if we choose to do ``` if differentiable and isinstance(lr, Tensor): ``` instead of ``` if isinstance(lr, Tensor): ``` Add newline to revert earlier change updated to use cpu-scalar addcmul + made name more generalizable
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
torch/optim/sgd.py
Outdated
@@ -354,6 +354,8 @@ def _single_tensor_sgd( | |||
if isinstance(lr, Tensor) and lr.requires_grad: | |||
param.addcmul_(grad, lr, value=-1) | |||
else: | |||
# CPU scalar tensors w/out grad works but isn't supported in typehints | |||
lr = cast(float, lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should almost always already be a float--we don't want to do this in the sgd code, but we should enforce that it's always a float in the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! I added it bc the CI failed, which I think is because functional uses really strict typing. That being said the cast felt pretty wrong when I added it esp because I hadn't seen any casts in any of the code I've read. What do u think we should do about the strict typing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failed test is here: https://github.com/pytorch/pytorch/actions/runs/12403464090/job/34628316613?pr=143510
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking there's a chance we could do just if isinstance(lr, Tensor)
because the previous error of found at least two devices, mps:0 and cpu!
occured on mps because my previous PR wasn't merged on this branch yet. In the last PR I allowed addcmul to take in cpu_scalars in the torchiterator and that wasn't there before, so it used to flag any cpu scalar before it even got to the implementation. We also added testing of the scalar to the prexisting addcmul test in the addcmul PR, and I believe it tested all the devices with the scalar and it seemed to work on all of them, but I could be terribly mistaken.
I also am kinda concerned that any tensor LRs would go into the addcmul and there might be some behavior supported w/ add that's not with addcmul, but it would fix this typing issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when you remove the type hint at the definition of this function?
Generally, jit script is too rigid with this..we shouldn't change our code semantics to make typing for jit script happy. Also, the addcmul path on MPS probably would not work even with your change because your previous PR is CUDA only.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Woohoo! |
Second PR in a larger project to broader support for differentiable optimizers with @janeyx99 ! The first one had an issue near the end so this is the second PR on that subject. See #143122 for the development up until this point.