8000 [mps] Implement support for sinc() operator (inductor and eager). (#1… · pytorch/pytorch@46390e9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 46390e9

Browse files
dccimalfet
authored andcommitted
[mps] Implement support for sinc() operator (inductor and eager). (#146539)
Pull Request resolved: #146539 Approved by: https://github.com/malfet, https://github.com/jansel Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent a14c780 commit 46390e9

File tree

7 files changed

+47
-2
lines changed

7 files changed

+47
-2
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,23 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
108108

109109
INSTANTIATE_UNARY_KERNELS_VEC2(short, short);
110110
INSTANTIATE_UNARY_KERNELS_VEC2(float, float);
111+
112+
template <typename T0, typename T1>
113+
kernel void sinc_kernel(
114+
device T0* output [[buffer(0)]],
115+
constant T1* input [[buffer(1)]],
116+
uint index [[thread_position_in_grid]]) {
117+
output[index] = T0(sinc(static_cast<float>(input[index])));
118+
}
119+
120+
#define INSTANTIATE_SINC_KERNEL(DTYPE0, DTYPE1) \
121+
template [[host_name("sinc_" #DTYPE0 "_" #DTYPE1)]] kernel void sinc_kernel( \
122+
device DTYPE0* output [[buffer(0)]], \
123+
constant DTYPE1* input [[buffer(1)]], \
124+
uint id [[thread_position_in_grid]]);
125+
126+
#if __METAL_VERSION__ >= 310
127+
INSTANTIATE_SINC_KERNEL(bfloat, bfloat);
128+
#endif
129+
INSTANTIATE_SINC_KERNEL(half, half);
130+
INSTANTIATE_SINC_KERNEL(float, float);

aten/src/ATen/native/mps/operations/UnaryKernel.mm

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#else
99
#include <ATen/ops/erfinv_native.h>
1010
#include <ATen/ops/exp_native.h>
11+
#include <ATen/ops/sinc_native.h>
1112
#include <ATen/ops/tanh_native.h>
1213
#endif
1314

@@ -74,6 +75,11 @@ static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const s
7475
TORCH_IMPL_FUNC(exp_out_mps)(const Tensor& self, const Tensor& output_) {
7576
exec_unary_kernel(self, output_, "exp");
7677
}
78+
79+
TORCH_IMPL_FUNC(sinc_out_mps)(const Tensor& self, const Tensor& output_) {
80+
exec_unary_kernel(self, output_, "sinc");
81+
}
82+
7783
TORCH_IMPL_FUNC(tanh_out_mps)(const Tensor& self, const Tensor& output_) {
7884
exec_unary_kernel(self, output_, "tanh");
7985
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5388,6 +5388,7 @@
53885388
structured_inherits: TensorIteratorBase
53895389
dispatch:
53905390
CPU, CUDA: sinc_out
5391+
MPS: sinc_out_mps
53915392
tags: pointwise
53925393

53935394
- func: sinh(Tensor self) -> Tensor

c10/metal/special_math.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,5 +451,15 @@ inline float digamma(T0 x) {
451451
}
452452
}
453453

454+
template <typename T>
455+
T sinc(T a) {
456+
if (a == static_cast<T>(0)) {
457+
return static_cast<T>(1);
458+
}
459+
constexpr T pi = static_cast<T>(M_PI_F);
460+
T product = pi * a;
461+
return static_cast<T>(::metal::sin(product) / product);
462+
}
463+
454464
} // namespace metal
455465
} // namespace c10

test/inductor/test_mps_basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def test_pointwise_polygamma(self):
104104
def test_pointwise_digamma(self):
105105
self.common(torch.special.digamma, (torch.rand(128, 128),), check_lowp=False)
106106

107+
def test_pointwise_sinc(self):
108+
self.common(torch.special.sinc, (torch.rand(128, 128),), check_lowp=False)
109+
107110
def test_pointwise_zeta(self):
108111
self.common(
109112
torch.special.zeta,

test/test_mps.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def mps_ops_modifier(ops):
487487
'square',
488488
'stack',
489489
'stft',
490+
'sinc',
490491
'sum',
491492
'sum_to_size',
492493
'tan',
@@ -767,7 +768,6 @@ def mps_ops_modifier(ops):
767768
'_segment_reduce_lengths': None,
768769
'_segment_reducelengths': None,
769770
'_segment_reduceoffsets': None,
770-
'sinc': None,
771771
'sparse.mm': None,
772772
'sparse.mmreduce': None,
773773
'special.airy_ai': None,
@@ -842,8 +842,9 @@ def mps_ops_modifier(ops):
842842
'atan2': [torch.int64],
843843
'angle': [torch.int64],
844844

845-
# zeta isn't supported for integral types
845+
# Operations not supported for integral types
846846
'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
847+
'sinc': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8, torch.complex64],
847848

848849
# GEMM on MPS is not supported for integral types
849850
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],

torch/_inductor/codegen/mps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def signbit(x: CSEVariable) -> str:
230230
def sin(x: CSEVariable) -> str:
231231
return f"metal::precise::sin({x})"
232232

233+
@staticmethod
234+
def sinc(x: CSEVariable) -> str:
235+
return f"c10::metal::sinc({x})"
236+
233237
@staticmethod
234238
def cos(x: CSEVariable) -> str:
235239
return f"metal::precise::cos({x})"

0 commit comments

Comments
 (0)
0