@@ -87,15 +87,37 @@ void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scal
87
87
});
88
88
}
89
89
90
+ const char tanh_backward_name[] = " tanh_backward" ;
90
91
void tanh_backward_kernel_cuda (TensorIteratorBase& iter) {
91
- if (isComplexType (iter.dtype ())) {
92
- AT_DISPATCH_COMPLEX_TYPES (iter.dtype (), " tanh_backward_complex_cuda" , [&]() {
92
+ auto dtype = iter.dtype ();
93
+ if (isComplexType (dtype)) {
94
+ #if AT_USE_JITERATOR()
95
+ static const auto tanh_backward_string = jiterator_stringify (
96
+ template <typename T>
97
+ T tanh_backward (T a, T b) {
98
+ return a * std::conj (T{1 .} - b * b);
99
+ }
100
+ ); // tanh_backward_string
101
+ AT_DISPATCH_COMPLEX_TYPES_AND (kComplexHalf , dtype, " tanh_backward_complex_cuda" , [&]() {
102
+ jitted_gpu_kernel<
103
+ /* name=*/ tanh_backward_name,
104
+ /* return_dtype=*/ scalar_t ,
105
+ /* common_dtype=*/ scalar_t ,
106
+ /* arity=*/ 2 >(iter, tanh_backward_string);
107
+ });
108
+ #else
109
+ AT_DISPATCH_COMPLEX_TYPES_AND (kComplexHalf , dtype, " tanh_backward_complex_cuda" , [&]() {
93
110
gpu_kernel (iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
94
- return a * std::conj (scalar_t {1 .} - b * b);
111
+ using comp_t = at::opmath_type<scalar_t >;
112
+ const auto one = comp_t {1 .};
113
+ const auto comp_b = static_cast <comp_t >(b);
114
+ const auto comp_a = static_cast <comp_t >(a);
115
+ return static_cast <scalar_t >(comp_a * std::conj (one - comp_b * comp_b));
95
116
});
96
117
});
118
+ #endif
97
119
} else {
98
- AT_DISPATCH_FLOATING_TYPES_AND2 (at::ScalarType::Half, at::ScalarType::BFloat16, iter. dtype () , " tanh_backward_cuda" , [&]() {
120
+ AT_DISPATCH_FLOATING_TYPES_AND2 (at::ScalarType::Half, at::ScalarType::BFloat16, dtype, " tanh_backward_cuda" , [&]() {
99
121
gpu_kernel (iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
100
122
return a * (scalar_t {1 .} - b * b);
101
123
});
0 commit comments