8000 add type promotion wrapper · pytorch/pytorch@3cd82ab · GitHub
[go: up one dir, main page]

Skip to content

Commit 3cd82ab

Browse files
committed
add type promotion wrapper
1 parent 76bfc80 commit 3cd82ab<
8000
div>

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,10 @@ def _nll_loss_nd(
516516

517517

518518
@register_decomposition(torch.ops.aten.nll_loss)
519+
@elementwise_type_promotion_wrapper(
520+
type_promoting_args=("input",),
521+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
522+
)
519523
@out_wrapper()
520524
def nll_loss(
521525
input: TensorLikeType,

0 commit comments

Comments
 (0)
0