8000 torch.tril introduces NaNs on MPS when matrix contained Infs (when diagonal is negative) · Issue #149813 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.tril introduces NaNs on MPS when matrix contained Infs (when diagonal is negative) #149813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
twoertwein opened this issue Mar 22, 2025 · 1 comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@twoertwein
Copy link
Contributor
twoertwein commented Mar 22, 2025

🐛 Describe the bug

# bug
(Pdb) torch.tril(torch.full((3, 3), float("inf"), device="mps"), diagonal=-1)
tensor([[nan, nan, nan],
        [inf, nan, nan],
        [inf, inf, nan]], device='mps:0')

# working examples
# works with non-infs
(Pdb) torch.tril(torch.full((3, 3), 1.0, device="mps"), diagonal=-1)
tensor([[0., 0., 0.],
        [1., 0., 0.],
        [1., 1., 0.]], device='mps:0')
# works on the cpu
(Pdb) torch.tril(torch.full((3, 3), float("inf"), device="cpu"), diagonal=-1)
tensor([[0., 0., 0.],
        [inf, 0., 0.],
        [inf, inf, 0.]])
# works for diagonal=0
(Pdb) torch.tril(torch.full((3, 3), float("inf"), device="mps"), diagonal=0)
tensor([[inf, 0., 0.],
        [inf, inf, 0.],
        [inf, inf, inf]], device='mps:0')
# works for positive diagonal
(Pdb) torch.tril(torch.full((3, 3), float("inf"), device="mps"), diagonal=1)
tensor([[inf, inf, 0.],
        [inf, inf, inf],
        [inf, inf, inf]], device='mps:0')

Temporary workaround: transpose + triu (with positive diagonal) + transpose (or move to CPU and back to MPS)

Versions

torch installed with uv: the collect_env.py script breaks, because python -mpip fails

pytorch 2.6.0
python 3.12.9
Mac M2

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@Isalia20 Isalia20 added the module: mps Related to Apple Metal Performance Shaders framework label Mar 23, 2025
@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: NaNs and Infs Problems related to NaN and Inf handling in floating point module: correctness (silent) issue that returns an incorrect result silently labels Mar 23, 2025
pytorchbot pushed a commit that referenced this issue Apr 1, 2025
Fixes #149813

Pull Request resolved: #149866
Approved by: https://github.com/malfet

(cherry picked from commit ba46643)
malfet pushed a commit that referenced this issue Apr 1, 2025
[MPS] tril op not handling infs correctly (#149866)

Fixes #149813

Pull Request resolved: #149866
Approved by: https://github.com/malfet

(cherry picked from commit ba46643)

Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
@ZainRizvi
Copy link
Contributor
ZainRizvi commented Apr 4, 2025

Verified that it works on the current rc build and that it fails on torch==2.7.0.dev20250312

(release2.7) ~/test/release2.7/.venv/lib/python3.12/site-packages/torch/lib python
Python 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
~/test/release2.7/.venv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:81.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
>>> torch.tril(torch.full((3, 3), float("inf"), device="mps"), diagonal=-1)
tensor([[0., 0., 0.],
        [inf, 0., 0.],
        [inf, inf, 0.]], device='mps:0')
>>> torch.tril(torch.full((3, 3), 1.0, device="m
8988
ps"), diagonal=-1)
tensor([[0., 0., 0.],
        [1., 0., 0.],
        [1., 1., 0.]], device='mps:0')

amathewc pushed a commit to amathewc/pytorch that referenced this issue Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0