8000 Support `torch.linalg.trace` by asi1024 · Pull Request #62714 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support torch.linalg.trace #62714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed

Support torch.linalg.trace #62714

wants to merge 23 commits into from

Conversation

asi1024
Copy link
Contributor
@asi1024 asi1024 commented Aug 4, 2021

Fixes #62255 (cc/ @mruberry, @rgommers, @emcastillo, @kmaehashi)

This PR adds support of torch.linalg.trace for the compatibility with NumPy's interface and Python array API standard.

>>> torch.linalg.trace(torch.arange(18).reshape(2, 3, 3))
tensor([12, 14, 16])

TODO:

  • Add documentation
  • Add tests

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Aug 4, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 9622b99 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-02-25T11:20:02.4745605Z FAIL [0.019s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)
2022-02-25T11:20:02.4603374Z   test_sparse_zeros_tanh_cuda_float64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4623876Z   test_sparse_zeros_tanh_cuda_int16 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4644090Z   test_sparse_zeros_tanh_cuda_int32 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4664636Z   test_sparse_zeros_tanh_cuda_int64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4684392Z   test_sparse_zeros_tanh_cuda_int8 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4704217Z   test_sparse_zeros_tanh_cuda_uint8 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4724154Z   test_sparse_zeros_trunc_cuda_float32 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.016s)
2022-02-25T11:20:02.4743513Z   test_sparse_zeros_trunc_cuda_float64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)
2022-02-25T11:20:02.4744595Z 
2022-02-25T11:20:02.4744990Z ======================================================================
2022-02-25T11:20:02.4745605Z FAIL [0.019s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)
2022-02-25T11:20:02.4746326Z ----------------------------------------------------------------------
2022-02-25T11:20:02.4747129Z Traceback (most recent call last):
2022-02-25T11:20:02.4749094Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 376, in instantiated_test
2022-02-25T11:20:02.4750339Z     result = test(self, **param_kwargs)
2022-02-25T11:20:02.4751375Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 2951, in wrapped
2022-02-25T11:20:02.4752299Z     f(self, *args, **kwargs, coalesced=False)
2022-02-25T11:20:02.4753010Z   File "test_sparse.py", line 1275, in test_sparse_addmm
2022-02-25T11:20:02.4753575Z     test_shape(7, 8, 9, 20, False, (1, 1))
2022-02-25T11:20:02.4754352Z   File "test_sparse.py", line 1264, in test_shape
2022-02-25T11:20:02.4754943Z     self.assertEqual(Y, Y_dense)

1 failure not recognized by patterns:

Job Step Action
GitHub Actions Lint / clang-format Run clang-format 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry
Copy link
Collaborator

Hey @asi1024, just checking in on this PR because it's still marked as "draft." Is it ready for a review?

@nateanl nateanl mentioned this pull request Aug 17, 2021
5 tasks
@asi1024
Copy link
Contributor Author
asi1024 commented Aug 30, 2021

@mruberry Sorry for my late response. I will mark "ready for review" after adding tests and documentation!

@asi1024 asi1024 marked this pull request as ready for review September 16, 2021 09:52
@asi1024 asi1024 changed the title [WIP] Support torch.linalg.trace Support torch.linalg.trace Sep 16, 2021
Copy link
Collaborator
@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, @asi1024, thanks for your contribution! I left a suggestion to use the CompositeImplicitAutograd dispatch key that would allow us to remove the _backward function trimming down unnecessary code. After that, I think the PR should be good to go.

python_module: linalg
variants: method, function
dispatch:
CPU, CUDA: linalg_trace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is overwritten by CompositeExplicitAutograd: linalg_trace. The code in the linalg_trace function is independent of the device so CPU, CUDA specialization is not needed here and CompositeExplicitAutograd is the correct choice of the dispatch key.

Suggested change
CPU, CUDA: linalg_trace

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought using CompositeImplicitAutograd should be better, then the backward function is not needed.

CPU, CUDA: linalg_trace
CompositeExplicitAutograd: linalg_trace

- func: linalg_trace_backward(Tensor grad, int[] sizes, int offset) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please remove this entry from native_functions.yaml?
Most of the backward functions in PyTorch are placed in torch\csrc\autograd\FunctionsManual.cpp and torch\csrc\autograd\FunctionsManual.h, so let's move linalg_trace_backward from ReduceOps.cpp.

@@ -1112,6 +1125,13 @@ void impl_func_prod(
}
}

Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> opt_dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed in this PR?
The function prod(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor should be autogenerated with

structured_delegate: prod.int_out

@ezyang ezyang removed their request for review September 16, 2021 14:35
@codecov
Copy link
codecov bot commented Sep 16, 2021

Codecov Report

Merging #62714 (32e5ee4) into master (0dc9872) will increase coverage by 0.08%.
The diff coverage is 100.00%.

❗ Current head 32e5ee4 differs from pull request most recent head 7092683. Consider uploading reports for the commit 7092683 to get more accurate results

@@            Coverage Diff             @@
##           master   #62714      +/-   ##
==========================================
+ Coverage   66.37%   66.46%   +0.08%     
==========================================
  Files         738      727      -11     
  Lines       94170    93581     -589     
==========================================
- Hits        62510    62200     -310     
+ Misses      31660    31381     -279     

@soulitzer soulitzer removed their request for review September 16, 2021 16:53
@albanD albanD removed their request for review September 16, 2021 17:47
@heitorschueroff heitorschueroff added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 16, 2021
Copy link
Collaborator
@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few points regarding docs / testing.

trace = _add_docstr(_linalg.linalg_trace, r"""
trace(input, offset=0) -> Tensor

Returns the sum of the elements of the diagonal.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What diagonal? Given that we have the parameter offset, this should probably read:

Returns the sum of the elements of a diagonal.

Followed by an explanation of how the offset parameter chooses a diagonal.

y = torch.linalg.trace(x)
xn = np.array(x.cpu().numpy()).reshape(shape)
yn = np.trace(xn, axis1=-2, axis2=-1)
yn = torch.from_numpy(np.asarray(yn))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8000

The call to torch.from_numpy here and a few lines below is not necessary as assertEqual is able to compare tensors and numpy arrays. Even more, not calling it is often faster. Same below.

Comment on lines 8017 to 8018
xn = np.array(x.cpu().numpy()).reshape(shape)
yn = np.trace(xn, axis1=-2, axis2=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work without all the explicit castings by simply doing

Suggested change
xn = np.array(x.cpu().numpy()).reshape(shape)
yn = np.trace(xn, axis1=-2, axis2=-1)
yn = np.trace(x.cpu(), axis1=-2, axis2=-1)

Same below.

@asi1024
Copy link
Contributor Author
asi1024 commented Sep 29, 2021

@lezcano Thank you for your reviews! Could you take another look?

@asi1024
Copy link
Contributor Author
asi1024 commented Feb 14, 2022

@lezcano Now all CIs have passed! PTAL!

@lezcano
Copy link
Collaborator
lezcano commented Feb 14, 2022

Yeah, this looks good to me. I just found (yet another, ugh, sorry).
We're missing to add the relevant entry to the docs in docs/source/linalg.rst! It could probably go between diagonal and det.

Sorry for that! I believe this is the last missing thing! :D
Otherwise, we can just wait for @mruberry to review.

@asi1024
Copy link
Contributor Author
asi1024 commented Feb 21, 2022

@lezcano The CI failures look unrelated to this PR. Could you take another look?

@lezcano
Copy link
Collaborator
lezcano commented Feb 21, 2022

As mentioned, this LGTM. We now just need to wait for @mruberry to have a look. He's been a bit busy lately, but let's hope he finds some time soon :)

trace = _add_docstr(_linalg.linalg_trace, r"""
trace(input, *, offset=0, out=None) -> Tensor

Computes the trace of a matrix.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really well written.

inputs = (
((S, S), 0),
((S, M), 0),
((S, S), 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a sample with a negative offset and a comment explaining the format of these tuples

@@ -9979,6 +9993,14 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_slogdet,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],),
OpInfo('linalg.trace',
ref=np.trace,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice reference

@@ -7764,6 +7764,28 @@ def test_tensordot(self, device):
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
self.assertEqual(a, an)

def test_linalg_trace(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test

@@ -2574,6 +2574,20 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
low=None, high=None,
requires_grad=requires_grad))),)

def sample_inputs_linalg_trace(self, device, dtype, requires_grad, **kwargs):
inputs = (
((S, S), 0),
Copy link
Collaborator
@mruberry mruberry Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens on empty tensors? We should probably add a case for them. What about a batched sample input too? (S, S, S)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace function implemented in this PR returns different values from numpy.trace for 3-dim inputs. numpy.trace reduces with axis1=0, axis2=1 whereas array API specifies to reduce with axis1=-2, axis2=-1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to compare then against a lambda with our same defaults that the calls into np.trace?

def test_linalg_trace(self, device):
inputs = [
{'shape': (1, 1), 'offsets': [0]},
{'shape': (10, 1), 'offsets': [0, -9]},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if offset is an absurd number, like 100?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeError will be raised if offset is out of range. I will add a test for this case!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeError will be raised if offset is out of range. I will add a test for this case!

Make it an ErrorInput

@@ -7764,6 +7764,28 @@ def test_tensordot(self, device):
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
self.assertEqual(a, an)

def test_linalg_trace(self, device):
inputs = [
{'shape': (1, 1), 'offsets': [0]},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the empty case here would be interesting, too

@@ -6503,6 +6503,14 @@
device_check: NoCheck
device_guard: False

- func: linalg_trace.out(Tensor self, *, int offset=0, Tensor(a!) out) -> Tensor(a!)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice schemas

// see https://github.com/pytorch/pytorch/pull/47305,
Tensor linalg_trace(const Tensor& self, int64_t offset) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user documented name is input (per your docs below) and these warnings should start with the name of the operation like this:

torch.linalg.trace(): input should have at least...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to change the user-facing name of this argument to A, which is the name we use throughout torch.linalg

@mruberry mruberry self-requested a review February 22, 2022 19:21
Copy link
Collaborator
@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @asi1024! Overall this looks really good but there are a few comments inline from me and @lezcano that still need to be addressed. Just ping me when they're done and we'll get this merged.

@asi1024
Copy link
Contributor Author
asi1024 commented Feb 26, 2022

@mruberry Updated tests. PTAL!

@rgommers rgommers added the module: python array api Issues related to the Python Array API label Apr 13, 2022
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 12, 2022
@github-actions github-actions bot closed this Jul 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: python array api Issues related to the Python Array API open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support torch.linalg.trace
9 participants
0