8000 Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd by xwang233 · Pull Request #48436 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd #48436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 60 commits into from

Conversation

xwang233
Copy link
Collaborator
@xwang233 xwang233 commented Nov 25, 2020

This PR adds cusolver gesvdj and gesvdjBatched to the backend of torch.svd.

I've tested the performance using cuda 11.1 on 2070, V100, and A100. The cusolver gesvdj and gesvdjBatched performances are better than magma in all square matrix cases. So cusolver backend will replace magma backend when available.

When both matrix dimensions are no greater than 32, gesvdjBatched is used. Otherwise, gesvdj is used.

Detailed benchmark is available at https://github.com/xwang233/code-snippet/tree/master/linalg/svd.

Some relevant code and discussions

See also #42666 #47953

Close #50516

@xwang233
Copy link
Collaborator Author
xwang233 commented Nov 25, 2020

Benchmark for A100 https://github.com/xwang233/code-snippet/tree/master/linalg/svd/A100

Benchmark for V100 https://github.com/xwang233/code-snippet/tree/master/linalg/svd/V100

Benchmark on RTX 2070 super + E5 2680 v3

  • before: magma
  • after_1: cusolver gesvd only
  • after_2: cusolver gesvdj only
  • after_3: gesvdj + gesvdjBatched (when both dims <= 32) [this PR]
benchmark

time is in ms (10^-3 s)

shape cpu before after_1 after_2 after_3
[] 2 torch.float32 0.025 1.624 0.206 0.215 0.255
[] 4 torch.float32 0.027 1.381 0.299 0.312 0.359
[] 8 torch.float32 0.035 1.364 0.453 0.452 0.502
[] 16 torch.float32 0.064 1.429 0.778 0.464 0.614
[] 32 torch.float32 0.344 2.244 1.493 0.613 0.661
[] 64 torch.float32 0.814 2.268 3.543 1.205 1.262
[] 128 torch.float32 2.442 8.966 8.623 2.940 2.959
[] 256 torch.float32 9.296 28.048 25.734 7.912 8.040
[] 512 torch.float32 40.841 71.162 102.131 22.941 23.547
[] 1024 torch.float32 183.968 273.447 579.854 99.649 98.451
[1] 2 torch.float32 0.036 1.942 0.231 0.201 0.204
[1] 4 torch.float32 0.039 2.317 0.335 0.309 0.341
[1] 8 torch.float32 0.047 2.272 0.507 0.411 0.422
[1] 16 torch.float32 0.082 1.607 0.874 0.504 0.524
[1] 32 torch.float32 0.294 2.372 1.632 0.611 0.573
[1] 64 torch.float32 0.777 2.815 3.732 1.213 1.202
[1] 128 torch.float32 2.328 9.200 9.064 2.901 2.915
[1] 256 torch.float32 9.093 29.403 26.150 8.111 8.239
[1] 512 torch.float32 39.886 75.919 102.661 23.375 23.989
[1] 1024 torch.float32 168.504 242.375 582.987 98.094 99.713
[2] 2 torch.float32 0.031 3.358 0.322 0.316 0.098
[2] 4 torch.float32 0.036 2.935 0.497 0.531 0.156
[2] 8 torch.float32 0.045 2.935 0.804 0.686 0.196
[2] 16 torch.float32 0.101 3.027 1.472 0.935 0.270
[2] 32 torch.float32 0.603 4.691 2.862 1.137 0.353
[2] 64 torch.float32 1.670 5.471 6.917 2.325 2.355
[2] 128 torch.float32 4.995 17.448 17.218 5.814 5.831
[2] 256 torch.float32 18.183 54.771 50.414 15.998 16.267
[2] 512 torch.float32 79.924 145.471 200.012 46.597 47.683
[2] 1024 torch.float32 334.309 494.471 1162.145 196.642 195.095
[4] 2 torch.float32 0.029 5.619 0.546 0.548 0.119
[4] 4 torch.float32 0.044 5.669 0.939 1.030 0.178
[4] 8 torch.float32 0.074 5.727 1.575 1.245 0.217
[4] 16 torch.float32 0.186 5.936 2.873 1.769 0.308
[4] 32 torch.float32 1.127 9.218 5.634 2.205 0.372
[4] 64 torch.float32 3.011 9.099 13.716 4.545 4.822
[4] 128 torch.float32 8.998 36.183 34.052 11.562 11.599
[4] 256 torch.float32 35.854 114.162 100.016 32.560 32.481
[4] 512 torch.float32 162.641 280.715 394.425 93.996 95.387
[8] 2 torch.float32 0.036 11.721 1.036 0.988 0.102
[8] 4 torch.float32 0.061 11.661 1.745 1.900 0.143
[8] 8 torch.float32 0.121 11.655 3.071 2.628 0.229
[8] 16 torch.float32 0.352 11.931 5.755 3.385 0.276
[8] 32 torch.float32 1.889 15.173 11.331 4.233 0.356
[8] 64 torch.float32 5.414 19.499 27.344 9.108 9.322
[8] 128 torch.float32 17.462 70.354 67.321 23.154 23.425
[8] 256 torch.float32 72.542 234.389 200.375 64.568 65.086
[16] 2 torch.float32 0.049 22.706 1.945 1.884 0.101
[16] 4 torch.float32 0.097 22.740 3.349 3.793 0.160
[16] 8 torch.float32 0.220 22.971 5.863 5.057 0.201
[16] 16 torch.float32 0.685 23.412 11.126 6.664 0.281
[16] 32 torch.float32 3.780 29.633 22.174 8.225 0.355
[16] 64 torch.float32 11.145 40.003 53.875 17.867 18.709
[16] 128 torch.float32 34.763 138.227 134.162 45.664 46.196
[32] 2 torch.float32 0.073 45.338 3.819 3.669 0.118
[32] 4 torch.float32 0.175 45.438 6.822 7.309 0.160
[32] 8 torch.float32 0.425 47.594 11.824 10.198 0.201
[32] 16 torch.float32 1.342 48.721 22.410 13.252 0.283
[32] 32 torch.float32 7.234 61.844 44.549 16.419 0.364
[32] 64 torch.float32 22.023 78.596 107.863 35.681 37.525
[64] 2 torch.float32 0.125 92.296 7.515 7.292 0.127
[64] 4 torch.float32 0.327 91.349 13.519 14.703 0.201
[64] 8 torch.float32 0.841 93.665 23.592 20.017 0.250
[64] 16 torch.float32 2.658 95.347 44.793 26.501 0.325
[64] 32 torch.float32 15.419 122.030 89.137 33.202 0.536
[128] 2 torch.float32 0.229 186.144 15.096 14.582 0.151
[128] 4 torch.float32 0.628 186.209 26.771 29.967 0.286
[128] 8 torch.float32 1.646 184.716 47.386 39.732 0.390
[128] 16 torch.float32 5.277 191.366 89.578 52.256 0.542
[256] 2 torch.float32 0.439 361.648 29.828 28.889 0.195
[256] 4 torch.float32 1.234 371.878 53.513 59.414 0.423
[256] 8 torch.float32 3.279 365.996 94.705 79.853 0.632
[512] 2 torch.float32 0.847 737.991 59.543 57.562 0.301
[512] 4 torch.float32 2.444 738.817 106.619 118.540 0.730
[1024] 2 torch.float32 1.686 1475.765 119.638 115.037 0.491

@xwang233
Copy link
Collaborator Author
xwang233 commented Nov 25, 2020

reserved 2

@dr-ci
Copy link
dr-ci bot commented Nov 25, 2020

💊 CI failures summary and remediations

As of commit 3075adc (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@xwang233 xwang233 changed the title [WIP] Enable cusolver/cublas backend for torch.svd Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd Dec 7, 2020
@xwang233
Copy link
Collaborator Author

@heitorschueroff This PR is ready to go.

Copy link
Contributor
@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@heitorschueroff
Copy link
Contributor

@heitorschueroff This PR is ready to go.

Thank you @xwang233, I've imported it to phabricator.

@heitorschueroff
Copy link
Contributor

@xwang233 and @ngimel, the manywheel build has a lot of cuda related linker errors, are these related to this PR?

@ngimel
Copy link
Collaborator
ngimel commented Jan 21, 2021

Yeah, looks like it. They are for lascl function (scale matrix by scalar), unclear which library it should come from.

@xwang233
Copy link
Collaborator Author

I tried this on cuda 10.2 with python setup.py install and it finished fine. Could this be a problem of the manywheel builder? I will submit an internal issue for this linker problem.

@ngimel
Copy link
Collaborator
ngimel commented Jan 21, 2021

@seemethere, @malfet any ideas on why manywheel build is failing? What's the best way to repro it?

@malfet
Copy link
Contributor
malfet commented Jan 22, 2021

@xwang233, @ngimel big difference between regular builds and manywheels one, is that it compiles code for all GPU architectures as well as links against them statically. Although, it's totally possible that cuSolver_static is stale

@xwang233
Copy link
Collaborator Author
xwang233 commented Jan 22, 2021

Thanks @malfet . We were able to reproduce the exact error message of undefined reference to 'cudsh_{s,d,c,z}lascl_'. For the static linking, we found that, not only -lcusolver_static, but also -llapack_static needs to be added. liblapack_static.a is also distributed in the same cuda lib path.

FYI, found doc here https://docs.nvidia.com/cuda/cusolver/index.html#static-link-lapack

@malfet
Copy link
Contributor
malfet commented Jan 22, 2021

@xwang233 cool, can you please add above-mentioned dependency to

set_property(
TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
"${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a")
(guarded with appropriate version of if(CUDA_VERSION VERSION_GREATER_EQUAL 10.2) if needed)

@heitorschueroff
Copy link
Contributor

Thank you @xwang233 for adding the static dependencies, that fixed the previous errors. However, there is a new ROCM failure in the test test_broadcast_double_backwards_gpu. @mruberry I know ROCM can fail sometimes, is this something we should be concerned with?

@mruberry
Copy link
Collaborator

Thank you @xwang233 for adding the static dependencies, that fixed the previous errors. However, there is a new ROCM failure in the test test_broadcast_double_backwards_gpu. @mruberry I know ROCM can fail sometimes, is this something we should be concerned with?

That is a very strange ROCm error. I don't think it's related.

@ngimel
Copy link
Collaborator
ngimel commented Jan 22, 2021

It's been showing up on other PRs as well

Copy link
Contributor
@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@heitorschueroff merged this pull request in 186c3da.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0