-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd #48436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
625a32f
3bc7a6e
071e8ab
8ca446d
86065fc
663b6f9
99a5e0c
bd48b35
692d67e
76bf235
a622eed
927d569
7d31052
9c6c800
5bb94a5
8056f59
29f5c7d
8cbd708
8a85768
0fb0af9
46a6027
9c08a67
9249a56
4446145
8b8c1d8
26f1e36
72edde6
7074ba8
445ba77
549738b
e6d9f09
8860cfa
478798b
7956abd
b57310a
d06fbe2
3bb3a88
a3a45b1
5bb8a9c
12bb8f3
c298256
7274295
ff2133e
932021b
697dca5
2f45e43
1c8cbd7
eedb710
ca2ddd8
ffec626
7cdb8e6
813ca13
00d260e
60b07f7
7d8f3ad
bb31bf9
f2efbaf
66ebe3b
a99df7d
3075adc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,7 +152,7 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) { | |
|
||
// call cusolver gesvdj function to calculate svd | ||
template<typename scalar_t> | ||
inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv) { | ||
inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv, bool some) { | ||
using value_t = typename c10::scalar_value_type<scalar_t>::type; | ||
auto self_data = self.data_ptr<scalar_t>(); | ||
auto U_data = U.data_ptr<scalar_t>(); | ||
|
@@ -177,7 +177,7 @@ inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& | |
auto handle = at::cuda::getCurrentCUDASolverDnHandle(); | ||
auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; | ||
at::cuda::solver::gesvdj<scalar_t>( | ||
handle, jobz, /*econ=*/ 1, m, n, | ||
handle, jobz, /*econ=*/ some ? 1 : 0, m, n, | ||
self_data + i * self_stride, | ||
m, | ||
S_data + i * S_stride, | ||
|
@@ -195,14 +195,14 @@ inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& | |
|
||
// wrapper around _apply_svd_lib_gesvdj that handles dtype dispatch, | ||
// creates a working copy of the input, and creates V^H from the V returned by gesvdj | ||
inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv) { | ||
inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv, bool some) { | ||
const int64_t m = self.size(-2); | ||
const int64_t n = self.size(-1); | ||
Tensor self_working_copy = cloneBatchedColumnMajor(self); | ||
VT = VT.transpose(-2, -1); // gesvdj returns V instead of V^H | ||
|
||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] { | ||
_apply_svd_lib_gesvdj<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv); | ||
_apply_svd_lib_gesvdj<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv, some); | ||
}); | ||
|
||
VT = VT.conj(); | ||
|
@@ -280,7 +280,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_lib(const Tensor& self, bool | |
if (m <= 32 && n <= 32 && batch_size > 1 && (!some || m == n)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ngimel asks: should this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It follows the same heuristic here in tensorflow There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the cuSOLVER's documentation, it seems that |
||
apply_svd_lib_gesvdjBatched(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv); | ||
} else { | ||
apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv); | ||
apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv, some); | ||
} | ||
|
||
// A device-host sync will be performed. | ||
|
Uh oh!
There was an error while loading. Please reload this page.