8000 Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd by xwang233 · Pull Request #48436 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add cusolver gesvdj and gesvdjBatched to the backend of torch.svd #48436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 60 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
625a32f
cusolver parallel stream launch macro
xwang233 Nov 25, 2020
3bc7a6e
fix unused test_inverse statements
xwang233 Nov 25, 2020
071e8ab
move pivot allocation outside of for loop
xwang233 Nov 25, 2020
8ca446d
test_inverse test batch_size = 2 (parallel cusolver path)
xwang233 Nov 25, 2020
86065fc
rename variable
xwang233 Nov 25, 2020
663b6f9
update macro
xwang233 Nov 26, 2020
99a5e0c
[WIP] temp
xwang233 Dec 1, 2020
bd48b35
Merge remote-tracking branch 'upstream/viable/strict' into cusolver-svd
xwang233 Dec 1, 2020
692d67e
cusolver gesvd impl done [benchmark slower than magma]
xwang233 Dec 4, 2020
76bf235
Merge remote-tracking branch 'upstream/viable/strict' into cusolver-svd
xwang233 Dec 4, 2020
a622eed
skip test_linalg.test_norm_extreme_values
xwang233 Dec 4, 2020
927d569
gesvdj+parallel stream [benchmark looks very good]
xwang233 Dec 4, 2020
7d31052
enable on CPU
xwang233 Dec 4, 2020
9c6c800
vt conj
8000 xwang233 Dec 4, 2020
5bb94a5
lint
xwang233 Dec 5, 2020
8056f59
precision override for float32
xwang233 Dec 5, 2020
29f5c7d
gesvdjBatched
xwang233 Dec 5, 2020
8cbd708
remove gesvd code
xwang233 Dec 6, 2020
8a85768
refactor cusolver inverse heuristic code
xwang233 Dec 6, 2020
0fb0af9
Merge remote-tracking branch 'upstream/viable/strict' into cusolver-svd
xwang233 Dec 6, 2020
46a6027
remove gesvd code in cudasolver.{h,cpp}
xwang233 Dec 6, 2020
9c08a67
[doc] cusolver gesvdj and batched
xwang233 Dec 6, 2020
9249a56
test decorator
xwang233 Dec 7, 2020
4446145
Merge remote-tracking branch 'upstream/viable/strict' into cusolver-svd
xwang233 Dec 7, 2020
8b8c1d8
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 5, 2021
26f1e36
reword test skipping
xwang233 Jan 5, 2021
72edde6
add at::parallel_for for parallel stream launch
xwang233 Jan 5, 2021
7074ba8
Merge remote-tracking branch 'upstream/viable/strict' into cusolver-svd
xwang233 Jan 5, 2021
445ba77
add a cuda guard for `at::parallel_for`
xwang233 Jan 5, 2021
549738b
lint
xwang233 Jan 5, 2021
e6d9f09
revert at::parallel_for changes
xwang233 Jan 6, 2021
8860cfa
remove CUDA_PARALLEL_STREAM_LAUNCH
xwang233 Jan 6, 2021
478798b
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 9, 2021
7956abd
test decorators
xwang233 Jan 9, 2021
b57310a
transpose
xwang233 Jan 10, 2021
d06fbe2
[Action Required] wrong test: svd is not unique
xwang233 Jan 10, 2021
3bb3a88
precision override
xwang233 Jan 10, 2021
a3a45b1
doc change
xwang233 Jan 10, 2021
5bb8a9c
comments
xwang233 Jan 10, 2021
12bb8f3
semi colon
xwang233 Jan 10, 2021
c298256
comments
xwang233 Jan 10, 2021
7274295
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 10, 2021
ff2133e
lint
xwang233 Jan 11, 2021
932021b
comments
xwang233 Jan 11, 2021
697dca5
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 11, 2021
2f45e43
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 13, 2021
1c8cbd7
comments
xwang233 Jan 13, 2021
eedb710
comment svd non unique
xwang233 Jan 13, 2021
ca2ddd8
doc
xwang233 Jan 14, 2021
ffec626
test abs of singular vectors
xwang233 Jan 14, 2021
7cdb8e6
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 14, 2021
813ca13
fix "some" in gesvdj path
xwang233 Jan 14, 2021
00d260e
delete unused jobchar
xwang233 Jan 14, 2021
60b07f7
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 14, 2021
7d8f3ad
remove partial indexing for comparison
xwang233 Jan 20, 2021
bb31bf9
[doc] :math:
xwang233 Jan 20, 2021
f2efbaf
comment, column-major VT
xwang233 Jan 20, 2021
66ebe3b
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 20, 2021
a99df7d
Merge remote-tracking branch 'upstream/master' into cusolver-svd
xwang233 Jan 22, 2021
3075adc
try liblapack_static, may need `if cuda >= 10.1`
xwang233 Jan 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ if(USE_CUDA AND NOT USE_ROCM)
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static
)
else()
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
Expand Down
190 changes: 190 additions & 0 deletions aten/src/ATen/cuda/CUDASolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,196 @@ void getrs<c10::complex<float>>(
info));
}


template<>
void gesvdj<float>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float* A, int lda, float* S, float* U,
int ldu, float *V, int ldv, int *info, gesvdjInfo_t params
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(float)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj(
handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
static_cast<float*>(dataPtr.get()),
lwork, info, params));
}

template<>
void gesvdj<double>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double* A, int lda, double* S, double* U,
int ldu, double *V, int ldv, int *info, gesvdjInfo_t params
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork 10000 , params));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(double)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj(
handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
static_cast<double*>(dataPtr.get()),
lwork, info, params));
}

template<>
void gesvdj<c10::complex<float>>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
int ldu, c10::complex<float> *V, int ldv, int *info, gesvdjInfo_t params
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj_bufferSize(
handle, jobz, econ, m, n,
reinterpret_cast<cuComplex*>(A),
lda, S,
reinterpret_cast<cuComplex*>(U),
ldu,
reinterpret_cast<cuComplex*>(V),
ldv, &lwork, params));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj(
handle, jobz, econ, m, n,
reinterpret_cast<cuComplex*>(A),
lda, S,
reinterpret_cast<cuComplex*>(U),
ldu,
reinterpret_cast<cuComplex*>(V),
ldv,
static_cast<cuComplex*>(dataPtr.get()),
lwork, info, params));
}

template<>
void gesvdj<c10::complex<double>>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
int ldu, c10::complex<double> *V, int ldv, int *info, gesvdjInfo_t params
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj_bufferSize(
handle, jobz, econ, m, n,
reinterpret_cast<cuDoubleComplex*>(A),
lda, S,
reinterpret_cast<cuDoubleComplex*>(U),
ldu,
reinterpret_cast<cuDoubleComplex*>(V),
ldv, &lwork, params));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj(
handle, jobz, econ, m, n,
reinterpret_cast<cuDoubleComplex*>(A),
lda, S,
reinterpret_cast<cuDoubleComplex*>(U),
ldu,
reinterpret_cast<cuDoubleComplex*>(V),
ldv,
static_cast<cuDoubleComplex*>(dataPtr.get()),
lwork, info, params));
}


template<>
void gesvdjBatched<float>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float* A, int lda, float* S, float* U,
int ldu, float *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(float)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched(
handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
static_cast<float*>(dataPtr.get()),
lwork, info, params, batchSize));
}

template<>
void gesvdjBatched<double>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double* A, int lda, double* S, double* U,
int ldu, double *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(double)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched(
handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
static_cast<double*>(dataPtr.get()),
lwork, info, params, batchSize));
}

template<>
void gesvdjBatched<c10::complex<float>>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
int ldu, c10::complex<float> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched_bufferSize(
handle, jobz, m, n,
reinterpret_cast<cuComplex*>(A),
lda, S,
reinterpret_cast<cuComplex*>(U),
ldu,
reinterpret_cast<cuComplex*>(V),
ldv, &lwork, params, batchSize));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched(
handle, jobz, m, n,
reinterpret_cast<cuComplex*>(A),
lda, S,
reinterpret_cast<cuComplex*>(U),
ldu,
reinterpret_cast<cuComplex*>(V),
ldv,
static_cast<cuComplex*>(dataPtr.get()),
lwork, info, params, batchSize));
}

template<>
void gesvdjBatched<c10::complex<double>>(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
int ldu, c10::complex<double> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched_bufferSize(
handle, jobz, m, n,
reinterpret_cast<cuDoubleComplex*>(A),
lda, S,
reinterpret_cast<cuDoubleComplex*>(U),
ldu,
reinterpret_cast<cuDoubleComplex*>(V),
ldv, &lwork, params, batchSize));

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork);

TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched(
handle, jobz, m, n,
reinterpret_cast<cuDoubleComplex*>(A),
lda, S,
reinterpret_cast<cuDoubleComplex*>(U),
ldu,
reinterpret_cast<cuDoubleComplex*>(V),
ldv,
static_cast<cuDoubleComplex*>(dataPtr.get()),
lwork, info, params, batchSize));
}

} // namespace solver
} // namespace cuda
} // namespace at
Expand Down
35 changes: 35 additions & 0 deletions aten/src/ATen/cuda/CUDASolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,41 @@ template<>
void getrs<c10::complex<float>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<float>));


#define CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype) \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params

template<class Dtype, class Vtype>
void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) {
TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name());
}
template<>
void gesvdj<float>(CUDASOLVER_GESVDJ_ARGTYPES(float, float));
template<>
void gesvdj<double>(CUDASOLVER_GESVDJ_ARGTYPES(double, double));
template<>
void gesvdj<c10::complex<float>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<float>, float));
template<>
void gesvdj<c10::complex<double>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<double>, double));


#define CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype) \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params, int batchSize

template<class Dtype, class Vtype>
void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) {
TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name());
}
template<>
void gesvdjBatched<float>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float));
template<>
void gesvdjBatched<double>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(double, double));
template<>
void gesvdjBatched<c10::complex<float>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<float>, float));
template<>
void gesvdjBatched<c10::complex<double>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<double>, double));

} // namespace solver
} // namespace cuda
} // namespace at
Expand Down
49 changes: 15 additions & 34 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,14 @@ static inline std::tuple<std::vector<int64_t>,
}

// Function to generate empty tensors of required size, strides and dtype for the SVD operation
static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& input, bool some, bool compute_uv) {
static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& input, bool some, bool compute_uv,
const bool svd_use_cusolver=false) {

// U, S, VT are initialized as empty tensors.
// For CPU LAPACK and GPU MAGMA backend, the tensors are initialized on CPU.
// For GPU cuSOLVER backend, the tensors are initialized on GPU.
const auto usvt_device = svd_use_cusolver ? at::kCUDA : at::kCPU;

auto sizes = input.sizes().vec();
int64_t m = input.size(-2), n = input.size(-1);

Expand All @@ -251,47 +258,21 @@ static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& in
strides[input.dim() - 1] = m;
strides[input.dim() - 2] = 1;

Tensor U_empty;
if (!input.is_cuda()) {
U_empty = at::empty_strided(sizes, strides, input.options());
} else {
// NB: U_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd
// (which is the driver routine for the divide and conquer SVD operation)
// takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that
// moves the inputs between devices internally.
U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this comment was useful to understand why we want to allocate these tensors on the CPU in case we use magma (or at least, it was helpful to me when I looked at this code for the first time), so maybe it is worth resurrecting and put it close to usvt_device = ... above

}
Tensor U_empty = at::empty_strided(sizes, strides, input.options().device(usvt_device));
U_empty.zero_();

// VT should be a column-major or a batch of column-major matrices
sizes[input.dim() - 2] = n;
sizes[input.dim() - 1] = n;
strides = at::detail::defaultStrides(sizes);
strides[input.dim() - 1] = n;
strides[input.dim() - 2] = 1;
Tensor VT_empty;
if (!input.is_cuda()) {
VT_empty = at::empty_strided(sizes, strides, input.options());
} else {
10000 // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd
// (which is the driver routine for the divide and conquer SVD operation)
// takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that
// moves the inputs between devices internally.
VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
}
// VT should be a column-major or a batch of column-major matrices
Tensor VT_empty = at::zeros(sizes, input.options().device(usvt_device));
VT_empty.transpose_(-2, -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the code is correct but the comment is wrong. Moreover, it contradicts the comment at line 264, which says // VT should be a column-major or a batch of column-major matrices.
at::zeros returns a row-major tensor, and transpose_ turns it into a column-major one, which is what lapack/magma (and I assume cusolver) expects.
See also #45821 (comment) which is a more in-depth explanation on the subject


sizes.pop_back();
sizes[input.dim() - 2] = std::min(m, n);
Tensor S_empty;
ScalarType dtype = toValueType(typeMetaToScalarType(input.dtype()));
if (!input.is_cuda()) {
S_empty = at::empty(sizes, input.options().dtype(dtype));
} else {
// NB: S_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd
// (which is the driver routine for the divide and conquer SVD operation)
// takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that
// moves the inputs between devices internally.
S_empty = at::empty(sizes, input.options().dtype(dtype).device(at::kCPU));
}
Tensor S_empty = at::empty(sizes, input.options().dtype(dtype).device(usvt_device));

return std::tuple<Tensor, Tensor, Tensor>(U_empty, S_empty, VT_empty);
}

Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ AT_ERROR("svd: MAGMA library not found in "
#endif
}

std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) {
std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, bool some, bool compute_uv) {
std::vector<int64_t> infos(batchCount(self), 0);
int64_t m = self.size(-2), n = self.size(-1);
int64_t k = std::min(m, n);
Expand Down Expand Up @@ -2256,6 +2256,14 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) {
#ifdef USE_CUSOLVER
return _svd_helper_cuda_lib(self, some, compute_uv);
#else
return _svd_helper_cuda_legacy(self, some, compute_uv);
#endif
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
Expand Down
Loading
0