8000 [jiterator] neg: complex (#75123) · pytorch/pytorch@8ba4463 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ba4463

Browse files
khushi-411facebook-github-bot
authored andcommitted
[jiterator] neg: complex (#75123)
Summary: Follows: #74748 cc kshitij12345! Pull Request resolved: #75123 Approved by: https://github.com/kshitij12345, https://github.com/anjali411 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/9855f1271c33a394be299c0fe131b24753268b7b Reviewed By: b0noI Differential Revision: D35550114 fbshipit-source-id: eafbcba8b5f119e9c1fa2b2ecd0a1eb4316c1a17
1 parent ccb8d25 commit 8ba4463

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

aten/src/ATen/native/cuda/UnarySignKernels.cu

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,38 @@ void logical_not_kernel_cuda(TensorIteratorBase& iter) {
2424
}
2525

2626
// NB: Ignores the negative bit on tensors
27+
const char neg_name[] = "neg_kernel";
2728
void neg_kernel_cuda(TensorIteratorBase& iter) {
28-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "neg_cuda", [&]() {
29+
auto dtype = iter.dtype();
30+
if (at::isComplexType(dtype)) {
31+
#if AT_USE_JITERATOR()
32+
static const auto neg_string = jiterator_stringify(
33+
template <typename T>
34+
T neg_kernel(T a) {
35+
return -a;
36+
}
37+
); // neg_string
38+
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
39+
jitted_gpu_kernel<
40+
/*name=*/ neg_name,
41+
/*return_dtype=*/ scalar_t,
42+
/*common_dtype=*/ scalar_t,
43+
/*arity=*/ 1>(iter, neg_string);
44+
});
45+
#else
46+
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
47+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
48+
return -a;
49+
});
50+
});
51+
#endif
52+
} else {
53+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, dtype, "neg_cuda", [&]() {
2954
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
3055
return -a;
3156
});
3257
});
58+
}
3359
}
3460

3561
void sign_kernel_cuda(TensorIteratorBase& iter){

0 commit comments

Comments
 (0)
0