-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Open
Labels
module: NaNs and InfsProblems related to NaN and Inf handling in floating pointProblems related to NaN and Inf handling in floating pointmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
torch.linalg.cholesky
and torch.linalg.cholesky_ex
handle inputs that contain NaN values differently on the CPU and on CUDA devices.
To Reproduce
>> import torch
>> x = torch.full((2, 2), float("nan"))
>> torch.linalg.cholesky(x)
torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-7e5e41995557> in <module>
----> 1 torch.linalg.cholesky(x)
RuntimeError: torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).
>> torch.linalg.cholesky(x.to(torch.device("cuda")))
tensor([[nan, 0.],
[nan, nan]], device='cuda:0')
>> torch.linalg.cholesky_ex(x)
torch.return_types.linalg_cholesky_ex(
L=tensor([[nan, 0.],
[nan, nan]]),
info=tensor(1, dtype=torch.int32))
>> torch.linalg.cholesky_ex(x.to(torch.device("cuda")))
torch.return_types.linalg_cholesky_ex(
L=tensor([[nan, 0.],
[nan, nan]], device='cuda:0'),
info=tensor(0, device='cuda:0', dtype=torch.int32))
Expected behavior
Behavior of both torch.linalg.cholesky
and torch.linalg.cholesky_ex
should be consistent between CPU and CUDA devices. Ideally, we'd actually have both raise some kind of NaN Error
here rather than a positive definite error since that is kind of confusing, but that's somewhat orthogonal to this issue.
Environment
This is the case both on PyTorch version: 1.9.0 (+cu102)
as well as on PyTorch master (FB-internal Linux setup).
Additional context
This was initially raised in cornellius-gp/gpytorch#1747
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano
Metadata
Metadata
Assignees
Labels
module: NaNs and InfsProblems related to NaN and Inf handling in floating pointProblems related to NaN and Inf handling in floating pointmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module