10000 POC for mixed prec optim frontend by janeyx99 · Pull Request #146640 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

POC for mixed prec optim frontend #146640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: gh/janeyx99/222/base
Choose a base branch
from

Conversation

janeyx99
Copy link
Contributor
@janeyx99 janeyx99 commented Feb 6, 2025

This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py.

This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet.

This currently only works with Adam(W)!

A toy script for how to use:

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(2, 3),
    torch.nn.Sigmoid(),
    torch.nn.Linear(3, 1),
    torch.nn.Sigmoid(),
)
model.to("cuda")

optim = torch.optim.AdamW(model.named_parameters(), foreach=False)
mp_policy = {
    "exp_avg": lambda _: torch.bfloat16,
    "exp_avg_sq": lambda _: torch.bfloat16,
    "max_exp_avg_sq": lambda _: torch.bfloat16,
}
optim.set_dtype_policy(mp_policy)

i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2)
l = model(i).sum()
l.bac
10000
kward()

optim.step()

Stack from ghstack (oldest at bottom):

@janeyx99 janeyx99 requested a review from albanD as a code owner February 6, 2025 21:31
Copy link
pytorch-bot bot commented Feb 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146640

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit 59010b7 with merge base 83bb921 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

janeyx99 added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: b3b5d46
Pull Request resolved: #146640
@janeyx99 janeyx99 marked this pull request as draft February 6, 2025 21:33
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py.

This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet.

This currently only works with Adam(W)!

A toy script for how to use:
```
import torch

model = torch.nn.Sequential(
    torch.nn.Linear(2, 3),
    torch.nn.Sigmoid(),
    torch.nn.Linear(3, 1),
    torch.nn.Sigmoid(),
)
model.to("cuda")

optim = torch.optim.AdamW(model.named_parameters(), foreach=False)
mp_policy = {
    "exp_avg": lambda _: torch.bfloat16,
    "exp_avg_sq": lambda _: torch.bfloat16,
    "max_exp_avg_sq": lambda _: torch.bfloat16,
}
optim.set_dtype_policy(mp_policy)

i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2)
l = model(i).sum()
l.backward()

optim.step()
```




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Feb 8, 2025
ghstack-source-id: 251b9c3
Pull Request resolved: #146640
start.dtype == end.dtype,
lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
)
# torch._check(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should uncomment this once #146749 is fixed

This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py.

This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet.

This currently only works with Adam(W)!

A toy script for how to use:
```
import torch

model = torch.nn.Sequential(
    torch.nn.Linear(2, 3),
    torch.nn.Sigmoid(),
    torch.nn.Linear(3, 1),
    torch.nn.Sigmoid(),
)
model.to("cuda")

optim = torch.optim.AdamW(model.named_parameters(), foreach=False)
mp_policy = {
    "exp_avg": lambda _: torch.bfloat16,
    "exp_avg_sq": lambda _: torch.bfloat16,
    "max_exp_avg_sq": lambda _: torch.bfloat16,
}
optim.set_dtype_policy(mp_policy)

i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2)
l = model(i).sum()
l.backward()

optim.step()
```




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Feb 12, 2025
ghstack-source-id: 9f0e566
Pull Request resolved: #146640
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py.

This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet.

This currently only works with Adam(W)!

A toy script for how to use:
```
import torch

model = torch.nn.Sequential(
    torch.nn.Linear(2, 3),
    torch.nn.Sigmoid(),
    torch.nn.Linear(3, 1),
    torch.nn.Sigmoid(),
)
model.to("cuda")

optim = torch.optim.AdamW(model.named_parameters(), foreach=False)
mp_policy = {
    "exp_avg": lambda _: torch.bfloat16,
    "exp_avg_sq": lambda _: torch.bfloat16,
    "max_exp_avg_sq": lambda _: torch.bfloat16,
}
optim.set_dtype_policy(mp_policy)

i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2)
l = model(i).sum()
l.backward()

optim.step()
```




[ghstack-poisoned]
@janeyx99 janeyx99 mentioned this pull request Feb 22, 2025
desai0007 pushed a commit to desai0007/test-repo-pytorch that referenced this pull request Feb 26, 2025
ghstack-source-id: 85f4266
Pull Request resolved: pytorch/pytorch#146640
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
0