8000 Add cuSOLVER path for torch.linalg.qr · IvanYashchuk/pytorch@62ded95 · GitHub
[go: up one dir, main page]

Skip to content

Commit 62ded95

Browse files
committed
Add cuSOLVER path for torch.linalg.qr
Using cuSOLVER path with `pytest test/test_ops.py -k 'linalg_qr' --durations=5` cuts the runtime for these tests by 1 minute locally. Ref. pytorch#51552 ghstack-source-id: e94b357 Pull Request resolved: pytorch#56256
1 parent 8f10bad commit 62ded95

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1790,7 +1790,7 @@ void linalg_qr_out_helper(const Tensor& input, const Tensor& Q, const Tensor& R,
17901790
orgqr_stub(input.device().type(), const_cast<Tensor&>(Q), tau);
17911791
}
17921792

1793-
std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& input, std::string mode) {
1793+
std::tuple<Tensor, Tensor> _linalg_qr_helper_default(const Tensor& input, std::string mode) {
17941794
bool compute_q, reduced_mode;
17951795
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
17961796
auto m = input.size(-2);

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,7 @@ AT_ERROR("qr: MAGMA library not found in "
21332133
#endif
21342134
}
21352135

2136-
std::tuple<Tensor,Tensor> _linalg_qr_helper_cuda(const Tensor& self, std::string mode) {
2136+
std::tuple<Tensor, Tensor> linalg_qr_helper_magma(const Tensor& self, std::string mode) {
21372137
bool compute_q, reduced;
21382138
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
21392139

@@ -2178,6 +2178,16 @@ std::tuple<Tensor,Tensor> _linalg_qr_helper_cuda(const Tensor& self, std::string
21782178
return std::make_tuple(q_working_copy, r_working_copy);
21792179
}
21802180

2181+
std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, std::string mode) {
2182+
#if defined(USE_CUSOLVER)
2183+
// _linalg_qr_helper_default is a generic function that is implemented using
2184+
// geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2185+
return _linalg_qr_helper_default(input, mode);
2186+
#else
2187+
return linalg_qr_helper_magma(input, mode);
2188+
#endif
2189+
}
2190+
21812191
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21822192

21832193
template <typename scalar_t>

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8902,7 +8902,7 @@
89028902
- func: _linalg_qr_helper(Tensor self, str mode) -> (Tensor, Tensor)
89038903
variants: function
89048904
dispatch:
8905-
CPU: _linalg_qr_helper_cpu
8905+
CPU: _linalg_qr_helper_default
89068906
CUDA: _linalg_qr_helper_cuda
89078907

89088908
- func: linalg_matrix_power(Tensor self, int n) -> Tensor

0 commit comments

Comments
 (0)
0