8000 Update on "torch.sgn for complex tensors" · pytorch/pytorch@26b367e · GitHub
[go: up one dir, main page]

Skip to content

Commit 26b367e

Browse files
committed
Update on "torch.sgn for complex tensors"
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
1 parent e154c5d commit 26b367e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

aten/src/ATen/cpu/vec256/vec256_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ struct Vec256 {
241241
}
242242

243243
template <typename other_t_sgn = T,
244-
typename std::enable_if<c10::is_complex_t<other_t_sgn>::value, int>::type = 0>
244+
typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
245245
Vec256<T> sgn() const {
246246
return map(at::native::sgn_impl);
247247
}

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,13 @@ Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
774774
}
775775

776776
Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) {
777-
// [grad / abs(z) - Re(grad/self) * result
778-
auto abs = at::abs(self);
779-
return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
777+
if (self.is_complex()) {
778+
// vjp = grad / abs(self) - Re(grad/self) * result
779+
auto abs = at::abs(self);
780+
return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
781+
} else {
782+
return at::zeros_like(grad, at::MemoryFormat::Preserve);
783+
}
780784
}
781785

782786
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {

0 commit comments

Comments
 (0)
0