-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d #152094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152094
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
97ba670
to
9256e4e
Compare
Fixes pytorch#135447. When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient. To work around this, we break the problematic dimension into chunks with chunk size being no greater than 2^16 - 1. Test case for nn.ReplicationPad1d: ``` shape = [65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad1d((1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` Test case for nn.ReplicationPad2d, ``` shape = [2, 65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad2d((1, 1, 1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` These tests produce expected output with this workaround.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix lint and also looks like it fails on exactly the test you are trying to add
// we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1. | ||
// This is reported in https://github.com/pytorch/pytorch/issues/135447. | ||
// Internal radar for MPSGraph: rdar://149853787. | ||
const int64_t max_sub_batch_size = 65535; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int64_t max_sub_batch_size = 65535; | |
constexpr auto max_sub_batch_size = 65535; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @malfet for the comments. I will follow up on these issues.
Fixes #135447.
When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient.
To work around this, we break the problematic dimension into chunks with chunk size being
no greater than 2^16 - 1.
Test case for nn.ReplicationPad1d:
Test case for nn.ReplicationPad2d,
These tests produce expected output with this workaround.