8000 torch.linalg.eigh fails on CPU · Issue #145801 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.linalg.eigh fails on CPU #145801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
atalman opened this issue Jan 28, 2025 · 7 comments
Closed

torch.linalg.eigh fails on CPU #145801

atalman opened this issue Jan 28, 2025 · 7 comments
Assignees
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: regression It used to work, and now it doesn't module: third_party triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@atalman
Copy link
Contributor
atalman commented Jan 28, 2025

🐛 Describe the bug

Based on this issue #94772 we see failure on CPU since PyTorch 2.4.0 Release.

Minumum test, requires fc_layer_tensor.pt.zip :

import torch

t = torch.load('fc_layer_tensor.pt', weights_only=True, map_location='cpu').flatten()
torch.linalg.eigh(torch.outer(t, t))

Output:

python3 test5.py 
/home/ubuntu/test5.py:12: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  fc_layer.weight.grad = loaded_tensor = torch.load('fc_layer_tensor.pt')

Intel oneMKL ERROR: Parameter 8 was incorrect on entry to SSYEVD.
Traceback (most recent call last):
  File "/home/ubuntu/test5.py", line 15, in <module>
    evals_adagrad, evecs_adagrad = torch.linalg.eigh(precond_adagrad.cpu())
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1538, please report a bug to PyTorch. linalg.eigh: Argument 8 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

Full test

import torch
from torchvision import datasets, transforms

SEED = 123
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 512
num_classes = 10
num_features = 28**2
loss_fn = torch.nn.CrossEntropyLoss()

tforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = datasets.MNIST("~/data/", download=False, train=True, transform=tforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

fc_layer = torch.nn.Linear(in_features=num_features, out_features=num_classes, bias=False).to(DEVICE)

for batch_ix, (inputs, targets) in enumerate(train_loader):

    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

    fc_layer.weight.grad = None
    logits = fc_layer(inputs.view(inputs.shape[0], -1))
    loss = loss_fn(logits, targets)
    loss.backward()

    vec_grad = torch.flatten(fc_layer.weight.grad)
    precond_adagrad = torch.outer(vec_grad, vec_grad)

    # CPU computation works fine
    evals_adagrad, evecs_adagrad = torch.linalg.eigh(precond_adagrad.cpu())

    # But eigh computation on GPU fails
    evals_adagrad, evecs_adagrad = torch.linalg.eigh(precond_adagrad)

Versions

2.7.0

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @malfet @jianyuh @nikitaved @pearu @mruberry @walterddr @xwang233 @lezcano

@atalman atalman added this to the 2.6.1 milestone Jan 28, 2025
@malfet malfet added module: regression It used to work, and now it doesn't module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul high priority labels Jan 28, 2025
@malfet malfet added the module: error checking Bugs related to incorrect/lacking error checking label Jan 28, 2025
@malfet
Copy link
Contributor
malfet commented Jan 28, 2025

It might be an error checking issue, as input tensor dimentions are way too high...

@nikitaved
Copy link
Collaborator
nikitaved commented Jan 28, 2025

The dimension of 8k by 8k should fine. Could that be an issue related to the fact that eigh is run on a rank-1 matrix? Although the complaint is about lwork, and even with this shape int32 should not overflow... @atalman , is the use-case just the low-rank projection operators?

@malfet
Copy link
Contributor
malfet commented Jan 31, 2025

Simplified the example (and also realized one doesn't really need any specific weights)

% python3 -c "import torch;x=torch.rand(7840);y=torch.linalg.eigh(torch.outer(x, x))"

@malfet
Copy link
Contributor
malfet commented Jan 31, 2025

Smallest size it starts to fail is 2895 (on my Mac):

% python3 -c "import torch;x=torch.rand(2894);y=torch.linalg.eigh(torch.outer(x, x))"; echo $?
0
% python3 -c "import torch;x=torch.rand(2895);y=torch.linalg.eigh(torch.outer(x, x))"; echo $?
** On entry to SSYEVD, parameter number  8 had an illegal value
Traceback (most recent call last):
  File "<string>", line 1, in <module>
RuntimeError: false INTERNAL ASSERT FAILED at "/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp":1604, please report a bug to PyTorch. linalg.eigh: Argument 8 has illegal value. Most certainly there is a bug in the implementation calling the backend library.
1

@malfet
Copy link
Contributor
malfet commented Feb 1, 2025

Error is coming from https://github.com/Reference-LAPACK/lapack/blob/a00531096fff76e49bfd86260885c32070b1afcd/SRC/ssyevd.f#L258

         IF( LWORK.LT.LWMIN .AND. .NOT.LQUERY ) THEN
            INFO = -8

@malfet
Copy link
Contributor
malfet commented Feb 1, 2025

Hmm, looks like when one uses query to get info about lwork, one need to upcast it by ULP as it is returned as floating point value.

lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_data, lda, values_data,
&lwork_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, infos_data);
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(lwork_query));

I.e. when query value is computed as integer value, but than returned back as float, it need to be rounded up by ULP.

Adding debug prints shows

% python3 -c "import torch;print(torch.linalg.eigh(torch.rand(2894, 2894)))"
lapackSyevd: jobz=V, uplo=L, n=2894, lwork=-1, liwork=-1
lwork_query returned 1.67678e+07
lapackSyevd: jobz=V, uplo=L, n=2894, lwork=16767837, liwork=14473
% python3 -c "import torch;print(torch.linalg.eigh(torch.rand(2895, 2895)))"
lapackSyevd: jobz=V, uplo=L, n=2895, lwork=-1, liwork=-1
lwork_query returned 1.67794e+07
lapackSyevd: jobz=V, uplo=L, n=2895, lwork=16779420, liwork=14478
** On entry to SSYEVD, parameter number  8 had an illegal value

@malfet malfet self-assigned this Feb 1, 2025
@malfet malfet added module: third_party triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed module: error checking Bugs related to incorrect/lacking error checking high priority triage review labels Feb 3, 2025
@malfet
Copy link
Contributor
malfet commented Feb 3, 2025

From SSYEVD documentation:

     WORK is REAL array, dimension (MAX(1,LWORK))
     On exit, if INFO = 0, WORK(1) returns the optimal LWORK.

     LWORK is INTEGER
     The dimension of the array WORK.
     If N <= 1,               LWORK must be at least 1.
     If JOBZ = 'N' and N > 1, LWORK must be at least 2*N+1.
     If JOBZ = 'V' and N > 1, LWORK must be at least
                                           1 + 6*N + 2*N**2.

     If LWORK = -1, then a workspace query is assumed; the routine
     only calculates the optimal sizes of the WORK and IWORK
     arrays, returns these values as the first entries of the WORK
     and IWORK arrays, and no error message related to LWORK or
     LIWORK is issued by XERBLA.

I.e. in query mode LWORK size is returned as floating point value, which sometimes if cast back to float results in a value smaller than the one needed, which resulted in error code -8

To file a bug against OpenBLAS, but in the meantime will workaround the issue by allocating larger array

Raymo111 pushed a commit that referenced this issue Feb 20, 2025
work-query APIs return floating point values, that could loose precision when converted back to int. Solve this by using `nextafter` and `ceil`
Add regression test

Fixes #145801

Pull Request resolved: #146456
Approved by: https://github.com/malfet
majing921201 pushed a commit to majing921201/pytorch that referenced this issue Mar 4, 2025
work-query APIs return floating point values, that could loose precision when converted back to int. Solve this by using `nextafter` and `ceil`
Add regression test

Fixes pytorch#145801

Pull Request resolved: pytorch#146456
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: regression It used to work 48FA , and now it doesn't module: third_party triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0