10000 Inconsistent NaN handling by cholesky between CPU and CUDA · Issue #64818 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Inconsistent NaN handling by cholesky between CPU and CUDA #64818
@Balandat

Description

@Balandat

🐛 Bug

torch.linalg.cholesky and torch.linalg.cholesky_ex handle inputs that contain NaN values differently on the CPU and on CUDA devices.

To Reproduce

>> import torch
>> x = torch.full((2, 2), float("nan"))

>> torch.linalg.cholesky(x)

torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-7e5e41995557> in <module>
----> 1 torch.linalg.cholesky(x)
RuntimeError: torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).

>> torch.linalg.cholesky(x.to(torch.device("cuda")))

tensor([[nan, 0.],
        [nan, nan]], device='cuda:0')

>> torch.linalg.cholesky_ex(x)

torch.return_types.linalg_cholesky_ex(
L=tensor([[nan, 0.],
        [nan, nan]]),
info=tensor(1, dtype=torch.int32))

>> torch.linalg.cholesky_ex(x.to(torch.device("cuda")))

torch.return_types.linalg_cholesky_ex(
L=tensor([[nan, 0.],
        [nan, nan]], device='cuda:0'),
info=tensor(0, device='cuda:0', dtype=torch.int32))

Expected behavior

Behavior of both torch.linalg.cholesky and torch.linalg.cholesky_ex should be consistent between CPU and CUDA devices. Ideally, we'd actually have both raise some kind of NaN Error here rather than a positive definite error since that is kind of confusing, but that's somewhat orthogonal to this issue.

Environment

This is the case both on PyTorch version: 1.9.0 (+cu102) as well as on PyTorch master (FB-internal Linux setup).

Additional context

This was initially raised in cornellius-gp/gpytorch#1747

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: NaNs and InfsProblems related to NaN and Inf handling in floating pointmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0