8000 Update on "torch.sgn for complex tensors" · pytorch/pytorch@30512e2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 30512e2

Browse files
committed
Update on "torch.sgn for complex tensors"
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
1 parent 0092a14 commit 30512e2

File tree

3 files changed

+0
-39
lines changed

3 files changed

+0
-39
lines changed

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -629,38 +629,6 @@ Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tenso
629629
}
630630
}
631631

632-
633-
Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {
634-
if (!keepdim && sizes.size() > 0) {
635-
grad = grad.unsqueeze(dim);
636-
indices = indices.unsqueeze(dim);
637-
}
638-
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
639-
}
640-
641-
Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
642-
auto grad_input = at::zeros(input_sizes, grad.options());
643-
grad_input.slice(dim, start, end, step).copy_(grad);
644-
return grad_input;
645-
}
646-
647-
Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
648-
auto grad_input = at::zeros(input_sizes, grad.options());
649-
grad_input.select(dim, index).copy_(grad);
650-
return grad_input;
651-
}
652-
653-
Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
654-
if (sizes.size() != 2) {
655-
throw std::runtime_error("expected matrix input");
656-
}
657-
658-
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
659-
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
660-
grad_input.index_fill_(0, indices, grad);
661-
return grad_input.view(sizes);
662-
}
663-
664632
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
665633
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
666634
}

torch/csrc/autograd/FunctionsManual.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,6 @@ at::Tensor sum_tensorlist(at::TensorList tl);
7676
at::Tensor repeat_backward(at::Tensor grad, int64_t input_dims, at::IntArrayRef repeats);
7777
at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, d 8000 ouble p1m);
7878
at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value);
79-
at::Tensor index_select_backward(at::Tensor grad, int64_t dim, at::Tensor indices, at::IntArrayRef sizes, bool keepdim);
80-
at::Tensor slice_backward(at::Tensor grad, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step);
81-
at::Tensor select_backward(at::Tensor grad, at::IntArrayRef input_sizes, int64_t dim, int64_t index);
82-
at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
8379
at::Tensor sgn_backward(Tensor result, Tensor grad, Tensor self);
8480
at::Tensor var_backward(const at::Tensor & grad, const at::Tensor & self, bool unbiased);
8581
at::Tensor var_backward(at::Tensor grad, const at::Tensor & self, at::IntArrayRef dim, bool unbiased, bool keepdim);

torch/testing/_internal/common_nn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3966,7 +3966,6 @@ def padding3d_circular(input, pad):
39663966
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
39673967
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
39683968
(i.numel() if get_reduction(m) else 1),
3969-
check_gradgrad=False,
39703969
check_bfloat16=TEST_WITH_ROCM,
39713970
),
39723971
dict(
@@ -3978,7 +3977,6 @@ def padding3d_circular(input, pad):
39783977
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
39793978
(i.numel() if get_reduction(m) else 1),
39803979
desc='weights',
3981-
check_gradgrad=False,
39823980
check_bfloat16=TEST_WITH_ROCM,
39833981
),
39843982
dict(
@@ -4328,7 +4326,6 @@ def padding3d_circular(input, pad):
43284326
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
43294327
(i.numel() if get_reduction(m) == 'mean' else 1),
43304328
desc='scalar_weights',
4331-
check_gradgrad=False,
43324329
check_bfloat16=TEST_WITH_ROCM,
43334330
),
43344331
dict(

0 commit comments

Comments
 (0)
0