8000 [jiterator] neg: complex · pytorch/pytorch@9855f12 · GitHub < 8000 /head>
[go: up one dir, main page]

Skip to content

Commit 9855f12

Browse files
khushi-411pytorchmergebot
authored andcommitted
[jiterator] neg: complex
Follows: #74748 cc @kshitij12345! Pull Request resolved: #75123 Approved by: https://github.com/kshitij12345, https://github.com/anjali411
1 parent caa28ff commit 9855f12

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

aten/src/ATen/native/ 8000 cuda/UnarySignKernels.cu

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,38 @@ void logical_not_kernel_cuda(TensorIteratorBase& iter) {
2424
}
2525

2626
// NB: Ignores the negative bit on tensors
27+
const char neg_name[] = "neg_kernel";
2728
void neg_kernel_cuda(TensorIteratorBase& iter) {
28-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "neg_cuda", [&]() {
29+
auto dtype = iter.dtype();
30+
if (at::isComplexType(dtype)) {
31+
#if AT_USE_JITERATOR()
32+
static const auto neg_string = jiterator_stringify(
33+
template <typename T>
34+
T neg_kernel(T a) {
35+
return -a;
36+
}
37+
); // neg_string
38+
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
39+
jitted_gpu_kernel<
40+
/*name=*/ neg_name,
41+
/*return_dtype=*/ scalar_t,
42+
/*common_dtype=*/ scalar_t,
43+
/*arity=*/ 1>(iter, neg_string);
44+
});
45+
#else
46+
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
47+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
48+
return -a;
49+
});
50+
});
51+
#endif
52+
} else {
53+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, dtype, "neg_cuda", [&]() {
2954
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
3055
return -a;
3156
});
3257
});
58+
}
3359
}
3460

3561
void sign_kernel_cuda(TensorIteratorBase& iter){

0 commit comments

Comments
 (0)
0