8000 Add `max_pool3d` backward pass for MPS by kurtamohler · Pull Request #157498 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Conversation

@kurtamohler
Copy link
Collaborator
@kurtamohler kurtamohler commented Jul 2, 2025

Stack from ghstack (oldest at bottom):

Note on backward precision over fp16:

A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa $m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024) \exp2(e)$. (source)

Consider adding two numbers $a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$) and $e_b=-3$ (so $0.175 \le b \lt 0.25$). Assume that the result has the same exponent as $a$. Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$'s mantissa, which would introduce a maximum error on the order of $(2^4 / 1024) \exp2(-3) \approx 0.002$.

The error is nearly the same if $e_b = -2$ (so $0.25 \le b \lt 0.5$), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024) \exp2(-2) \approx 0.002$. Same for $e_b=-1$.

So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of $0.0058 \approx 2.9 * 0.002$.

That's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category release notes: inductor (aoti) labels Jul 2, 2025
@pytorch-bot
Copy link
pytorch-bot bot commented Jul 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157498

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 75 Pending

As of commit beb403b with merge base 8a47f9d (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions
Copy link
Contributor
github-actions bot commented Jul 2, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor
github-actions bot commented Jul 2, 2025

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

[ghstack-poisoned]
[ghstack-poisoned]
10BC0
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jul 3, 2025
ghstack-source-id: 896823e
Pull-Request: #157498
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jul 3, 2025
ghstack-source-id: 9162184
Pull-Request: #157498
test/test_mps.py Outdated

# This is the mismatching element
self.assertEqual(input_cpu.grad[0, 0, 2, 1], 3.1816, atol=atol, rtol=rtol)
self.assertEqual(input_mps.grad[0, 0, 2, 1], 3.1758, atol=atol, rtol=rtol)
Copy link
Collaborator Author
@kurtamohler kurtamohler Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this seems like an abnormally large difference. The elements getting summed from the output grad here are: [0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999]. Summing just nine numbers of similar magnitudes seems like it should not give such a high error, even if the order of summation differs

Copy link
Collaborator Author
@kurtamohler kurtamohler Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, torch.sum() gives the same result whether it's run on MPS or CPU tensors, even if the order changes. But it's yet a different result than max_pool3d_backward for both MPS and CPU.

>>> a = torch.tensor([0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999], dtype=torch.float16)
>>> a[torch.randperm(9)].contiguous().sum()
tensor(3.1797, dtype=torch.float16)
>>> a[torch.randperm(9)].contiguous().sum()
tensor(3.1797, dtype=torch.float16)
>>> a.to('mps')[torch.randperm(9)].contiguous().sum()
tensor(3.1797, device='mps:0', dtype=torch.float16)
>>> a.to('mps')[torch.randperm(9)].contiguous().sum()
tensor(3.1797, device='mps:0', dtype=torch.float16)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe torch.sum() sorts the values. I tried forcing my own random summation ordering like this:

import torch
import time

vals = [0.1816, 0.5708, 0.6470, 0.7988, 0.1284, 0.0864, 0.1357, 0.2300, 0.3999]
torch.manual_seed(time.time())
dtype = torch.float16

for device in ['cpu', 'mps']:
    results = []
    for _ in range(10):
        a = [torch.tensor(v, dtype=dtype, device=device) for v in vals]
        r = torch.tensor(0, dtype=dtype, device=device)
        for idx in torch.randperm(len(a)):
            r += a[idx]

        results.append(r)

    results = torch.stack(results)
    print(f'{device}: {results}')

And below is what I get when I run it. I actually do see significantly different results with different orderings:

cpu: tensor([3.1777, 3.1797, 3.1797, 3.1777, 3.1797, 3.1797, 3.1797, 3.1797, 3.1797,
        3.1797], dtype=torch.float16)
mps: tensor([3.1797, 3.1777, 3.1797, 3.1797, 3.1797, 3.1797, 3.1797, 3.1777, 3.1777,
        3.1797], device='mps:0', dtype=torch.float16)

Perhaps both the CPU and MPS impl of max_pool3d_backward are within acceptable precision after all, and maybe we just need to increase the tolerance in test_output_grad_match

Copy link
Collaborator Author
@kurtamohler kurtamohler Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok actually, I think this is all within acceptable precision after all.

A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa $m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024) \exp2(e)$. (source)

Consider adding two numbers $a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$) and $e_b=-3$ (so $0.175 \le b \lt 0.25$). Assume that the result has the same exponent as $a$. Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$'s mantissa, which would introduce a maximum error on the order of $(2^4 / 1024) \exp2(-3) \approx 0.002$.

The error is nearly the same if $e_b = -2$ (so $0.25 \le b \lt 0.5$), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024) \exp2(-2) \approx 0.002$. Same for $e_b=-1$.

So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of $0.0058 \approx 2.9 * 0.002$.

So I think that's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the tolerance values. If this is alright, I'll remove the test I added here as well before merging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to get rid of this test now because it was making CI fail, even though it passes when I run it locally, and I only added this test with the intention of removing it after understanding why CPU and MPS can give such different results

[ghstack-poisoned]
size_prod *= grad_input_sizes[dim];
}

AtomicType<T>::atomic_add(
Copy link
Collaborator Author
@kurtamohler kurtamohler Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is AtomicType::atomic_add deterministic? My guess is no, in which case I think I should mark this op nondeterministic for torch.use_deterministic_algorithms. I think we would only need the nondeterministic alert to be raised if the input requires grad and the stride is less than kernel size in any of the dimensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is need for it, I could write a deterministic alternative that doesn't use atomic add. The CUDA impl doesn't have this yet either

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jul 4, 2025
ghstack-source-id: 73cdd3b
Pull-Request: #157498
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jul 7, 2025
ghstack-source-id: 4d64ba1
Pull-Request: #157498
@malfet
Copy link
Contributor
malfet commented Jul 7, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

0