10000 [jiterator, complex32] tanh_backward : complex (#76289) · pytorch/pytorch@0ccd3ae · GitHub
[go: up one dir, main page]

Skip to content

Commit 0ccd3ae

Browse files
khushi-411facebook-github-bot
authored andcommitted
[jiterator, complex32] tanh_backward : complex (#76289)
Summary: Follows #74748 and #74537 cc kshitij12345! Pull Request resolved: #76289 Approved by: https://github.com/anjali411 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/305a9cc00a58fe7f265e1c6331f197f61a2390b5 Reviewed By: osalpekar Differential Revision: D35971220 fbshipit-source-id: bce0fb21b4d23ad8f9081a0b30a0d096829dc8c3
1 parent 7b84b96 commit 0ccd3ae

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,37 @@ void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scal
8787
});
8888
}
8989

90+
const char tanh_backward_name[] = "tanh_backward";
9091
void tanh_backward_kernel_cuda(TensorIteratorBase& iter) {
91-
if(isComplexType(iter.dtype())) {
92-
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_complex_cuda", [&]() {
92+
auto dtype = iter.dtype();
93+
if(isComplexType(dtype)) {
94+
#if AT_USE_JITERATOR()
95+
static const auto tanh_backward_string = jiterator_stringify(
96+
template <typename T>
97+
T tanh_backward(T a, T b) {
98+
return a * std::conj(T{1.} - b * b);
99+
}
100+
); // tanh_backward_string
101+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
102+
jitted_gpu_kernel<
103+
/*name=*/ tanh_backward_name,
104+
/*return_dtype=*/ scalar_t,
105+
/*common_dtype=*/ scalar_t,
106+
/*arity=*/ 2>(iter, tanh_backward_string);
107+
});
108+
#else
109+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
93110
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
94-
return a * std::conj(scalar_t{1.} - b * b);
111+
using comp_t = at::opmath_type<scalar_t>;
112+
const auto one = comp_t{1.};
113+
const auto comp_b = static_cast<comp_t>(b);
114+
const auto comp_a = static_cast<comp_t>(a);
115+
return static_cast<scalar_t>(comp_a * std::conj(one - comp_b * comp_b));
95116
});
96117
});
118+
#endif
97119
} else {
98-
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() {
120+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "tanh_backward_cuda", [&]() {
99121
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
100122
return a * (scalar_t{1.} - b * b);
101123
});

0 commit comments

Comments
 (0)
0