10000 fixup · pytorch/pytorch@c71d746 · GitHub
[go: up one dir, main page]

Skip to content

Commit c71d746

Browse files
committed
fixup
1 parent 44702b8 commit c71d746

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def _nll_loss_nd(
428428
ignore_index: int,
429429
) -> TensorLikeType:
430430
utils.check(
431-
input.ndim < 4 and input.ndim != 3,
431+
input.ndim <= 4 and input.ndim != 3,
432432
lambda: f"Expected input dimension to be either [1, 2, 4] but recieved {input.ndim}.",
433433
)
434434

@@ -513,7 +513,7 @@ def nll_loss(
513513
# TODO Can input be zero or one dimension? If so, how do we interpret that?
514514
# The documentation for suggests that input should have at least two dimensions.
515515
# Why are inputs with three or four dimensions special?
516-
if input.ndim < 4 and input.ndim != 3:
516+
if input.ndim <= 4 and input.ndim != 3:
517517
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
518518

519519
# TODO Add comment for this case
@@ -523,7 +523,7 @@ def nll_loss(
523523
out_size = [batch_size] + list(input.shape[2:])
524524

525525
utils.check(
526-
target.shape[1:] != input.shape[2:],
526+
target.shape[1:] == input.shape[2:],
527527
lambda: f"Expected target shape {out_size} but got {target.shape}",
528528
)
529529

0 commit comments

Comments
 (0)
0