8000 Update on "torch.sgn for complex tensors" · pytorch/pytorch@74a456d · GitHub
[go: up one dir, main page]

Skip to content

Commit 74a456d

Browse files
committed
Update on "torch.sgn for complex tensors"
[ghstack-poisoned]
1 parent d8d243e commit 74a456d

File tree

10 files changed

+26
-23
lines changed

10 files changed

+26
-23
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ _(aten, selu) \
605605
_(aten, set) \
606606
_(aten, sigmoid) \
607607
_(aten, sign) \
608+
_(aten, sgn) \
608609
_(aten, sin) \
609610
_(aten, sinh) \
610611
_(aten, size) \

aten/src/ATen/native/cpu/zmath_std.h

Lines changed: 0 additions & 12 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,6 @@ inline double zabs <std::complex<double>, double> (std::complex<double> z) {
4343
return std::abs(z);
4444
}
4545

46-
// template<>
47-
// inline std::complex<float> sgn_impl (std::complex<float> z) {
48-
// auto angle = std::arg(z);
49-
// return std::complex<float>(std::cos(angle), std::sin(angle));
50-
// }
51-
52-
// template<>
53-
// inline std::complex<double> sgn_impl (std::complex<double> z) {
54-
// auto angle = std::arg(z);
55-
// return std::complex<double>(std::cos(angle), std::sin(angle));
56-
// }
57-
5846
template<>
5947
inline std::complex<float> angle_impl <std::complex<float>> (std::complex<float> z) {
6048
return std::complex<float>(std::arg(z), 0.0);

aten/src/ATen/native/cuda/UnarySignKernels.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/native/DispatchStub.h>
88
#include <ATen/native/TensorIterator.h>
99
#include <ATen/native/cuda/Math.cuh>
10+
#include <ATen/native/cuda/zmath.cuh>
1011

1112
namespace at { namespace native {
1213

@@ -46,14 +47,16 @@ void sign_kernel_cuda(TensorIterator& iter){
4647
}
4748

4849
template<typename T>
49-
__host__ __device__ static inline c10::complex<T> sgn_wrapper(c10::complex<T> v) {
50-
return v.sgn();
50+
__host__ __device__ static inline thrust::complex<T> sgn_wrapper(thrust::complex<T> v) {
51+
T angle = thrust::arg(v);
52+
return thrust::complex<T>(::cos(angle), ::sin(angle));
5153
}
5254

53-
void sign_kernel_cuda(TensorIterator& iter){
55+
void sgn_kernel_cuda(TensorIterator& iter){
5456
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sgn_cuda", [&]() {
55-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
56-
return sgn_wrapper(a);
57+
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
58+
gpu_kernel(iter, []GPU_LAMBDA(thrust_t a) -> thrust_t {
59+
return sgn_wrapper(a);
5760
});
5861
});
5962
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@
253253
- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
254254
dispatch:
255255
CPU: sgn_out
256+
CUDA: sgn_out
256257

257258
- func: real(Tensor self) -> Tensor
258259
use_c10_dispatcher: full

benchmarks/operator_benchmark/pt/unary_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self):
9494
['sigmoid', torch.sigmoid],
9595
['sigmoid_', torch.sigmoid_],
9696
['sign', torch.sign],
97-
['sgn', torch.sign],
97+
['sgn', torch.sgn],
9898
['sin', torch.sin],
9999
['sin_', torch.sin_],
100100
['sinh', torch.sinh],

docs/source/name_inference.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ If you don't see an operation listed here, but it would help your use case, plea
197197
:meth:`Tensor.sigmoid_`,None
198198
":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc`
199199
:meth:`Tensor.sign_`,None
200+ 77FB
":meth:`Tensor.sgn`, :func:`torch.sgn`",:ref:`keeps_input_names-doc`
201+
:meth:`Tensor.sgn_`,None
200202
":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc`
201203
:meth:`Tensor.sin_`,None
202204
":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc`

docs/source/tensors.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ view of a storage and defines numeric operations on it.
473473
.. automethod:: sigmoid_
474474
.. automethod:: sign
475475
.. automethod:: sign_
476+
.. automethod:: sgn
477+
.. automethod:: sgn_
476478
.. automethod:: sin
477479
.. automethod:: sin_
478480
.. automethod:: sinh

test/test_torch.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11132,14 +11132,14 @@ def test_sign(self, device):
1113211132
def test_sgn(self, device, dtype):
1113311133
x = torch.randn(100, dtype=dtype)
1113411134
angle = x.angle()
11135-
cos_angle = angle.cos()
11136-
sin_angle = angle.sin()
11137-
expected = cos_angle + 1j * sin_angle
11138-
self.assertEqual(x.sgn(), expected)
11135+
out = x.sgn()
11136+
self.assertEqual(out.angle(), angle)
11137+
self.assertEqual(out.abs(), torch.ones_like(x).real)
1113911138

1114011139
x_out = torch.empty_like(x)
1114111140
torch.sgn(x, out=x_out)
11142-
self.assertEqual(x_out, expected)
11141+
self.assertEqual(x_out.angle(), angle)
11142+
self.assertEqual(x_out.abs(), torch.ones_like(x).real)
1114311143

1114411144
def test_logical_any(self, device):
1114511145
x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device)
@@ -14615,6 +14615,8 @@ def test_helper(x, y, memory_format):
1461514615
lambda x, y: x.sigmoid_(),
1461614616
lambda x, y: x.sign(),
1461714617
lambda x, y: x.sign_(),
14618+
lambda x, y: x.sgn(),
14619+
lambda x, y: x.sgn_(),
1461814620
lambda x, y: x.sin(),
1461914621
lambda x, y: x.sin_(),
1462014622
lambda x, y: x.sinh(),

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,9 @@
883883
- name: sign(Tensor self) -> Tensor
884884
self: zeros_like(grad, at::MemoryFormat::Preserve)
885885

886+
- name: sgn(Tensor self) -> Tensor
887+
self: grad / at::abs(self) # g_x u0 + i g_y v1, u0 = v1 = abs(z)
888+
886889
- name: sin(Tensor self) -> Tensor
887890
self: grad * self.cos()
888891

torch/_overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def get_testing_overrides():
609609
torch.selu: lambda input, inplace=False: -1,
610610
torch.sigmoid: lambda input, out=None: -1,
611611
torch.sign: lambda input, out=None: -1,
612+
torch.sgn: lambda input, out=None: -1,
612613
torch.sin: lambda input, out=None: -1,
613614
torch.sinh: lambda input, out=None: -1,
614615
torch.slogdet: lambda input: -1,

0 commit comments

Comments
 (0)
0