8000 Add cuSOLVER path for torch.linalg.qr · IvanYashchuk/pytorch@ee90471 · GitHub
[go: up one dir, main page]

Skip to content

Commit ee90471

Browse files
committed
Add cuSOLVER path for torch.linalg.qr
Using cuSOLVER path with `pytest test/test_ops.py -k 'linalg_qr' --durations=5` cuts the runtime for these tests by 1 minute locally. Ref. pytorch#51552 ghstack-source-id: 2f98cde Pull Request resolved: pytorch#56256
1 parent e6f4e6c commit ee90471

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,7 @@ void linalg_qr_out_helper(const Tensor& input, const Tensor& Q, const Tensor& R,
17771777
orgqr_stub(input.device().type(), const_cast<Tensor&>(Q), tau);
17781778
}
17791779

1780-
std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& input, std::string mode) {
1780+
std::tuple<Tensor, Tensor> _linalg_qr_helper_default(const Tensor& input, std::string mode) {
17811781
bool compute_q, reduced_mode;
17821782
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
17831783
auto m = input.size(-2);

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,7 @@ AT_ERROR("qr: MAGMA library not found in "
21332133
#endif
21342134
}
21352135

2136-
std::tuple<Tensor,Tensor> _linalg_qr_helper_cuda(const Tensor& self, std::string mode) {
2136+
std::tuple<Tensor, Tensor> linalg_qr_helper_magma(const Tensor& self, std::string mode) {
21372137
bool compute_q, reduced;
21382138
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
21392139

@@ -2178,6 +2178,16 @@ std::tuple<Tensor,Tensor> _linalg_qr_helper_cuda(const Tensor& self, std::string
21782178
return std::make_tuple(q_working_copy, r_working_copy);
21792179
}
21802180

2181+
std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, std::string mode) {
2182+
#if defined(USE_CUSOLVER)
2183+
// _linalg_qr_helper_default is a generic function that is implemented using
2184+
// geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2185+
return _linalg_qr_helper_default(input, mode);
2186+
#else
2187+
return linalg_qr_helper_magma(input, mode);
2188+
#endif
2189+
}
2190+
21812191
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21822192

21832193
template <typename scalar_t>

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8882,7 +8882,7 @@
88828882
- func: _linalg_qr_helper(Tensor self, str mode) -> (Tensor, Tensor)
88838883
variants: function
88848884
dispatch:
8885-
CPU: _linalg_qr_helper_cpu
8885+
CPU: _linalg_qr_helper_default
88868886
CUDA: _linalg_qr_helper_cuda
88878887

88888888
- func: linalg_matrix_power(Tensor self, int n) -> Tensor

0 commit comments

Comments
 (0)
0