8000 Batched SVD using cuSolver · Issue #14175 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Batched SVD using cuSolver #14175

New 8000 issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jjbouza opened this issue Nov 19, 2018 · 26 comments
Closed

Batched SVD using cuSolver #14175

jjbouza opened this issue Nov 19, 2018 · 26 comments
Assignees
Labels
feature A request for a proper, new feature. todo Not as important as medium or high priority tasks, but we will work on these. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jjbouza
Copy link
jjbouza commented Nov 19, 2018

It seems there are several people working on batch mode linear algebra routines, i.e. #11796 and #14071 are active.

Any plans for adding a batch mode SVD? This would be useful for certain implementations of group equivariant networks.

I'm not completely familiar with the PyTorch codebase, but if I'm not mistaken the usual backend used for linear algebra computations on the GPU is MAGMA. I don't think MAGMA implements a batch SVD operation, but cuSolver does for small matrices (max 32x32). For larger matrices we can just fall back to the current approach.

If no one else is planning on working on this I can take a look at it. The correct way to do this would be to model something like #9949, right?

I realize several others have made similar suggestions: #10172, #4689. Those issues don't seem active however.

@soumith soumith added the todo Not as important as medium or high priority tasks, but we will work on these. label Nov 19, 2018
@soumith
Copy link
Member
soumith commented Nov 19, 2018

this seems like a pretty good thing to add. @vishwakftw can you mentor @jjbouza to get this done?

@vishwakftw
Copy link
Contributor

@soumith sure I would be happy to mentor @jjbouza . One question, how will we achieve batched SVD on GPU for matrices larger than 32x32?

@jjbouza
Copy link
Author
jjbouza commented Nov 20, 2018

Hello @vishwakftw, thanks for the help.

The easy solution is to fall back to CPU for matrices larger than 32x32.

There might be some other options though. For example, using CUDA streams I think we can parallelize the regular (non-batch) SVD operation provided by cuSolver. AFAIK this is what was done before cuSolver provided the batch mode SVD natively. For an example of this see e.g. see the first answer here.

I would need to test and benchmark this approach though.

@vishwakftw
Copy link
Contributor
vishwakftw commented Nov 20, 2018

I think it wouldn't be wise to paralyze users by asking them to use CPU for matrices larger than 32x32, and I am instead advocating for a uniform API use across CPU and GPU.

Your idea utilizing cuSolver is a good one. On a related note, I think we can use MAGMA for SVD for batches with matrices larger than 32x32. The code would be very similar to LAPACK code for batching in a loop, which you should be able to find in aten/src/ATen/native/BatchLinearAlgebra.cpp. As far as the code base goes, I don't think cuSolver has been extensively used.

@jjbouza
Copy link
Author
jjbouza commented Nov 20, 2018

Agreed, using MAGMA+for loop on GPU would work. I can benchmark the cuSolver CUDA streams approach against this. First I'm going to implement the batched cuSolver routine for batch size <= 32x32. I'll model your batched inverse implementation to do this, so I'll let you know if I have any questions.

@vishwakftw
Copy link
Contributor

Great, that's sounds good to me! Please feel free to ping here or on Slack :)

@jjbouza
Copy link
Author
jjbouza commented Nov 22, 2018

Hey @vishwakftw, like you said, it doesn't appear that the code base has any support for cuSolver. To add this I'm going to need to generate the cuSolver CUDA handles and pass them around to the cuSolver calls.

I've been looking through the cuBLAS code to get an idea for how to do this, so I just want to make sure I've got it straight. Heres the idea:

Add a THCState_getCurrentcuSolverHandle function to aten/src/THC/THCGeneral to generate or access the cuSolver handle for a THCState (this would be an analog to the THCState_getCurrentBlasHandle function)

Add a getCurrentCUDASolverHandle function to aten/src/ATen/cuda/CUDAContext that calls the above THCState_getCurrentcuSolverHandle (this would be an analog to the at::cuda::getCurrentCUDABlasHandle function)

Does this look like the right pattern for cuSolver integration? Thanks

@vishwakftw
Copy link
Contributor

Hey @jjbouza . This path sounds good to me!

@fmassa
Copy link
Member
fmassa commented Nov 22, 2018

One quick comment: if possible, try not to add new things to THC, but instead do them directly in ATen

@vishwakftw
Copy link
Contributor

@jjbouza , just as @fmassa has suggested, please add the functionality to ATen instead of THC. This means you will have to move the caching logic from THC to ATen for the cuSolver handle. You can still use functions from THC in ATen, so this should not be an issue.

@vishwakftw
Copy link
Contributor

@jjbouza any updates?

@jjbouza
Copy link
Author
jjbouza commented Dec 28, 2018

@vishwakftw Have been slowly making progress. I should have something soon

@daniyar-niantic
Copy link

Would the batched SVD still be differentiable?

@vishwakftw
Copy link
Contributor

Yes, it should be. It is effectively batch_count number of SVDs performed at once; I don't see why it shouldn't be differentiable. Of course, there is always the edge case comprising of ill-conditioned matrices.

@Balandat
Copy link
Contributor

this would be great to have!

@vishwakftw
Copy link
Contributor

cc: @jjbouza.

@KinglittleQ
Copy link
Contributor
KinglittleQ commented Feb 19, 2019

Hey, guys. I've implemented a batch version of SVD by cuSolver as an individual package, including forward and backward function. It's not perfect now and only supports torch.CudaFloatTensor as input but it may be helpful.

Torch-batch-svd: https://github.com/KinglittleQ/torch-batch-svd

@vishwakftw
Copy link
Contributor

This is awesome! Thanks for doing this. For inputs of size greater than 32 x 32, we can probably perform the SVD computation in a loop.

For CPU, this has to be performed in a loop anyways, so I think this is a pretty good implementation!

Do you mind sending in a PR for this? I'd be happy to clarify any questions that you may have about porting this is PyTorch / ATen.

@KinglittleQ
Copy link
Contributor

I can try to port it into Pytorch if you could tell me where should I add the forward and backward function. But I'm not sure I have time to complete it as there's other work for me to do now.

@vishwakftw
Copy link
Contributor

Is it fine if I complete the port on your behalf?

@KinglittleQ
Copy link
Contributor

Of course, it's OK.

@vishwakftw
Copy link
Contributor

Thank you, I'll ping you when I send in the pull request.

@vishwakftw vishwakftw self-assigned this Mar 4, 2019
@SaiK95
Copy link
SaiK95 commented May 14, 2019

@vishwakftw Hi, can you tell me if there's a resolution to this thread?
Or what the best way to perform batch SVD on small tensors (smaller than 32*32) is?

@vishwakftw
Copy link
Contributor

Hi @SaiK95. Sorry I haven’t been able to get to this in the past few months due to college. I’ll get to it in the summer, within a month.

Currently, the best and only possible way to do it would be to run a loop.

@SaiK95
Copy link
SaiK95 commented May 14, 2019

@vishwakftw No problem, thanks for the quick reply!

@pietern pietern added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 14, 2019
@vishwakftw
Copy link
Contributor

Closing this since the feature is available on master. The current implementation uses sequential MAGMA calls in a for-loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. todo Not as important as medium or high priority tasks, but we will work on these. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants
0