diff --git a/aten/src/ATen/cuda/llvm_complex.cpp b/aten/src/ATen/cuda/llvm_complex.cpp index d88bdc4ce6579..735607bf6bac4 100644 --- a/aten/src/ATen/cuda/llvm_complex.cpp +++ b/aten/src/ATen/cuda/llvm_complex.cpp @@ -12,6 +12,7 @@ #include #include +#include namespace at { @@ -773,6 +774,37 @@ log2(const complex<_Tp>& __x) return log(__x) / log(_Tp(2)); } +// reciprocal + +template +inline +complex<_Tp> +reciprocal(const complex<_Tp>& __x) +{ + // Handle extreme cases for numpy compatibility + auto both_inf = [](_Tp real, _Tp imag) { + return isinf(real) && isinf(imag); + }; + + auto either_inf = [](_Tp real, _Tp imag) { + return isinf(real) || isinf(imag); + }; + + auto either_nan = [](_Tp real, _Tp imag) { + return isnan(real) || isnan(imag); + }; + + if (either_nan(__x.real(), __x.imag()) || both_inf(__x.real(), __x.imag())) { + // If either is Nan or both are infinite, return {nan, nan} + return {std::numeric_limits<_Tp>::quiet_NaN(), std::numeric_limits<_Tp>::quiet_NaN()}; + } else if (either_inf(__x.real(), __x.imag())) { + // If either is Inf, return {0, 0} + return {0, 0}; + } + const complex<_Tp> one = complex<_Tp>(1.0, 0); + return one/__x; +} + // sqrt template diff --git a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu index 87aa784b7d5d3..028477d1ae071 100644 --- a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -66,12 +67,12 @@ void floor_kernel_cuda(TensorIteratorBase& iter) { } template -__host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) { +C10_HOST_DEVICE static inline scalar_t reciprocal_wrapper(scalar_t a) { return static_cast(1)/a; } template -__host__ __device__ static inline c10::complex reciprocal_wrapper(c10::complex v) { +C10_HOST_DEVICE static inline c10::complex reciprocal_wrapper(c10::complex v) { // Handle extreme cases for numpy compatibility auto both_inf = [](T real, T imag) { return (::isinf(real) && ::isinf(imag)); @@ -96,15 +97,62 @@ __host__ __device__ static inline c10::complex reciprocal_wrapper(c10::comple return one/v; } +const char reciprocal_name[] = "reciprocal_kernel"; void reciprocal_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto reciprocal_string = jiterator_stringify( + template + T reciprocal_kernel(T v) { + // Handle extreme cases for numpy compatibility + auto both_inf = [](T real, T imag) { + return std::isinf(real) && std::isinf(imag); + }; + + auto either_inf = [](T real, T imag) { + return std::isinf(real) || std::isinf(imag); + }; + + auto either_nan = [](T real, T imag) { + return std::isnan(real) || std::isnan(imag); + }; + + if (either_nan(v.real(), v.imag()) || both_inf(v.real(), v.imag())) { + // If either is Nan or both are infinite, return {nan, nan} + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } else if (either_inf(v.real(), v.imag())) { + // If either is Inf, return {0, 0} + return {0, 0}; + } + const c10::complex one = c10::complex(1.0, 0); + return one/v; + } + ); // reciprocal_string + AT_DISPATCH_COMPLEX_TYPES(dtype, "reciprocal_cuda", [&]() { + jitted_gpu_kernel< + /*name=*/ reciprocal_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, reciprocal_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES(dtype, "reciprocal_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return reciprocal_wrapper(a); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( ScalarType::Half, ScalarType::BFloat16, - iter.common_dtype(), "reciprocal_cuda", + dtype, "reciprocal_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return reciprocal_wrapper(a); }); }); + } } // We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm.