10000 move the fallback logic to inductor · pytorch/pytorch@4023e26 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4023e26

Browse files
jianyizhpytorchmergebot
authored andcommitted
move the fallback logic to inductor
1 parent 3df9506 commit 4023e26

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

torch/_decomp/decompositions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,6 @@ def embedding_dense_backward(
12361236
padding_idx: int,
12371237
scale_grad_by_freq: bool,
12381238
):
1239-
if grad_output.is_xpu:
1240-
return NotImplemented
12411239
computation_dtype, result_dtype = utils.elementwise_dtypes(
12421240
grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
12431241
)

torch/_inductor/decomposition.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch._decomp.decompositions import (
2121
_grid_sampler_2d as decomp_grid_sampler_2d,
2222
_index_add,
23+
embedding_dense_backward as decomp_embedding_dense_backward,
2324
pw_cast_for_opmath,
2425
)
2526
from torch._decomp.decompositions_for_rng import extra_random_decomps
@@ -110,6 +111,7 @@
110111
aten._softmax_backward_data,
111112
aten.clamp_max,
112113
aten.clamp_min,
114+
aten.embedding_dense_backward, # we fall back on xpu
113115
aten.index_add, # we conditionally call this decomp
114116
aten.glu, # inductor lowers this directly
115117
aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
@@ -133,6 +135,23 @@ def register_decomposition(
133135
return decomp.register_decomposition(ops, decompositions)
134136

135137

138+
@register_decomposition([aten.embedding_dense_backward])
139+
def _embedding_dense_backward(
140+
grad_output: torch.Tensor,
141+
indices: torch.Tensor,
142+
num_weights: int,
143+
padding_idx: int,
144+
scale_grad_by_freq: bool,
145+
):
146+
if grad_output.is_xpu:
147+
return NotImplemented
148+
# decomp_func = decompositions.pop(op.overloads()[0], None)
149+
# We can write a util function to update decomp table if we have more ops to fallback.
150+
return decomp_embedding_dense_backward(
151+
grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
152+
)
153+
154+
136155
# TODO: for now, inductor doesn't handle asserts
137156
# because the condition is symbol -> tensor in the graph.
138157
@register_decomposition([aten._assert_async.msg])

torch/_inductor/lowering.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2612,7 +2612,9 @@ def is_aligned(x):
26122612
make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
26132613
make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
26142614
make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
2615-
make_fallback(aten.embedding_dense_backward, warn=False) # (XPU-only and faster than decomp)
2615+
make_fallback(
2616+
aten.embedding_dense_backward, warn=False
2617+
) # (XPU-only and faster than decomp)
26162618

26172619

26182620
# 1.5) Easy or Impossible

0 commit comments

Comments
 (0)
0