8000 [complex32, jiterator] sin, asin (#77606) · pytorch/pytorch@6f4d200 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6f4d200

Browse files
khushi-411pytorchmergebot
authored andcommitted
[complex32, jiterator] sin, asin (#77606)
Follows #74537 and #74748 cc @kshitij12345 Pull Request resolved: #77606 Approved by: https://github.com/kshitij12345, https://github.com/ngimel
1 parent 4ea176e commit 6f4d200

File tree

2 files changed

+94
-13
lines changed

2 files changed

+94
-13
lines changed

aten/src/ATen/native/cuda/UnaryGeometricKernels.cu

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/native/DispatchStub.h>
88
#include <ATen/native/TensorIterator.h>
99
#include <ATen/native/cuda/Math.cuh>
10+
#include <ATen/OpMathType.h>
1011

1112
namespace at { namespace native {
1213

@@ -21,15 +22,39 @@ void acos_kernel_cuda(TensorIteratorBase& iter) {
2122
});
2223
}
2324

25+
const char asin_name[] = "asin";
2426
void asin_kernel_cuda(TensorIteratorBase& iter) {
25-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
26-
ScalarType::Half, ScalarType::BFloat16,
27-
iter.common_dtype(), "asin_cuda",
28-
[&]() {
29-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
30-
return ::asin(a);
31-
});
32-
});
27+
auto common_dtype = iter.common_dtype();
28+
if(at::isComplexType(common_dtype)) {
29+
#if AT_USE_JITERATOR
30+
static const auto asin_string = jiterator_stringify(
31+
template <typename T>
32+
T asin(T a) {
33+
return std::asin(a);
34+
}
35+
);
36+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() {
37+
jitted_gpu_kernel<
38+
/*name=*/ asin_name,
39+
/*return_dtype=*/ scalar_t,
40+
/*common_dtype=*/ scalar_t,
41+
/*arity=*/ 1>(iter, asin_string);
42+
});
43+
#else
44+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() {
45+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
46+
using opmath_t = at::opmath_type<scalar_t>;
47+
return ::asin(static_cast<opmath_t>(a));
48+
});
49+
});
50+
#endif
51+
} else {
52+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() {
53+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
54+
return ::asin(a);
55+
});
56+
});
57+
}
3358
}
3459

3560
const char atan_name[] = "atan";
@@ -70,15 +95,42 @@ void atan_kernel_cuda(TensorIteratorBase& iter) {
7095
}
7196
}
7297

98+
const char sin_name[] = "sin";
7399
void sin_kernel_cuda(TensorIteratorBase& iter) {
74-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
75-
ScalarType::Half, ScalarType::BFloat16,
76-
iter.common_dtype(), "sin_cuda",
100+
auto common_dtype = iter.common_dtype();
101+
if(at::isComplexType(common_dtype)) {
102+
#if AT_USE_JITERATOR
103+
static const auto sin_string = jiterator_stringify(
104+
template <typename T>
105+
T sin(T a) {
106+
return std::sin(a);
107+
}
108+
);
109+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() {
110+
jitted_gpu_kernel<
111+
/*name=*/ sin_name,
112+
/*return_dtype=*/ scalar_t,
113+
/*common_dtype=*/ scalar_t,
114+
/*arity=*/ 1>(iter, sin_string);
115+
});
116+
#else
117+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() {
118+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
119+
using opmath_t = at::opmath_type<scalar_t>;
120+
return ::sin(static_cast<opmath_t>(a));
121+
});
122+
});
123+
#endif
124+
} else {
125+
AT_DISPATCH_FLOATING_TYPES_AND2(
126+
ScalarType::Half, ScalarType::BFloat16,
127+
common_dtype, "sin_cuda",
77128
[&]() {
78129
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
79130
return ::sin(a);
80131
});
81132
});
133+
}
82134
}
83135

84136
void cos_kernel_cuda(TensorIteratorBase& iter) {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10370,7 +10370,8 @@ def error_inputs_mean(op_info, device, **kwargs):
1037010370
supports_forward_ad=True,
1037110371
supports_fwgrad_bwgrad=True,
1037210372
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
10373-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
10373+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
10374+
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1037410375
assert_autodiffed=True,
1037510376
decorators=[
1037610377
DecorateInfo(
@@ -10391,6 +10392,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1039110392
active_if=IS_WINDOWS),
1039210393
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1039310394
'TestSparseUnaryUfuncs', &# 1E79 39;test_sparse_fn_grad'),
10395+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
10396+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
10397+
dtypes=(torch.chalf,)),
10398+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
10399+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
10400+
dtypes=(torch.chalf,)),
10401+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
10402+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
10403+
dtypes=(torch.chalf,)),
10404+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
10405+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
10406+
dtypes=(torch.chalf,)),
1039410407
)),
1039510408
# NOTE: derivative for inplace asinh is not implemented
1039610409
UnaryUfuncInfo('asinh',
@@ -14773,7 +14786,11 @@ def error_inputs_mean(op_info, device, **kwargs):
1477314786
UnaryUfuncInfo('sin',
1477414787
ref=np.sin,
1477514788
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
14776-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
14789+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
14790+
# TODO: Add torch.chalf backward dtype support. Currently, we get:
14791+
# AssertionError: The supported dtypes for sin on device type cuda are incorrect!
14792+
# The following dtypes did not work in backward but are listed by the OpInfo: {torch.complex32}.
14793+
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1477714794
assert_autodiffed=True,
1477814795
handles_large_floats=False,
1477914796
supports_sparse=True,
@@ -14790,6 +14807,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1479014807
dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
1479114808
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1479214809
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
14810+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
14811+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
14812+
dtypes=(torch.chalf,)),
14813+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
14814+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
14815+
dtypes=(torch.chalf,)),
14816+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
14817+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
14818+
dtypes=(torch.chalf,)),
14819+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
14820+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
14821+
dtypes=(torch.chalf,)),
1479314822
),
1479414823
decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
1479514824
UnaryUfuncInfo('sinc',

0 commit comments

Comments
 (0)
0