8000 Update on "torch.sgn for complex tensors" · pytorch/pytorch@156c622 · GitHub
[go: up one dir, main page]

Skip to content

Commit 156c622

Browse files
committed
Update on "torch.sgn for complex tensors"
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
1 parent 9354172 commit 156c622

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

aten/src/ATen/native/cuda/UnarySignKernels.cu

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,6 @@ void signbit_kernel_cuda(TensorIterator& iter){
5151
});
5252
}
5353

54-
// template<typename T>
55-
// __host__ __device__ static inline thrust::complex<T> sgn_wrapper(thrust::complex<T> v) {
56-
// if (v == thrust::complex<T>(0, 0)) {
57-
// return thrust::complex<T>(0, 0);
58-
// } else {
59-
// return z / std::abs(z);
60-
// }
61-
//}
62-
6354
template<typename T>
6455
__host__ __device__ static inline c10::complex<T> sgn_wrapper(c10::complex<T> z) {
6556
if (z == c10::complex<T>(0, 0)) {
@@ -71,7 +62,6 @@ __host__ __device__ static inline c10::complex<T> sgn_wrapper(c10::complex<T> z)
7162

7263
void sgn_kernel_cuda(TensorIterator& iter){
7364
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sgn_cuda", [&]() {
74-
//using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
7565
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
7666
return sgn_wrapper(a);
7767
});

0 commit comments

Comments
 (0)
0