8000 Update on "torch.sgn for complex tensors" · pytorch/pytorch@0fed8c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0fed8c8

Browse files
committed
Update on "torch.sgn for complex tensors"
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
1 parent b308f0f commit 0fed8c8

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

torch/testing/_internal/common_nn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3966,7 +3966,6 @@ def padding3d_circular(input, pad):
39663966
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
39673967
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
39683968
(i.numel() if get_reduction(m) else 1),
3969-
check_gradgrad=False,
39703969
check_bfloat16=TEST_WITH_ROCM,
39713970
),
39723971
dict(
@@ -3978,7 +3977,6 @@ def padding3d_circular(input, pad):
39783977
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
39793978
(i.numel() if get_reduction(m) else 1),
39803979
desc='weights',
3981-
check_gradgrad=False,
39823980
check_bfloat16=TEST_WITH_ROCM,
39833981
),
39843982
dict(
@@ -4328,7 +4326,6 @@ def padding3d_circular(input, pad):
43284326
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
43294327
(i.numel() if get_reduction(m) == 'mean' else 1),
43304328
desc='scalar_weights',
4331-
check_gradgrad=False,
43324329
check_bfloat16=TEST_WITH_ROCM,
43334330
),
43344331
dict(

0 commit comments

Comments
 (0)
0