8000 [mps] Add a shader for spherical_bessel_j0. (#146771) · pytorch/pytorch@91c4bf3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 91c4bf3

Browse files
dccipytorchmergebot
authored andcommitted
[mps] Add a shader for spherical_bessel_j0. (#146771)
In preparation for adding the operation to inductor/eager. Adapted from the CUDA version of the shader. Pull Request resolved: #146771 Approved by: https://github.com/malfet
1 parent 0e83e7d commit 91c4bf3

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

c10/metal/special_math.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
477477
return float2(re, im) / a2;
478478
}
479479

480+
template <typename T>
481+
inline T spherical_bessel_j0(T x) {
482+
if (::metal::isinf(x))
483+
return T(0.0);
484+
T x2 = x * x;
485+
T k1 = static_cast<T>(-1.0);
486+
T k2 = static_cast<T>(1.0);
487+
488+
if (::metal::abs(x) < T(0.5)) {
489+
return T(1.0) +
490+
x2 *
491+
(k1 / T(6.0) +
492+
x2 *
493+
(k2 / T(120.0) +
494+
x2 *
495+
(k1 / T(5040.0) +
496+
x2 *
497+
(k2 / T(362880.0) +
498+
x2 *
499+
(k1 / T(39916800.0) +
500+
x2 * (k2 / T(6227020800.0)))))));
501+
}
502+
503+
return ::metal::sin(x) / x;
504+
}
505+
480506
} // namespace metal
481507
} // namespace c10

0 commit comments

Comments
 (0)
0