8000 fallback embedding bwd on xpu only · pytorch/pytorch@6bd1e22 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6bd1e22

Browse files
committed
fallback embedding bwd on xpu only
1 parent a756c50 commit 6bd1e22

File tree

4 files changed

+18
-1
lines changed

4 files changed

+18
-1
lines changed

torch/_decomp/decompositions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,8 @@ def embedding_dense_backward(
12361236
padding_idx: int,
12371237
scale_grad_by_freq: bool,
12381238
):
1239+
if grad_output.is_xpu:
1240+
return NotImplemented
12391241
computation_dtype, result_dtype = utils.elementwise_dtypes(
12401242
grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
12411243
)
@@ -1474,7 +1476,7 @@ def _addmm_activation(
14741476
):
14751477
out = addmm(self, mat1, mat2, beta, alpha)
14761478
if use_gelu:
1477-
if self.is_cuda:
1479+
if self.is_cuda or self.is_xpu:
14781480
return aten.gelu(out, approximate="tanh")
14791481
else:
14801482
return aten.gelu(out)

torch/_inductor/lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2814,6 +2814,7 @@ def is_aligned(x):
28142814

28152815
# index_reduce requires fallback when use_scatter_fallback(...) returns True
28162816
make_fallback(aten.index_reduce)
2817+
make_fallback(aten.embedding_dense_backward)
28172818

28182819

28192820
# Register with type_promotion_kind None.

torch/_meta_registrations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7101,6 +7101,18 @@ def _check_for_unsupported_isin_dtype(dtype):
71017101
)
71027102

71037103

7104+
@register_meta(aten.embedding_dense_backward)
7105+
def meta_embedding_dense_backward(
7106+
grad_output,
7107+
indices,
7108+
num_weights,
7109+
padding_idx,
7110+
scale_grad_by_freq,
7111+
):
7112+
grad_weight = grad_output.new_empty((num_weights, grad_output.size(-1)))
7113+
return grad_weight
7114+
7115+
71047116
@register_meta(aten._embedding_bag_backward)
71057117
def meta_embedding_bag_backward(
71067118
grad,

torch/_prims_common/wrappers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def maybe_check_copy_devices(out):
307307
result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
308308
else:
309309
result = fn(*args, **kwargs)
310+
if result is NotImplemented:
311+
return NotImplemented
310312
assert (
311313
(isinstance(result, TensorLike) and is_tensor)
312314
or (

0 commit comments

Comments
 (0)
0