8000 Add relative and absolute tolerances for matrix_rank, pinv · Issue #54151 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add relative and absolute tolerances for matrix_rank, pinv #54151

10000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
IvanYashchuk opened this issue Mar 17, 2021 · 14 comments
Closed

Add relative and absolute tolerances for matrix_rank, pinv #54151

IvanYashchuk opened this issue Mar 17, 2021 · 14 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@IvanYashchuk
Copy link
Collaborator
IvanYashchuk commented Mar 17, 2021

🚀 Feature

Both torch.linalg.matrix_rank and torch.linalg.pinv calculate singular values of the provided matrix and truncate them based on the specified tolerance (argument is called rcond for torch.linalg.pinv and tol for torch.linalg.matrix_rank).
Currently implemented behavior for setting the tolerance and the default tolerance values follow NumPy.
However, NumPy is not consistent in the default values and in treating the provided tolerance as relative or absolute.

default
for the argument
default
tolerance
truncation criteria
(if default)
truncation criteria
(if specified)
matrix_rank None eps * max(rows, cols) tol * max(singular_values) tol
pinv 1e-15 1e-15 tol * max(singular_values) tol * max(singular_values)

The proposal is to implement a unified way to specify the absolute or relative tolerances for the truncation of singular values as following:

def matrix_rank_or_pinv(input, *, atol = 0, rtol = default_rtol):
    ...
    singular_values = ... # compute singular values of input
    truncation_criteria = max(atol, rtol * max(singular_values) )
    truncated_singular_values = singular_values > truncation_criteria
    ...

Possible choices of default_rtol:

  • NumPy uses eps * max(rows, cols) for matrix_rank and 1e-15 for pinv
  • TensorFlow uses the same default as NumPy for matrix_rank but 10 * eps * max(rows, cols) for pinv
  • JAX uses the same defaults as TensorFlow
  • Julia uses eps * min(rows, cols) both for pinv and matrix_rank

Use of max(atol, rtol * ...) for defining the truncation criteria follows math.isclose.

Backwards compatibility / NumPy compatibility:

def matrix_rank(input, tol = None):
    if tol is None:
        return matrix_rank(input, atol = 0, rtol = eps * max(rows, cols))
    else:
        return matrix_rank(input, atol = tol, rtol = 0)

def pinv(input, rcond = 1e-15)
    return pinv(input, atol = 0, rtol = rcond)

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @rgommers

@IvanYashchuk IvanYashchuk added module: numpy Related to numpy support, and also numpy compatibility of our operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Mar 17, 2021
@vadimkantorov
Copy link
Contributor

Things are probably going to be worse for fp16 or other short bitwidth types #35666. So at least thresholds/epsilons should be configurable...

@IvanYashchuk IvanYashchuk changed the title Add relative and absolute tolerances for matrix_rank, pinv Add relative and absolute tolerances for matrix_rank, pinv, lstsq Mar 17, 2021
@IvanYashchuk IvanYashchuk changed the title Add relative and absolute tolerances for matrix_rank, pinv, lstsq Add relative and absolute tolerances for matrix_rank, pinv Mar 17, 2021
@agolynski agolynski added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 18, 2021
@mruberry
Copy link
Collaborator

Just to make sure I understand this issue:

  • NumPy's matrix_rank() and pinv() have differently named kwargs that control values below which singular values are set to zero (my understanding is they are not "truncated" but set to zero)
  • the behavior of these kwargs is wacky
  • it would make more sense to use atol, which if set would specify a number to test against, and rtol, which if set would specify a number to be computed using the maximal singular value
  • if both atol and rtol are set, then any singular value which failed either criteria would be set to zero
  • the historic rcond and tol kwargs could be supported just as they are, but we could encourage users to use atol and rtol instead

Does that sound right, @IvanYashchuk?

@IvanYashchuk
Copy link
Collaborator Author

That's a great summary! All the points are correct.

@vadimkantorov
Copy link
Contributor

If default cut-offs are for fp32 and do not adapt automagically to fp16 inputs, then probably docs should have suggestions for cut-offs for fp16 inputs and clearly indicate that it's recommended that the user calls the function with the cut-off corresponding to dtype

@IvanYashchuk
Copy link
Collaborator Author

fp16 inputs are not supported for CPU and CUDA and I doubt they will ever be in near future. If it was supported it wouldn't be an issue: eps in the first message stands for "machine precision", the actual value for specified dtype can be obtained with torch.finfo(torch.float16).eps, so the default relative tolerance would adapt to dtype of input.

@mruberry
Copy link
Collaborator
mruberry commented Mar 23, 2021

This sounds like a good UX change to me. We'll have to be a little careful in the docs how we promote atol/rtol over the historic arguments and how we clarify the default behavior.

I also thinking changing the default cutoff value for these operations is OK if it simplifies the UX.

@lezcano
Copy link
Collaborator
lezcano commented Jul 14, 2021

Summary of offline discussion. All this applies to pinv, matrix_rank and lstsq.

  • The name in the documentation will be rtol (consistent with the array API), but we will also support rcond. The support of rcond will not be documented.
  • We will add atol (and silently support acond) for linalg.matrix_rank, as it is the only one for which we can do this.
  • The default for rtol will be eps*max(rows, cols) for all the three operations to conform with the array API. When doing this, we'll be able to remove the warnings in these operations that say that "we may change the default rtol value in the future".

@nikitaved
Copy link
Collaborator
nikitaved commented Aug 12, 2021

The current default values for tolerances could be too conservative for larger random Gaussian matrices of floating type, for example:

In [25]: a = torch.rand(1, 1, 1024, 1024).cuda()

In [26]: a[..., -1, :] = 0

In [27]: a[..., :, -1] = 0

In [28]: torch.linalg.matrix_rank(a)
Out[28]: tensor([[1019]], device='cuda:0')

In [29]: a.svd()[1].abs().topk(k=5, largest=False)
Out[29]: 
torch.return_types.topk(
values=tensor([[[0.0000, 0.0011, 0.0157, 0.0202, 0.0401]]], device='cuda:0'),
indices=tensor([[[1023, 1022, 1021, 1020, 1019]]], device='cuda:0'))

Given how important Gaussian matrices are in the context of Machine Learning, looks like matrix_rank cannot be trusted in general. A note for the practitioner, always use SVD :) So, I guess, it is worth adding a note suggesting using SVD in case of doubt .
A side note, lstsq uses the threshold based on the largest singular value irrespective of the size of the input.

@IvanYashchuk
Copy link
Collaborator Author

torch.rand is samples from uniform distribution, not Gaussian. It's likely not full rank even before you add zeros on the diagonal.

The default relative tolerance for matrix_rank (eps*max(rows, cols)) is common and won't be changed for something else because it's not an easy choice to decide the truncation criteria. Therefore in this proposal and in the implementation, it's possible to choose between absolute tolerance that doesn't depend on singular values and relative tolerance. Maybe useful reading is section 6. How small is small from https://doi.org/10.1137/0905030? There and in some other references it's suggested to use the truncation criteria as eps times some growing function of the number of rows and columns.

A side note, lstsq uses the threshold based on the largest singular value irrespective of the size of the input.

This is not the case according to the documentation:

If rcond=None, rcond is set to the machine precision of the dtype of A times max(m, n)

@lezcano
Copy link
Collaborator
lezcano commented Aug 12, 2021

The current default tolerance for matrix_rank is \sigma_0 * max(m, n) * eps. This is what's causing this weird behaviour. When this issue is solved, it'll be max(m,n)*eps which will return the correct results for these random matrices.

The default values for matrix_rank proposed in this issue start being problematic for matrices of shape 10k x 10k or 100k x 100k. If you are using such matrices in single precision in your code, you'll need to be very careful of how any linalg operation handles them regardless, so I don't think that a note in the docs is necessary.

About using SVD, well, matrix_rank uses svdvals behind the scenes, so I'm afraid you won't be able to do much better on your own using SVD than using the current matrix_rank with a reasonable atol / rtol (when this issue is implemented, of course).

As a side note, I don't think that a matrix in R^{1024 x 1024} distributed as U(0,1) will have a singular value close to zero, although it may be close to singular looking at its determinant :)

@nikitaved
Copy link
Collaborator
nikitaved commented Aug 12, 2021

torch.rand is samples from uniform distribution, not Gaussian. It's likely not full rank even before you add zeros on the diagonal.

I agree, fixed to randn, but the point is still valid. Nevertheless, I disagree it is not full-rank, which you can see from the analysis of smallest singular values (prior to zeroing out the last column and row). And generating a rank-deficient matrix from random element-wise procedures should be hard as these matrices are measure zero in the probability space.

Another example:

In [4]: a = torch.rand(1024, 1024, device='cuda')

In [5]: torch.linalg.matrix_rank(a)
Out[5]: tensor(1020, device='cuda:0')

In [6]: a.svd()[1].abs().topk(k=5, largest=False)
Out[6]: 
torch.return_types.topk(
values=tensor([0.0040, 0.0126, 0.0253, 0.0422, 0.0651], device='cuda:0'),
indices=tensor([1023, 1022, 1021, 1020, 1019], device='cuda:0'))

@nikitaved
Copy link
Collaborator
nikitaved commented Aug 12, 2021

About using SVD, well, matrix_rank uses svdvals behind the scenes, so I'm afraid you won't be able to do much better on your own using SVD than using the current matrix_rank with a reasonable atol / rtol (when this issue is implemented, of course).

Since it uses SVD, I might skip on trying to understand what these atol/rtol are and just implement my own custom thresholding function such in a potentially non-bc breaking way. Some LAPACK methods, by the way, allow to pass a custom thresholding function for singular values.

@lezcano
Copy link
Collaborator
lezcano commented Aug 12, 2021

For lstsq all drivers have a similar behaviour when it comes to rtol, and that is, they are truncated whenever they are smaller than \sigma_0 * rtol. For why this is also the case for all the drivers see data-apis/array-api#216

While the behaviour is certainly not consistent with the one this PR proposes as one multiplies the rtol you give it by sigma_0 and the other one uses its value raw, the defaults are not far from each other. For example, for square gaussian random matrices, we have that (on average) the largest eigenvalue concentrates around sqrt(m)+sqrt(n). As such, perhaps a better default would be (sqrt(m)+sqrt(n))*eps, but it starts getting a bit too technical perhaps? Also, there's the consistency with the other libraries and all that...

@IvanYashchuk
Copy link
Collaborator Author

The current default tolerance for matrix_rank is \sigma_0 * max(m, n) * eps. This is what's causing this weird behaviour. When this issue is solved, it'll be max(m,n)*eps which will return the correct results for these random matrices.

@lezcano, what is the correct result? The current default behavior for matrix_rank is to discard singular values below sigma_0 * max(m, n) * eps, so the result is correct. The proposal in this issue is not to modify that, the same default behavior will stay and also will be added for pinv. If someone wants to discard the singular values using the absolute tolerance, he'd need to use atol=, it's absolute because it doesn't depend on the matrix values.

wconstab pushed a commit that referenced this issue Oct 20, 2021
Summary:
This pull request introduces new keyword arguments for `torch.linalg.matrix_rank` and `torch.linalg.pinv`: `atol` and `rtol`.

Currently, only tensor overload has default values for either `atol` or `rtol`, the float overload requires both arguments to be specified.

FC compatibility: #63102 (comment)

Fixes #54151. Fixes #66618.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Pull Request resolved: #63102

Reviewed By: H-Huang

Differential Revision: D31641456

Pulled By: mruberry

fbshipit-source-id: 4c765508ab1657730703e42975fc8c0d0a60eb7c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants
0