8000 Add support for differentiable LR in SGD + test v2.0 by EmmettBicker · Pull Request #143510 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add support for differentiable LR in SGD + test v2.0 #143510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed

Add support for differentiable LR in SGD + test v2.0 #143510

wants to merge 10 commits into from

Conversation

EmmettBicker
Copy link
Contributor

Second PR in a larger project to broader support for differentiable optimizers with @janeyx99 ! The first one had an issue near the end so this is the second PR on that subject. See #143122 for the development up until this point.

Copy link
pytorch-bot bot commented Dec 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143510

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ae57bfe with merge base 80a4239 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 18, 2024
EmmettBicker and others added 2 commits December 18, 2024 18:57
removed default = False on differentiables

lint

Addressed several of Jane's comments, still not ready to merge

Addressed more comments

Renamed kwargs to inner_kwargs, changed x ,y to be simpler, reordered order of test var definitions to be more logical, put lr back into inner_kwargs to make the function more adaptable for future enhancement

Add differentiable flag to functional_sgd.py

Streamlined tester function + easier support for more kwargs

Renamed + fixed test_differentiable_lr

Functional refactoring (last time I added default param in sgd() it broke CI, fingers crossed!)

lint

Update test/optim/test_optim.py

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>

Add comment + attempt to under differentiable adding in torch/optim/sgd.py

Will revert this commit if we choose to do
```
if differentiable and isinstance(lr, Tensor):
```
instead of
```
if isinstance(lr, Tensor):
```

Add newline to revert earlier change

updated to use cpu-scalar addcmul + made name more generalizable

Add support for differentiable LR in SGD + test

removed default = False on differentiables

lint

Addressed several of Jane's comments, still not ready to merge

Addressed more comments

Renamed kwargs to inner_kwargs, changed x ,y to be simpler, reordered order of test var definitions to be more logical, put lr back into inner_kwargs to make the function more adaptable for future enhancement

Add differentiable flag to functional_sgd.py

Streamlined tester function + easier support for more kwargs

Renamed + fixed test_differentiable_lr

Functional refactoring (last time I added default param in sgd() it broke CI, fingers crossed!)

lint

Update test/optim/test_optim.py

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>

Add comment + attempt to under differentiable adding in torch/optim/sgd.py

Will revert this commit if we choose to do
```
if differentiable and isinstance(lr, Tensor):
```
instead of
```
if isinstance(lr, Tensor):
```

Add newline to revert earlier change

updated to use cpu-scalar addcmul + made name more generalizable
@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 19, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@@ -354,6 +354,8 @@ def _single_tensor_sgd(
if isinstance(lr, Tensor) and lr.requires_grad:
param.addcmul_(grad, lr, value=-1)
else:
# CPU scalar tensors w/out grad works but isn't supported in typehints
lr = cast(float, lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should almost always already be a float--we don't want to do this in the sgd code, but we should enforce that it's always a float in the tests

Copy link
Contributor Author
@EmmettBicker EmmettBicker Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! I added it bc the CI failed, which I think is because functional uses really strict typing. That being said the cast felt pretty wrong when I added it esp because I hadn't seen any casts in any of the code I've read. What do u think we should do about the strict typing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author
@EmmettBicker EmmettBicker Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking there's a chance we could do just if isinstance(lr, Tensor) because the previous error of found at least two devices, mps:0 and cpu! occured on mps because my previous PR wasn't merged on this branch yet. In the last PR I allowed addcmul to take in cpu_scalars in the torchiterator and that wasn't there before, so it used to flag any cpu scalar before it even got to the implementation. We also added testing of the scalar to the prexisting addcmul test in the addcmul PR, and I believe it tested all the devices with the scalar and it seemed to work on all of them, but I could be terribly mistaken.

I also am kinda concerned that any tensor LRs would go into the addcmul and there might be some behavior supported w/ add that's not with addcmul, but it would fix this typing issue

10000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when you remove the type hint at the definition of this function?

Generally, jit script is too rigid with this..we shouldn't change our code semantics to make typing for jit script happy. Also, the addcmul path on MPS probably would not work even with your change because your previous PR is CUDA only.

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@EmmettBicker
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@EmmettBicker
Copy link
Contributor Author

Woohoo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0