8000 [Inductor] Fallback embedding when sparse is True (#150659) · pytorch/pytorch@d6887f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit d6887f4

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor] Fallback embedding when sparse is True (#150659)
**Summary** Fix issue: #150656, fallback `embedding` when sparse is True. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_embedding_sparse ``` Pull Request resolved: #150659 Approved by: https://github.com/jansel
1 parent 2e23768 commit d6887f4

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5360,6 +5360,19 @@ def test_embedding(self):
53605360
(torch.randint(10, [2, 8]),),
53615361
)
53625362

5363+
def test_embedding_sparse(self):
5364+
# Fix https://github.com/pytorch/pytorch/issues/150656
5365+
def fn(weight, indices):
5366+
return F.embedding(indices, weight, sparse=True)
5367+
5368+
indices = torch.randint(10, (2, 3))
5369+
weight = torch.randn(10, 3, requires_grad=True)
5370+
5371+
self.common(
5372+
fn,
5373+
(weight, indices),
5374+
)
5375+
53635376
def test_mean(self):
53645377
def fn(x):
53655378
return (

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def run(*ex, **kwargs):
137137
"test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)),
138138
"test_flip_cat_dynamic_shapes": TestFailure(("cpu",)),
139139
"test_pad_single_dynamic_shapes": TestFailure(("cpu",)),
140+
"test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
140141
#
141142
# Failed to find for loop/triton kernel:
142143
#

torch/_inductor/lowering.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,6 +3376,11 @@ def fn(idx):
33763376

33773377
@register_lowering(aten.embedding, type_promotion_kind=None)
33783378
def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
3379+
if sparse:
3380+
return fallback_handler(aten.embedding.default)(
3381+
weight, indices, padding_idx, scale_grad_by_freq, sparse
3382+
)
3383+
33793384
assert not sparse
33803385
assert isinstance(weight, TensorBox)
33813386
assert isinstance(indices, TensorBox)

0 commit comments

Comments
 (0)
0