8000 pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood. by nikitaved · Pull Request #66092 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood. #66092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

nikitaved
Copy link
Collaborator
@nikitaved nikitaved commented Oct 4, 2021

Fixes #65911. Also enables complex support/tests for linalg_pinv in OpInfo.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @walterddr @IvanYashchuk @xwang233

@nikitaved nikitaved added module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul complex_autograd ci/slow-gradcheck labels Oct 4, 2021
@pytorch-probot
Copy link
pytorch-probot bot commented Oct 4, 2021
CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/2b5431eed4d51d4b7566af98116cbc3c649a0978/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
puretorch-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

Comment on lines +8950 to +8951
# Only large tensors show issues with implicit backward used prior to
# explicit backward implementation.
Copy link
Collaborator Author
@nikitaved nikitaved Oct 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note: large tensors of low rank. In my environment I had to create a 1-rank 30x30 matrix to see issues with repeated "zeros" in the backward of SVD.

@ezyang ezyang removed their request for review October 4, 2021 23:16
@ezyang
Copy link
Contributor
ezyang commented Oct 4, 2021

not sure appropriate FB reviewer has been tagged yet

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Oct 5, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2b5431e (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Oct 08 11:26:22 RuntimeError: tensorflow/compil...'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:22 Exception in device=CPU:1: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:22 Traceback (most recent call last):
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
Oct 08 11:26:22     _start_fn(index, pf_cfg, fn, args)
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
Oct 08 11:26:22     fn(gindex, *args)
Oct 08 11:26:22   File "/var/lib/jenkins/workspace/xla/test/test_mp_rendezvous.py", line 22, in _mp_fn
Oct 08 11:26:22     replicas=replicas)
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/core/xla_model.py", line 875, in rendezvous
Oct 08 11:26:22     return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
Oct 08 11:26:22 RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:23 Traceback (most recent call last):
Oct 08 11:26:23   File "/var/lib/jenkins/workspace/xla/test/test_mp_rendezvous.py", line 35, in <module>
Oct 08 11:26:23     xmp.spawn(_mp_fn, args=())
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 394, in spawn
Oct 08 11:26:23     start_method=start_method)
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
Oct 08 11:26:23     while not context.join():
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 144, in join
Oct 08 11:26:23     exit_code=exitcode
Oct 08 11:26:23 torch.multiprocessing.spawn.ProcessExitedException: process 3 terminated with exit code 17

XLA failure

Job pytorch_xla_linux_bionic_py3_6_clang9_test is failing. Please create an issue with title prefixed by [PT_BREAK] in pytorch/xla and link to to this PR. If you have questions, please reach out to @ailzhang / @dlibenzi / @JackCaoG.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@nikitaved nikitaved changed the title pinv: forward/backward AD which is Frechet-differentiable in a rank-preserving neighborhood. pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood. Oct 5, 2021
Copy link
Collaborator
@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks both for finding the slick formula for this backward and the rather compact implementation!

Comment on lines 2120 to 2121
# Note that by making the columns of `a` and `b` orthonormal we make sure
# that the product matrix `a @ b.t()` has condition number 1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This saves us a lot of pain in future debugging.

Now, this note is slightly incorrect. The resulting matrix will have singular values 0 and 1, so the condition number will be infinite! Perhaps you mean that it has condition number 1 when restricted to its image?

Copy link
Collaborator Author
@nikitaved nikitaved Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly in the image, correct, so that pinv is stable.

sample_inputs_func=sample_inputs_linalg_pinv_singular,
# Only large tensors show issues with implicit backward used prior to
# explicit backward implementation.
decorators=[slowTest, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the slowTest decorator working as expected here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD It will apply the slowTest decorator to EVERY test generated by this OpInfo

@mruberry
Copy link
Collaborator
mruberry commented Oct 6, 2021

Cool! Do you have before/after perf numbers for the autograd, @nikitaved?

@nikitaved
Copy link
Collaborator Author

@mruberry, I did run some benchmarks and surprisingly this PR also improves performance.

This PR, cpu float32:

shape: (10, 10), device: cpu, dtype: torch.float32                                                                                                                                                                 
29.8 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)                                                                                                                                           
                                                                                                                                                                                                                   
shape: (1000, 10, 10), device: cpu, dtype: torch.float32                                                                                                                                                           
797 µs ± 7.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (100, 100), device: cpu, dtype: torch.float32                                                                                                                                                               
273 µs ± 3.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (1000, 100, 100), device: cpu, dtype: torch.float32                                                                                                                                                         
56.4 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                              
                                                                                                                                                                                                                   
shape: (1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                             
11.7 ms ± 96.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)                                                                                                                                            
                                                                                                                                                                     
6D47
                                              
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                         
159 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)         

Master, cpu float32:

shape: (10, 10), device: cpu, dtype: torch.float32                                                                                                                                                                 
86.3 µs ± 3.55 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)                                                                                                                                          
                                                                                                                                                                                                                   
shape: (1000, 10, 10), device: cpu, dtype: torch.float32                                                                                                                                                           
2.23 ms ± 398 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)                                                                                                                                             
                                                                                                                                                                                                                   
shape: (100, 100), device: cpu, dtype: torch.float32                                                                                                                                                               
535 µs ± 33.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (1000, 100, 100), device: cpu, dtype: torch.float32                                                                                                                                                         
174 ms ± 6.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                              
                                                                                                                                                                                                                   
shape: (1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                             
26.9 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                             
                                                                                                                                                                                                                   
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                         
392 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)     

This PR, cuda float32:

shape: (10, 10), device: cuda, dtype: torch.float32
111 µs ± 3.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

shape: (1000, 10, 10), device: cuda, dtype: torch.float32
332 µs ± 998 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (100, 100), device: cuda, dtype: torch.float32
111 µs ± 772 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

shape: (1000, 100, 100), device: cuda, dtype: torch.float32
7.25 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (1000, 1000), device: cuda, dtype: torch.float32
3.21 ms ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
29.8 ms ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Master, cuda float32:

shape: (10, 10), device: cuda, dtype: torch.float32
282 µs ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (1000, 10, 10), device: cuda, dtype: torch.float32
565 µs ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (100, 100), device: cuda, dtype: torch.float32
312 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (1000, 100, 100), device: cuda, dtype: torch.float32
11.8 ms ± 41.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (1000, 1000), device: cuda, dtype: torch.float32
4.72 ms ± 32.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
42.4 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@lezcano
Copy link
Collaborator
lezcano commented Oct 8, 2021

Faster and correct! There's no better combination than that :)

Copy link
Collaborator
@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Can you fix the last lint (EDIT: Ho it looks like the job itself failed...)and I'll merge this.

@facebook-github-bot
Copy link
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed complex_autograd module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

pinv could be differentiable on a wider range of inputs
7 participants
0