8000 RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 1) · Issue #28293 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 1) #28293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ryh95 opened this issue Oct 18, 2019 · 21 comments
Closed
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ryh95
Copy link
ryh95 commented Oct 18, 2019

Hi,

when I run torch.svd() on a matrix with GPU, it raises the error
RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 1)

However, the torch.svd() has no problem with the matrix on CPU

Can someone help me figure out the reason?
Thank you!

The matrix is attached
runtime_error_W.zip

cc @vishwakftw @ssnl @jianyuh

@vishwakftw
Copy link
Contributor
vishwakftw commented Oct 18, 2019

Hi @ryh95, thank you for opening the issue.

This seems to be an issue with MAGMA, possibly relating to the correctness of gesdd vs gesvd. Note: the matrix is well-conditioned, but has too many repeated singular values.

@vishwakftw vishwakftw added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numerical-stability Problems related to numerical stability of operations labels Oct 18, 2019
@ryh95
Copy link
Author
ryh95 commented Oct 19, 2019

It seems that gesdd is used for both CPU and GPU versions of torch.svd(), according to the documentation.

The implementation of SVD on CPU uses the LAPACK routine ?gesdd (a divide-and-conquer algorithm) instead of ?gesvd for speed. Analogously, the SVD on GPU uses the MAGMA routine gesdd as well.

So, why this relate to the correctness of gesdd vs gesvd?

@pietern pietern added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 22, 2019
@Hippogriff
Copy link

Is there a temporary solution/work-around to this problem?

@ryh95
Copy link
Author
ryh95 commented Nov 26, 2019

I just move the tensor to CPU then use torch.svd to get a decomposition. I'll move the results back to GPU if needed.

Someone has mentioned that gesvd is more accurate/robust(#25978 (comment)), so you can move the tensor to CPU and transfer to numpy array and use scipy if you prefer a better result.

@Hippogriff
Copy link

@ryh95 I also want to back propagate the gradients through this computation. Converting it to cpu might break the computation graph (though I am not sure about this).

@Nav94
Copy link
Nav94 commented Mar 25, 2020

@Hippogriff Did you manage to work around the issue? I am also running into the same issue and was thinking of computing on the cpu and moving it back to the gpu.

@Hippogriff
Copy link

@Nav94 This can happens either matrix is severely ill conditioned or because the singular values are very close or equal to each other.
There are following solution to the problem:

  1. For ill conditioned case, you can compute the condition number of the matrix on cpu and if the condition number is very large, then you cannot do much. In this case, you can simply trivialize the solution.
  2. If the singular values are close to each other, you need to safe guard your back prop, that is, you need to write a new back ward pass. You can use the custom_svd function that replaces the torch's svd function.

`

def compute_grad_V(U, S, V, grad_V):
    N = S.shape[0]
    K = svd_grad_K(S)
    S = torch.eye(N).cuda(S.get_device()) * S.reshape((N, 1))
    inner = K.T * (V.T @ grad_V)
    inner = (inner + inner.T) / 2.0
    return 2 * U @ S @ inner @ V.T


def svd_grad_K(S):
    N = S.shape[0]
    s1 = S.view((1, N))
    s2 = S.view((N, 1))
    diff = s2 - s1
    plus = s2 + s1

    # TODO Look into it
    eps = torch.ones((N, N)) * 10**(-6)
    eps = eps.cuda(S.get_device())
    max_diff = torch.max(torch.abs(diff), eps)
    sign_diff = torch.sign(diff)

    K_neg = sign_diff * max_diff

    # gaurd the matrix inversion
    K_neg[torch.arange(N), torch.arange(N)] = 10 ** (-6)
    K_neg = 1 / K_neg
    K_pos = 1 / plus

    ones = torch.ones((N, N)).cuda(S.get_device())
    rm_diag = ones - torch.eye(N).cuda(S.get_device())
    K = K_neg * K_pos * rm_diag
    return K


class CustomSVD(Function):
    """
    Costum SVD to deal with the situations when the
    singular values are equal. In this case, if dealt
    normally the gradient w.r.t to the input goes to inf.
    To deal with this situation, we replace the entries of
    a K matrix from eq: 13 in https://arxiv.org/pdf/1509.07838.pdf
    to high value.
    Note: only applicable for the tall and square matrix and doesn't
    give correct gradients for fat matrix. Maybe transpose of the
    original matrix is requires to deal with this situation. Left for
    future work.
    """
    @staticmethod
    def forward(ctx, input):
        # Note: input is matrix of size m x n with m >= n.
        # Note: if above assumption is voilated, the gradients
        # will be wrong.
        try:
            U, S, V = torch.svd(input, some=True)
        except:
            import ipdb; ipdb.set_trace()

        ctx.save_for_backward(U, S, V)
        return U, S, V

    @staticmethod
    def backward(ctx, grad_U, grad_S, grad_V):
        U, S, V = ctx.saved_tensors
        grad_input = compute_grad_V(U, S, V, grad_V)
        return grad_input

customsvd = CustomSVD.apply

`

@ryh95
Copy link
Author
ryh95 commented Mar 28, 2020

Thanks for your solution! @Hippogriff
By the way, is it true that convert the tensor to CPU will break the backpropagate?

@Hippogriff
Copy link

Yes, as far as I know.

@Abdelpakey
Copy link

@Hippogriff
Thanks for the solution, is there any way to track your code in case you update it for the #TODO part?

@Hippogriff
Copy link

@Abdelpakey I am not keeping track of this yet. There are severl todos, for example, this code doesn't compute the gradient with respect to right and left eigen vectors. I will update this repo when I am done:
https://github.com/Hippogriff/SVD-Pytorch

@chenhao1umbc
Copy link
chenhao1umbc commented Apr 14, 2020

I am using this, since it is not solved

    try:
        u, s, v = torch.svd(L)
    except:                     # torch.svd may have convergence issues for GPU and CPU.
        u, s, v = torch.svd(L + 1e-4*L.mean()*torch.rand(l, h))

@SebastianGrans
Copy link

@chenhao1umbc Thanks for that snippet!

Is there a particular reason why you multiply with L.mean()?

@chenhao1umbc
Copy link

Yes, the main idea is that convergence issue can be solved by adding some turbulence. But the scale of L is unknown, which means L could be at 1e4 scale or 1e-4 scale. This means we cannot simple adding a "small" random number, but a relatively small number.

@davidbau
Copy link

This affects torch.pinverse also, due to its underlying svd.

@andreaskoepf
Copy link
Collaborator
andreaskoepf commented May 29, 2021

For pytorch 1.8.1+cu101 the output of:

x = torch.randn(64, 64).cuda()
x[0,0] = float('nan')
torch.svd(x)

is:

RuntimeError: svd_cuda: For batch 0: U(65,65) is zero, singular U.

The error message is unexpected (and misleading) and comes from a special batchCheckErrors() info-tensor function overload that does not check for the "svd" in the name string, see

static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false, int info_per_batch=1) {
auto batch_size = infos.numel();
auto infos_cpu = infos.to(at::kCPU);
auto infos_data = infos_cpu.data_ptr<int>();
for (int64_t i = 0; i < batch_size; i++) {
auto info = infos_data[i];
if (info < 0) {
AT_ERROR(name, ": For batch ", i/info_per_batch, ": Argument ", -info, " has illegal value");
} else if (!allow_singular && info > 0) {
AT_ERROR(name, ": For batch ", i/info_per_batch, ": U(", info, ",", info, ") is zero, singular U.");
}
}
}

If you look above and below e.g. batchCheckErrors(std::vector<int64_t>& infos, ...) and void singleCheckErrors() both have the case:

if (strstr(name, "svd")) {
     AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")");
}

Earlier version of pytorch raised for the same code "RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 23)".

I guess for cuda tensors torch.svd() calls the tensor-info batchCheckErrors() overload from _svd_helper_cuda_lib():

std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_lib(const Tensor& self, bool some, bool compute_uv) {
const int64_t batch_size = batchCount(self);
at::Tensor infos = at::zeros({batch_size}, self.options().dtype(at::kInt));
const int64_t m = self.size(-2);
const int64_t n = self.size(-1);
const int64_t k = std::min(m, n);
Tensor U_working_copy, S_working_copy, VT_working_copy;
std::tie(U_working_copy, S_working_copy, VT_working_copy) = \
_create_U_S_VT(self, some, compute_uv, /* svd_use_cusolver = */ true);
// U, S, V working copies are already column majored now
// heuristic for using `gesvdjBatched` over `gesvdj`
if (m <= 32 && n <= 32 && batch_size > 1 && (!some || m == n)) {
apply_svd_lib_gesvdjBatched(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv);
} else {
apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv, some);
}
// A device-host sync will be performed.
batchCheckErrors(infos, "svd_cuda");
if (!compute_uv) {
VT_working_copy.zero_();
U_working_copy.zero_();
}
if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

btw if somebody wonders why the output is U(65,65) is zero, singular U. for a svd of 64x64 .. it is simply a misinterpretation of the info-feedback of cusolverDnSgesvdj()` (see cuda doc) ... cuda docs say "if info = 0, the operation is successful. if info = -i, the i-th parameter is wrong (not counting handle). if info = min(m,n)+1, gesvdj dose not converge under given tolerance and maximum sweeps. "

@IvanYashchuk
Copy link
Collaborator

Thank you, @andreaskoepf, for reporting the issue of incorrect error messages! We will have it fixed in the future PyTorch release. Unfortunately, the bugfixes would not be backported to older versions.

facebook-github-bot pushed a commit that referenced this issue Oct 26, 2021
…ance (#64533)

Summary:
Fix #64237
Fix #28293
Fix #4689

See also #47953

cc ngimel jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Pull Request resolved: #64533

Reviewed By: albanD

Differential Revision: D31915794

Pulled By: ngimel

fbshipit-source-id: 29ea48696531ced8a48474e891a9e2d5f11e9d7a
@lhyfst
Copy link
lhyfst commented Dec 11, 2021

I am using this, since it is not solved

    try:
        u, s, v = torch.svd(L)
    except:                     # torch.svd may have convergence issues for GPU and CPU.
        u, s, v = torch.svd(L + 1e-4*L.mean()*torch.rand(l, h))

What is h in torch.rand(l, h)?

@andreaskoepf
Copy link
Collaborator

What is h in torch.rand(l, h)?

To add some noise you could use torch.rand_like(L).

For support & coding questions please refer to the PyTorch Forums & see the docs, e.g. torch.rand.

@lhyfst
Copy link
lhyfst commented Dec 11, 2021

What is h in torch.rand(l, h)?

To add some noise you could use torch.rand_like(L).

For support & coding questions please refer to the PyTorch Forums & see the docs, e.g. torch.rand.

Thank you

@a-r-j
Copy link
a-r-j commented Dec 9, 2022

I had some luck with:

from tenacity import retry, stop_after_attempt


@retry(stop=stop_after_attempt(32)) # Or some other value
def func_with_svd(L: torch.Tensor):
    try:
        u, s, v = torch.svd(L)
    except:                     # torch.svd may have convergence issues for GPU and CPU.
        u, s, v = torch.svd(L + 1e-4*L.mean()*torch.rand_like(L))
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
P 4FAF rojects
None yet
Development

Successfully merging a pull request may close this issue.

0