8000 [MPS] Add support for `i1e` (#149203) · pytorch/pytorch@f2221b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit f2221b2

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Add support for i1e (#149203)
Followup after #149174 Pull Request resolved: #149203 Approved by: https://github.com/dcci
1 parent f067eaf commit f2221b2

File tree

5 files changed

+57
-3
lines changed

5 files changed

+57
-3
lines changed

aten/src/ATen/native/mps/kernels/SpecialOps.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
88
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
99
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
1010
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
11+
DEFINE_UNARY_FLOATING_FUNCTOR(i1e);
1112
DEFINE_UNARY_FLOATING_FUNCTOR(spherical_bessel_j0);
1213
DEFINE_UNARY_FLOATING_FUNCTOR(entr);
1314

@@ -51,6 +52,7 @@ struct bessel_y1_forward_functor {
5152
REGISTER_UNARY_OP(i0, DTI, DTO); \
5253
REGISTER_UNARY_OP(i0e, DTI, DTO); \
5354
REGISTER_UNARY_OP(i1, DTI, DTO); \
55+
REGISTER_UNARY_OP(i1e, DTI, DTO); \
5456
REGISTER_UNARY_OP(spherical_bessel_j0, DTI, DTO); \
5557
REGISTER_UNARY_OP(entr, DTI, DTO)
5658

aten/src/ATen/native/mps/operations/SpecialOps.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ static void i1_kernel_mps(TensorIteratorBase& iter) {
2424
lib.exec_unary_kernel(iter, "i1");
2525
}
2626

27+
static void i1e_kernel_mps(TensorIteratorBase& iter) {
28+
lib.exec_unary_kernel(iter, "i1e");
29+
}
30+
2731
static void spherical_bessel_j0_kernel_mps(TensorIteratorBase& iter) {
2832
lib.exec_unary_kernel(iter, "spherical_bessel_j0");
2933
}
@@ -51,6 +55,7 @@ static void bessel_y1_kernel_mps(TensorIteratorBase& iter) {
5155
REGISTER_DISPATCH(i0_stub, &i0_kernel_mps)
5256
REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_mps)
5357
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_mps)
58+
REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_mps)
5459
REGISTER_DISPATCH(special_bessel_j0_stub, &bessel_j0_kernel_mps)
5560
REGISTER_DISPATCH(special_bessel_j1_stub, &bessel_j1_kernel_mps)
5661
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13528,7 +13528,7 @@
1352813528
structured: True
1352913529
structured_inherits: TensorIteratorBase
1353013530
dispatch:
13531-
CPU, CUDA: special_i1e_out
13531+
CPU, CUDA, MPS: special_i1e_out
1353213532
tags: pointwise
1353313533

1353413534
- func: special_logit(Tensor self, float? eps=None) -> Tensor

c10/metal/special_math.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,54 @@ inline T i1(T _x) {
241241
return static_cast<T>(_x < T(0.) ? -out : out);
242242
}
243243

244+
template <typename T>
245+
T i1e(T _x) {
246+
const auto x = ::metal::fabs(_x);
247+
if (x <= 8.0) {
248+
// Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8].
249+
// Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2.
250+
constexpr float coefficients[] = {
251+
9.38153738649577178388E-9f,
252+
-4.44505912879632808065E-8f,
253+
2.00329475355213526229E-7f,
254+
-8.56872026469545474066E-7f,
255+
3.47025130813767847674E-6f,
256+
-1.32731636560394358279E-5f,
257+
4.78156510755005422638E-5f,
258+
-1.61760815825896745588E-4f,
259+
5.12285956168575772895E-4f, 10000
260+
-1.51357245063125314899E-3f,
261+
4.15642294431288815669E-3f,
262+
-1.05640848946261981558E-2f,
263+
2.47264490306265168283E-2f,
264+
-5.29459812080949914269E-2f,
265+
1.02643658689847095384E-1f,
266+
-1.76416518357834055153E-1f,
267+
2.52587186443633654823E-1f};
268+
const auto y = x / 2.0 - 2.0;
269+
const auto out = chbevl(y, coefficients, 17) * x;
270+
return static_cast<T>(_x < 0. ? -out : out);
271+
}
272+
273+
// Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
274+
// in the inverted interval (8, infinity].
275+
// Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi).
276+
// TODO: what's an "inverted interval"? Open on the left
277+
// and closed on the right?
278+
constexpr float coefficients[] = {
279+
-3.83538038596423702205E-9f,
280+
-2.63146884688951950684E-8f,
281+
-2.51223623787020892529E-7f,
282+
-3.88256480887769039346E-6f,
283+
-1.10588938762623716291E-4f,
284+
-9.76109749136146840777E-3f,
285+
7.78576235018280120474E-1f};
286+
287+
const auto out =
288+
chbevl(32. / x - 2., coefficients, 7) / ::metal::precise::sqrt(x);
289+
return static_cast<T>(_x < 0. ? -out : out);
290+
}
291+
244292
// gamma, lgamma
245293
template <typename T>
246294
inline float log_gamma(const T);

test/test_mps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def mps_ops_grad_modifier(ops):
9797
'logdet': [torch.float16, torch.float32], # missing aten::lu_solve.out
9898
'aminmax': [torch.float32, torch.float16],
9999
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
100-
'special.i0e': None, # "special_i1e" not implemented
100+
'special.i1e': [torch.float16], # "i1e_backward" not implemented for 'Half'
101101

102102
# Correctness issues
103103
'atanh': [torch.float32],
@@ -651,7 +651,6 @@ def mps_ops_modifier(ops):
651651
'special.erfcx': None,
652652
'special.hermite_polynomial_h': None,
653653
'special.hermite_polynomial_he': None,
654-
'special.i1e': None,
655654
'special.laguerre_polynomial_l': None,
656655
'special.log_ndtr': None,
657656
'special.modified_bessel_i0': None,

0 commit comments

Comments
 (0)
0