8000 [complex32, jiterator] tan, atan (#77802) (#77802) · pytorch/pytorch@1d2e988 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d2e988

Browse files
khushi-411facebook-github-bot
authored andcommitted
[complex32, jiterator] tan, atan (#77802) (#77802)
Summary: Follows #74537 and #74748 cc kshitij12345 Pull Request resolved: #77802 Approved by: https://github.com/kshitij12345, https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a136408adad60e486612686c4a3ba4eb3bf73775 Reviewed By: seemethere Differential Revision: D36610717 fbshipit-source-id: 317004880c2b9f2e9ad6750b176e312c490c8880
1 parent 8dccea7 commit 1d2e988

File tree

2 files changed

+85
-6
lines changed

2 files changed

+85
-6
lines changed

aten/src/ATen/native/cuda/UnaryGeometricKernels.cu

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,42 @@ void asin_kernel_cuda(TensorIteratorBase& iter) {
3232
});
3333
}
3434

35+
const char atan_name[] = "atan";
3536
void atan_kernel_cuda(TensorIteratorBase& iter) {
36-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
37+
auto common_dtype = iter.common_dtype();
38+
if (at::isComplexType(common_dtype)) {
39+
#if AT_USE_JITERATOR
40+
static const auto atan_string = jiterator_stringify(
41+
template <typename T>
42+
T atan(T a) {
43+
return std::atan(a);
44+
}
45+
);
46+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
47+
jitted_gpu_kernel<
48+
/*name=*/ atan_name,
49+
/*return_dtype=*/ scalar_t,
50+
/*common_dtype=*/ scalar_t,
51+
/*arity=*/ 1>(iter, atan_string);
52+
});
53+
#else
54+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
55+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
56+
using opmath_t = at::opmath_type<scalar_t>;
57+
return ::atan(static_cast<opmath_t>(a));
58+
});
59+
});
60+
#endif
61+
} else {
62+
AT_DISPATCH_FLOATING_TYPES_AND2(
3763
ScalarType::Half, ScalarType::BFloat16,
38-
iter.common_dtype(), "atan_cuda",
64+
common_dtype, "atan_cuda",
3965
[&]() {
4066
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
4167
return ::atan(a);
4268
});
4369
});
70+
}
4471
}
4572

4673
void sin_kernel_cuda(TensorIteratorBase& iter) {
@@ -131,15 +158,42 @@ void atanh_kernel_cuda(TensorIteratorBase& iter) {
131158
});
132159
}
133160

161+
const char tan_name[] = "tan";
134162
void tan_kernel_cuda(TensorIteratorBase& iter) {
135-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
163+
auto common_dtype = iter.common_dtype();
164+
if (at::isComplexType(common_dtype)) {
165+
#if AT_USE_JITERATOR
166+
static const auto tan_string = jiterator_stringify(
167+
template <typename T>
168+
T tan(T a) {
169+
return std::tan(a);
170+
}
171+
);
172+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() {
173+
jitted_gpu_kernel<
174+
/*name=*/ tan_name,
175+
/*return_dtype=*/ scalar_t,
176+
/*common_dtype=*/ scalar_t,
177+
/*arity=*/ 1>(iter, tan_string);
178+
});
179+
#else
180+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() {
181+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
182+
using opmath_t = at::opmath_type<scalar_t>;
183+
return ::tan(static_cast<opmath_t>(a));
184+
});
185+
});
186+
#endif
187+
} else {
188+
AT_DISPATCH_FLOATING_TYPES_AND2(
136189
ScalarType::Half, ScalarType::BFloat16,
137-
iter.common_dtype(), "tan_cuda",
190+
common_dtype, "tan_cuda",
138191
[&]() {
139192
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
140193
return ::tan(a);
141194
});
142195
});
196+
}
143197
}
144198

145199
REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda);

torch/testing/_internal/common_methods_invocations.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10341,7 +10341,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1034110341
aliases=('arctan', ),
1034210342
ref=np.arctan,
1034310343
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
10344-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
10344+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1034510345
assert_autodiffed=True,
1034610346
supports_forward_ad=True,
1034710347
supports_fwgrad_bwgrad=True,
@@ -10367,6 +10367,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1036710367
active_if=IS_WINDOWS),
1036810368
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1036910369
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
10370+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
10371+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
10372+
dtypes=(torch.chalf,)),
10373+
# same reason as above
10374+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
10375+
dtypes=(torch.chalf,)),
10376+
# same reason as above
10377+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
10378+
dtypes=(torch.chalf,)),
10379+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
10380+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
10381+
dtypes=(torch.chalf,)),
1037010382
)),
1037110383
BinaryUfuncInfo('atan2',
1037210384
aliases=('arctan2',),
@@ -14995,7 +15007,8 @@ def error_inputs_mean(op_info, device, **kwargs):
1499515007
UnaryUfuncInfo('tan',
1499615008
ref=np.tan,
1499715009
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
14998-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
15010+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
15011+
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1499915012
assert_autodiffed=True,
1500015013
supports_forward_ad=True,
1500115014
supports_fwgrad_bwgrad=True,
@@ -15019,6 +15032,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1501915032
active_if=TEST_WITH_ROCM),
1502015033
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1502115034
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
15035+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15036+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
15037+
dtypes=(torch.chalf,)),
15038+
# same reason as above
15039+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
15040+
dtypes=(torch.chalf,)),
15041+
# same reason as above
15042+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
15043+
dtypes=(torch.chalf,)),
15044+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
15045+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
15046+
dtypes=(torch.chalf,)),
1502215047
),
1502315048
# tan(pi/2 * odd_number) is nan
1502415049
reference_numerics_filter=NumericsFilter(

0 commit comments

Comments
 (0)
0