-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Add max_pool3d backward pass for MPS
#157498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157498
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 75 PendingAs of commit beb403b with merge base 8a47f9d ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
test/test_mps.py
Outdated
|
|
||
| # This is the mismatching element | ||
| self.assertEqual(input_cpu.grad[0, 0, 2, 1], 3.1816, atol=atol, rtol=rtol) | ||
| self.assertEqual(input_mps.grad[0, 0, 2, 1], 3.1758, atol=atol, rtol=rtol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, this seems like an abnormally large difference. The elements getting summed from the output grad here are: [0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999]. Summing just nine numbers of similar magnitudes seems like it should not give such a high error, even if the order of summation differs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interestingly, torch.sum() gives the same result whether it's run on MPS or CPU tensors, even if the order changes. But it's yet a different result than max_pool3d_backward for both MPS and CPU.
>>> a = torch.tensor([0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999], dtype=torch.float16)
>>> a[torch.randperm(9)].contiguous().sum()
tensor(3.1797, dtype=torch.float16)
>>> a[torch.randperm(9)].contiguous().sum()
tensor(3.1797, dtype=torch.float16)
>>> a.to('mps')[torch.randperm(9)].contiguous().sum()
tensor(3.1797, device='mps:0', dtype=torch.float16)
>>> a.to('mps')[torch.randperm(9)].contiguous().sum()
tensor(3.1797, device='mps:0', dtype=torch.float16)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe torch.sum() sorts the values. I tried forcing my own random summation ordering like this:
import torch
import time
vals = [0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999]
torch.manual_seed(time.time())
dtype = torch.float16
for device in ['cpu', 'mps']:
results = []
for _ in range(10):
a = [torch.tensor(v, dtype=dtype, device=device) for v in vals]
r = torch.tensor(0, dtype=dtype, device=device)
for idx in torch.randperm(len(a)):
r += a[idx]
results.append(r)
results = torch.stack(results)
print(f'{device}: {results}')And below is what I get when I run it. I actually do see significantly different results with different orderings:
cpu: tensor([3.1777, 3.1797, 3.1797, 3.1777, 3.1797, 3.1797, 3.1797, 3.1797, 3.1797,
3.1797], dtype=torch.float16)
mps: tensor([3.1797, 3.1777, 3.1797, 3.1797, 3.1797, 3.1797, 3.1797, 3.1777, 3.1777,
3.1797], device='mps:0', dtype=torch.float16)
Perhaps both the CPU and MPS impl of max_pool3d_backward are within acceptable precision after all, and maybe we just need to increase the tolerance in test_output_grad_match
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok actually, I think this is all within acceptable precision after all.
A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa
Consider adding two numbers
The error is nearly the same if
So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of
So I think that's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the tolerance values. If this is alright, I'll remove the test I added here as well before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to get rid of this test now because it was making CI fail, even though it passes when I run it locally, and I only added this test with the intention of removing it after understanding why CPU and MPS can give such different results
| size_prod *= grad_input_sizes[dim]; | ||
| } | ||
|
|
||
| AtomicType<T>::atomic_add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is AtomicType::atomic_add deterministic? My guess is no, in which case I think I should mark this op nondeterministic for torch.use_deterministic_algorithms. I think we would only need the nondeterministic alert to be raised if the input requires grad and the stride is less than kernel size in any of the dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is need for it, I could write a deterministic alternative that doesn't use atomic add. The CUDA impl doesn't have this yet either
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
max_pool3dbackward pass for MPS #157498Note on backward precision over fp16:
A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa$m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024) \exp2(e)$ . (source)
Consider adding two numbers$a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$ ) and $e_b=-3$ (so $0.175 \le b \lt 0.25$ ). Assume that the result has the same exponent as $a$ . Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$ 's mantissa, which would introduce a maximum error on the order of $(2^4 / 1024) \exp2(-3) \approx 0.002$ .
The error is nearly the same if$e_b = -2$ (so $0.25 \le b \lt 0.5$ ), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024) \exp2(-2) \approx 0.002$ . Same for $e_b=-1$ .
So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of$0.0058 \approx 2.9 * 0.002$ .
That's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.