8000 [MPS] Add inductor support for spherical_bessel_j0. (#147650) · pytorch/pytorch@6a5e391 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a5e391

Browse files
dccipytorchmergebot
authored andcommitted
[MPS] Add inductor support for spherical_bessel_j0. (#147650)
Counterpart to my previous patch that added support for the op in eager. Pull Request resolved: #147650 Approved by: https://github.com/jansel
1 parent f9c117f commit 6a5e391

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

test/inductor/test_mps_basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ def test_pointwise_zeta(self):
114114
check_lowp=False,
115115
)
116116

117+
def test_pointwise_spherical_bessel_j0(self):
118+
self.common(
119+
torch.special.spherical_bessel_j0, (torch.rand(128, 128),), check_lowp=False
120+
)
121+
117122
def test_broadcast(self):
118123
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
119124

torch/_inductor/codegen/mps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,10 @@ def pow(a: CSEVariable, b: CSEVariable) -> str:
363363
def zeta(a: CSEVariable, b: CSEVariable) -> str:
364364
return f"c10::metal::zeta({a}, {b})"
365365

366+
@staticmethod
367+
def spherical_bessel_j0(x: CSEVariable) -> str:
368+
return f"c10::metal::spherical_bessel_j0({x})"
369+
366370

367371
MetalOverrides._initialize_pointwise_overrides("mps")
368372

0 commit comments

Comments
 (0)
0