-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Support torch.linalg.trace
#62714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.linalg.trace
#62714
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9622b99 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
Job | Step | Action |
---|---|---|
Run clang-format | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Please report bugs/suggestions to the (internal) Dr. CI Users group.
Hey @asi1024, just checking in on this PR because it's still marked as "draft." Is it ready for a review? |
@mruberry Sorry for my late response. I will mark "ready for review" after adding tests and documentation! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, @asi1024, thanks for your contribution! I left a suggestion to use the CompositeImplicitAutograd
dispatch key that would allow us to remove the _backward
function trimming down unnecessary code. After that, I think the PR should be good to go.
python_module: linalg | ||
variants: method, function | ||
dispatch: | ||
CPU, CUDA: linalg_trace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is overwritten by CompositeExplicitAutograd: linalg_trace
. The code in the linalg_trace
function is independent of the device so CPU, CUDA
specialization is not needed here and CompositeExplicitAutograd
is the correct choice of the dispatch key.
CPU, CUDA: linalg_trace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second thought using CompositeImplicitAutograd
should be better, then the backward function is not needed.
CPU, CUDA: linalg_trace | ||
CompositeExplicitAutograd: linalg_trace | ||
|
||
- func: linalg_trace_backward(Tensor grad, int[] sizes, int offset) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please remove this entry from native_functions.yaml?
Most of the backward functions in PyTorch are placed in torch\csrc\autograd\FunctionsManual.cpp
and torch\csrc\autograd\FunctionsManual.h
, so let's move linalg_trace_backward
from ReduceOps.cpp
.
aten/src/ATen/native/ReduceOps.cpp
Outdated
@@ -1112,6 +1125,13 @@ void impl_func_prod( | |||
} | |||
} | |||
|
|||
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> opt_dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed in this PR?
The function prod(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
should be autogenerated with
structured_delegate: prod.int_out |
Codecov Report
@@ Coverage Diff @@
## master #62714 +/- ##
==========================================
+ Coverage 66.37% 66.46% +0.08%
==========================================
Files 738 727 -11
Lines 94170 93581 -589
==========================================
- Hits 62510 62200 -310
+ Misses 31660 31381 -279 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few points regarding docs / testing.
torch/linalg/__init__.py
Outdated
trace = _add_docstr(_linalg.linalg_trace, r""" | ||
trace(input, offset=0) -> Tensor | ||
|
||
Returns the sum of the elements of the diagonal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What diagonal? Given that we have the parameter offset
, this should probably read:
Returns the sum of the elements of a diagonal.
Followed by an explanation of how the offset
parameter chooses a diagonal.
test/test_linalg.py
Outdated
y = torch.linalg.trace(x) | ||
xn = np.array(x.cpu().numpy()).reshape(shape) | ||
yn = np.trace(xn, axis1=-2, axis2=-1) | ||
yn = torch.from_numpy(np.asarray(yn)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to torch.from_numpy
here and a few lines below is not necessary as assertEqual
is able to compare tensors and numpy arrays. Even more, not calling it is often faster. Same below.
test/test_linalg.py
Outdated
xn = np.array(x.cpu().numpy()).reshape(shape) | ||
yn = np.trace(xn, axis1=-2, axis2=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might work without all the explicit castings by simply doing
xn = np.array(x.cpu().numpy()).reshape(shape) | |
yn = np.trace(xn, axis1=-2, axis2=-1) | |
yn = np.trace(x.cpu(), axis1=-2, axis2=-1) |
Same below.
@lezcano Thank you for your reviews! Could you take another look? |
@lezcano Now all CIs have passed! PTAL! |
Yeah, this looks good to me. I just found (yet another, ugh, sorry). Sorry for that! I believe this is the last missing thing! :D |
@lezcano The CI failures look unrelated to this PR. Could you take another look? |
As mentioned, this LGTM. We now just need to wait for @mruberry to have a look. He's been a bit busy lately, but let's hope he finds some time soon :) |
trace = _add_docstr(_linalg.linalg_trace, r""" | ||
trace(input, *, offset=0, out=None) -> Tensor | ||
|
||
Computes the trace of a matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really well written.
inputs = ( | ||
((S, S), 0), | ||
((S, M), 0), | ||
((S, S), 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a sample with a negative offset and a comment explaining the format of these tuples
@@ -9979,6 +9993,14 @@ def ref_pairwise_distance(input1, input2): | |||
dtypes=floating_and_complex_types(), | |||
sample_inputs_func=sample_inputs_linalg_slogdet, | |||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), | |||
OpInfo('linalg.trace', | |||
ref=np.trace, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice reference
@@ -7764,6 +7764,28 @@ def test_tensordot(self, device): | |||
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) | |||
self.assertEqual(a, an) | |||
|
|||
def test_linalg_trace(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test
@@ -2574,6 +2574,20 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): | |||
low=None, high=None, | |||
requires_grad=requires_grad))),) | |||
|
|||
def sample_inputs_linalg_trace(self, device, dtype, requires_grad, **kwargs): | |||
inputs = ( | |||
((S, S), 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens on empty tensors? We should probably add a case for them. What about a batched sample input too? (S, S, S)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trace function implemented in this PR returns different values from numpy.trace
for 3-dim inputs. numpy.trace
reduces with axis1=0, axis2=1 whereas array API specifies to reduce with axis1=-2, axis2=-1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to compare then against a lambda with our same defaults that the calls into np.trace?
def test_linalg_trace(self, device): | ||
inputs = [ | ||
{'shape': (1, 1), 'offsets': [0]}, | ||
{'shape': (10, 1), 'offsets': [0, -9]}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if offset is an absurd number, like 100?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RuntimeError
will be raised if offset
is out of range. I will add a test for this case!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RuntimeError
will be raised ifoffset
is out of range. I will add a test for this case!
Make it an ErrorInput
@@ -7764,6 +7764,28 @@ def test_tensordot(self, device): | |||
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) | |||
self.assertEqual(a, an) | |||
|
|||
def test_linalg_trace(self, device): | |||
inputs = [ | |||
{'shape': (1, 1), 'offsets': [0]}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the empty case here would be interesting, too
@@ -6503,6 +6503,14 @@ | |||
device_check: NoCheck | |||
device_guard: False | |||
|
|||
- func: linalg_trace.out(Tensor self, *, int offset=0, Tensor(a!) out) -> Tensor(a!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice schemas
aten/src/ATen/native/ReduceOps.cpp
Outdated
// see https://github.com/pytorch/pytorch/pull/47305, | ||
Tensor linalg_trace(const Tensor& self, int64_t offset) { | ||
TORCH_CHECK(self.dim() >= 2, | ||
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user documented name is input
(per your docs below) and these warnings should start with the name of the operation like this:
torch.linalg.trace(): input should have at least...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to change the user-facing name of this argument to A
, which is the name we use throughout torch.linalg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry Updated tests. PTAL! |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #62255 (cc/ @mruberry, @rgommers, @emcastillo, @kmaehashi)
This PR adds support of
torch.linalg.trace
for the compatibility with NumPy's interface and Python array API standard.TODO:
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano