8000 switch matrix multiplication order linalg_svd_jvp · pytorch/pytorch@8b78279 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8b78279

Browse files
committed
switch matrix multiplication order linalg_svd_jvp
Signed-off-by: redwrasse <mail@redwrasse.io>
1 parent 3c8c509 commit 8b78279

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,8 +3452,11 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
34523452
const auto V = Vh.mH();
34533453

34543454
// dP = U^H dA V
3455-
auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
3456-
: at::matmul(at::matmul(U.mH(), dA), V);
3455+
// U^H (dA V) is O(km(n + k))
3456+
// (U^H dA) V is O(kn(m + k))
3457+
// So prefer U^H (dA V) if m < n
3458+
auto dP = m < n ? at::matmul(U.mH(), at::matmul(dA, V))
3459+
: at::matmul(at::matmul(U.mH(), dA), V);
34573460

34583461
auto dS =
34593462
is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1);

0 commit comments

Comments
 (0)
0