8000 [reland][complex32, jiterator] cos, sinh, cosh, tanh (#78718) (#78718) · pytorch/pytorch@87d1361 · GitHub
[go: up one dir, main page]

Skip to content

Commit 87d1361

Browse files
khushi-411facebook-github-bot
authored andcommitted
[reland][complex32, jiterator] cos, sinh, cosh, tanh (#78718) (#78718)
Summary: Ref: #78458 Follows: #74537 and #74748 cc kshitij12345 anjali411 :) Pull Request resolved: #78718 Approved by: https://github.com/anjali411, https://github.com/kshitij12345 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5b32c34450d69d677d2e721059e381134303b3fa Reviewed By: osalpekar Differential Revision: D37010871 Pulled By: osalpekar fbshipit-source-id: 1f63f299bb25140d5538be6c602bdd67f6bf8ddd
1 parent 47cb201 commit 87d1361

File tree

2 files changed

+156
-16
lines changed

2 files changed

+156
-16
lines changed

aten/src/ATen/native/cuda/UnaryGeometricKernels.cu

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,48 +133,156 @@ void sin_kernel_cuda(TensorIteratorBase& iter) {
133133
}
134134
}
135135

136+
const char cos_name[] = "cos";
136137
void cos_kernel_cuda(TensorIteratorBase& iter) {
137-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
138+
auto common_dtype = iter.common_dtype();
139+
if(at::isComplexType(common_dtype)) {
140+
#if AT_USE_JITERATOR
141+
static const auto cos_string = jiterator_stringify(
142+
template <typename T>
143+
T cos(T a) {
144+
return std::cos(a);
145+
}
146+
);
147+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() {
148+
jitted_gpu_kernel<
149+
/*name=*/ cos_name,
150+
/*return_dtype=*/ scalar_t,
151+
/*common_dtype=*/ scalar_t,
152+
/*arity=*/ 1>(iter, cos_string);
153+
});
154+
#else
155+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() {
156+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
157+
using opmath_t = at::opmath_type<scalar_t>;
158+
return ::cos(static_cast<opmath_t>(a));
159+
});
160+
});
161+
#endif
162+
} else {
163+
AT_DISPATCH_FLOATING_TYPES_AND2(
138164
ScalarType::Half, ScalarType::BFloat16,
139-
iter.common_dtype(), "cos_cuda",
165+
common_dtype, "cos_cuda",
140166
[&]() {
141167
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
142168
return ::cos(a);
143169
});
144170
});
171+
}
145172
}
146173

174+
const char sinh_name[] = "sinh";
147175
void sinh_kernel_cuda(TensorIteratorBase& iter) {
148-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
176+
auto common_dtype = iter.common_dtype();
177+
if(at::isComplexType(common_dtype)) {
178+
#if AT_USE_JITERATOR
179+
static const auto sinh_string = jiterator_stringify(
180+
template <typename T>
181+
T sinh(T a) {
182+
return std::sinh(a);
183+
}
184+
);
185+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() {
186+
jitted_gpu_kernel<
187+
/*name=*/ sinh_name,
188+
/*return_dtype=*/ scalar_t,
189+
/*common_dtype=*/ scalar_t,
190+
/*arity=*/ 1>(iter, sinh_string);
191+
});
192+
#else
193+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() {
194+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
195+
using opmath_t = at::opmath_type<scalar_t>;
196+
return ::sinh(static_cast<opmath_t>(a));
197+
});
198+
});
199+
#endif
200+
} else {
201+
AT_DISPATCH_FLOATING_TYPES_AND2(
149202
ScalarType::Half, ScalarType::BFloat16,
150-
iter.common_dtype(), "sinh_cuda",
203+
common_dtype, "sinh_cuda",
151204
[&]() {
152205
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
153206
return ::sinh(a);
154207
});
155208
});
209+
}
156210
}
157211

212+
const char cosh_name[] = "cosh";
158213
void cosh_kernel_cuda(TensorIteratorBase& iter) {
159-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
214+
auto common_dtype = iter.common_dtype();
215+
if(at::isComplexType(common_dtype)) {
216+
#if AT_USE_JITERATOR
217+
static const auto cosh_string = jiterator_stringify(
218+
template <typename T>
219+
T cosh(T a) {
220+
return std::cosh(a);
221+
}
222+
);
223+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() {
224+
jitted_gpu_kernel<
225+
/*name=*/ cosh_name,
226+
/*return_dtype=*/ scalar_t,
227+
/*common_dtype=*/ scalar_t,
228+
/*arity=*/ 1>(iter, cosh_string);
229+
});
230+
#else
231+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() {
232+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
233+
using opmath_t = at::opmath_type<scalar_t>;
234+
return ::cosh(static_cast<opmath_t>(a));
235+
});
236+
});
237+
#endif
238+
} else {
239+
AT_DISPATCH_FLOATING_TYPES_AND2(
160240
ScalarType::Half, ScalarType::BFloat16,
161-
iter.common_dtype(), "cosh_cuda",
241+
common_dtype, "cosh_cuda",
162242
[&]() {
163243
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
164244
return ::cosh(a);
165245
});
166246
});
247+
}
167248
}
168249

250+
const char tanh_name[] = "tanh";
169251
void tanh_kernel_cuda(TensorIteratorBase& iter) {
170-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
252+
auto common_dtype = iter.common_dtype();
253+
if(at::isComplexType(common_dtype)) {
254+
#if AT_USE_JITERATOR
255+
static const auto tanh_string = jiterator_stringify(
256+
template <typename T>
257+
T tanh(T a) {
258+
return std::tanh(a);
259+
}
260+
);
261+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() {
262+
jitted_gpu_kernel<
263+
/*name=*/ tanh_name,
264+
/*return_dtype=*/ scalar_t,
265+
/*common_dtype=*/ scalar_t,
266+
/*arity=*/ 1>(iter, tanh_string);
267+
});
268+
#else
269+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() {
270+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
271+
using opmath_t = at::opmath_type<scalar_t>;
272+
return ::tanh(static_cast<opmath_t>(a));
273+
});
274+
});
275+
#endif
276+
} else {
277+
AT_DISPATCH_FLOATING_TYPES_AND2(
171278
ScalarType::Half, ScalarType::BFloat16,
172-
iter.common_dtype(), "tanh_cuda",
279+
common_dtype, "tanh_cuda",
173280
[&]() {
174281
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
175282
return ::tanh(a);
176283
});
177284
});
285+
}
178286
}
179287

180288
void acosh_kernel_cuda(TensorIteratorBase& iter) {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11168,7 +11168,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1116811168
UnaryUfuncInfo('cos',
1116911169
ref=np.cos,
1117011170
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
11171-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
11171+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1117211172
assert_autodiffed=True,
1117311173
handles_large_floats=False,
1117411174
supports_forward_ad=True,
@@ -11185,11 +11185,17 @@ def error_inputs_mean(op_info, device, **kwargs):
1118511185
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
1118611186
device_type='cpu',
1118711187
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
11188+
# AssertionError: Tensor-likes are not close!
11189+
# Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
11190+
# Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
11191+
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
11192+
device_type='cuda',
11193+
dtypes=(torch.chalf,), active_if=IS_WINDOWS),
1118811194
)),
1118911195
UnaryUfuncInfo('cosh',
1119011196
ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
1119111197
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
11192-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
11198+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1119311199
assert_autodiffed=True,
1119411200
supports_forward_ad=True,
1119511201
supports_fwgrad_bwgrad=True,
@@ -11209,6 +11215,12 @@ def error_inputs_mean(op_info, device, **kwargs):
1120911215
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
1121011216
device_type='cpu',
1121111217
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
11218+
# AssertionError: Tensor-likes are not close!
11219+
# Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
11220+
# Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
11221+
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
11222+
device_type='cuda',
11223+
dtypes=(torch.chalf,), active_if=IS_WINDOWS),
1121211224
)),
1121311225
OpInfo('cov',
1121411226
dtypes=all_types_and_complex_and(torch.bfloat16),
@@ -15268,10 +15280,6 @@ def error_inputs_mean(op_info, device, **kwargs):
1526815280
ref=np.sin,
1526915281
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
1527015282
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
15271-
# TODO: Add torch.chalf backward dtype support. Currently, we get:
15272-
# AssertionError: The supported dtypes for sin on device type cuda are incorrect!
15273-
# The following dtypes did not work in backward but are listed by the OpInfo: {torch.complex32}.
15274-
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1527515283
assert_autodiffed=True,
1527615284
handles_large_floats=False,
1527715285
supports_sparse=True,
@@ -15320,7 +15328,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1532015328
UnaryUfuncInfo('sinh',
1532115329
ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
1532215330
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
15323-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
15331+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1532415332
assert_autodiffed=True,
1532515333
supports_forward_ad=True,
1532615334
supports_fwgrad_bwgrad=True,
@@ -15341,6 +15349,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1534115349
device_type='cpu', dtypes=[torch.int8]),
1534215350
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1534315351
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
15352+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15353+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
15354+
dtypes=(torch.chalf,)),
15355+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15356+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
15357+
dtypes=(torch.chalf,)),
15358+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15359+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
15360+
dtypes=(torch.chalf,)),
15361+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
15362+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
15363+
dtypes=(torch.chalf,)),
1534415364
)),
1534515365
UnaryUfuncInfo('sign',
1534615366
ref=reference_sign,
@@ -15668,7 +15688,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1566815688
aliases=('nn.functional.tanh',),
1566915689
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
1567015690
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
15671-
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
15691+
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1567215692
assert_autodiffed=True,
1567315693
assert_jit_shape_analysis=True,
1567415694
supports_forward_ad=True,
@@ -15687,6 +15707,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1568715707
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
1568815708
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1568915709
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
15710+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15711+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
15712+
dtypes=(torch.chalf,)),
15713+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15714+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
15715+
dtypes=(torch.chalf,)),
15716+
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
15717+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
15718+
dtypes=(torch.chalf,)),
15719+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
15720+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
15721+
dtypes=(torch.chalf,)),
1569015722
),
1569115723
# tan(j * pi/2 * odd_number) is nan
1569215724
reference_numerics_filter=NumericsFilter(

0 commit comments

Comments
 (0)
0