-
Notifications
You must be signed in to change notification settings - Fork 24.2k
POC for mixed prec optim frontend #146640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/janeyx99/222/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146640
Note: Links to docs will display an error until the docs builds have been completed. ❌ 10 New FailuresAs of commit 59010b7 with merge base 83bb921 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py. This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet. This currently only works with Adam(W)! A toy script for how to use: ``` import torch model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid(), ) model.to("cuda") optim = torch.optim.AdamW(model.named_parameters(), foreach=False) mp_policy = { "exp_avg": lambda _: torch.bfloat16, "exp_avg_sq": lambda _: torch.bfloat16, "max_exp_avg_sq": lambda _: torch.bfloat16, } optim.set_dtype_policy(mp_policy) i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2) l = model(i).sum() l.backward() optim.step() ``` [ghstack-poisoned]
torch/_meta_registrations.py
Outdated
start.dtype == end.dtype, | ||
lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}", | ||
) | ||
# torch._check( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should uncomment this once #146749 is fixed
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py. This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet. This currently only works with Adam(W)! A toy script for how to use: ``` import torch model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid(), ) model.to("cuda") optim = torch.optim.AdamW(model.named_parameters(), foreach=False) mp_policy = { "exp_avg": lambda _: torch.bfloat16, "exp_avg_sq": lambda _: torch.bfloat16, "max_exp_avg_sq": lambda _: torch.bfloat16, } optim.set_dtype_policy(mp_policy) i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2) l = model(i).sum() l.backward() optim.step() ``` [ghstack-poisoned]
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py. This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet. This currently only works with Adam(W)! A toy script for how to use: ``` import torch model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid(), ) model.to("cuda") optim = torch.optim.AdamW(model.named_parameters(), foreach=False) mp_policy = { "exp_avg": lambda _: torch.bfloat16, "exp_avg_sq": lambda _: torch.bfloat16, "max_exp_avg_sq": lambda _: torch.bfloat16, } optim.set_dtype_policy(mp_policy) i = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device="cuda").reshape(3, 2) l = model(i).sum() l.backward() optim.step() ``` [ghstack-poisoned]
ghstack-source-id: 85f4266 Pull Request resolved: pytorch/pytorch#146640
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This PR is a prototype for what a frontend for asking for mixed precision can look like torch.optim through set_dtype_policy in optimizer.py.
This is not meant to be landable but to start some discussions on what people want/would like to see and to ask if there are things I haven't considered yet.
This currently only works with Adam(W)!
A toy script for how to use:
Stack from ghstack (oldest at bottom):