8000 Address comments - rnd 1 · pytorch/pytorch@44702b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 44702b8

Browse files
committed
Address comments - rnd 1
1 parent e7a3ae4 commit 44702b8

File tree

1 file changed

+45
-36
lines changed

1 file changed

+45
-36
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -427,33 +427,35 @@ def _nll_loss_nd(
427427
reduction: str,
428428
ignore_index: int,
429429
) -> TensorLikeType:
430-
if input.ndim == 3 or input.ndim > 4:
431-
msg = "Expected input dimension to be either [1, 2, 4] but recieved {}."
432-
raise ValueError(msg.format(input.ndim))
430+
utils.check(
431+
input.ndim < 4 and input.ndim != 3,
432+
lambda: f"Expected input dimension to be either [1, 2, 4] but recieved {input.ndim}.",
433+
)
433434

434-
if input.ndim != 1 and input.shape[0] != target.shape[0]:
435-
msg = "Expected input batch size ({}) to match target batch size ({})."
436-
raise ValueError(msg.format(input.shape[0], target.shape[0]))
435+
utils.check(
436+
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
437+
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
438+
)
437439

438440
_check_reduction_value(reduction)
439441

440442
flat_target = torch.reshape(target, [-1])
441443
ignore_classes_mask = torch.eq(flat_target, ignore_index)
442444
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
443-
default_class_weight = torch.scalar_tensor(
444-
1, dtype=input.dtype, device=input.device
445-
)
446445

447446
# TODO: This check does not work with FakeTensor inputs
448447
"""
449448
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
450449
valid_classes_mask = torch.logical_and((flat_target >= 0), (flat_target < num_classes))
451450
if not torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)):
452-
print(target, num_classes, ignore_index)
453451
raise ValueError("Target class is out-of-bounds and not ignore index")
454452
"""
455453

454+
# TODO Add comment for expansion
456455
if weight is None:
456+
default_class_weight = torch.scalar_tensor(
457+
1, dtype=input.dtype, device=input.device
458+
)
457459
current_weight = torch.where(
458460
ignore_classes_mask,
459461
ignore_class_weight,
@@ -466,6 +468,7 @@ def _nll_loss_nd(
466468
weight[flat_target],
467469
)
468470

471+
# TODO Add comments for each case
469472
if input.ndim == 1:
470473
loss = -input[target] * current_weight
471474
elif input.ndim == 2:
@@ -488,6 +491,7 @@ def _nll_loss_nd(
488491
elif reduction == "sum":
489492
return torch.sum(loss)
490493
else:
494+
# TODO Add comments "mean" reduction case
491495
return torch.sum(loss) / torch.sum(current_weight)
492496

493497

@@ -506,34 +510,39 @@ def nll_loss(
506510
# msg = "size_average and reduce args are deprecated, please use reduction argument."
507511
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
508512

509-
if input.ndim == 3 or input.ndim > 4:
510-
# input ndim is == 3 or > 4
511-
batch_size = input.shape[0]
512-
num_classes = input.shape[1]
513-
out_size = [batch_size] + list(input.shape[2:])
514-
515-
if target.shape[1:] != input.shape[2:]:
516-
msg = "Expected target size {} but got {}"
517-
raise ValueError(msg.format(out_size, target.shape))
518-
519-
# support empty batches, see #15870
520-
if input.numel() > 0:
521-
input = torch.reshape(input, [batch_size, num_classes, 1, -1])
522-
else:
523-
input = torch.reshape(input, [batch_size, num_classes, 0, 0])
524-
525-
if target.numel() > 0:
526-
target = torch.reshape(target, [batch_size, 1, -1])
527-
else:
528-
target = torch.reshape(target, [batch_size, 0, 0])
529-
530-
if reduction == "none":
531-
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
532-
else:
533-
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
534-
return torch.reshape(result, out_size)
513+
# TODO Can input be zero or one dimension? If so, how do we interpret that?
514+
# The documentation for suggests that input should have at least two dimensions.
515+
# Why are inputs with three or four dimensions special?
516+
if input.ndim < 4 and input.ndim != 3:
517+
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
518+
519+
# TODO Add comment for this case
520+
# input.ndim == 3 or input.ndim > 4
521+
batch_size = input.shape[0]
522+
num_classes = input.shape[1]
523+
out_size = [batch_size] + list(input.shape[2:])
524+
525+
utils.check(
526+
target.shape[1:] != input.shape[2:],
527+
lambda: f"Expected target shape 67ED {out_size} but got {target.shape}",
528+
)
529+
530+
# support empty batches, see #15870
531+
if input.numel() > 0:
532+
input = torch.reshape(input, [batch_size, num_classes, 1, -1])
533+
else:
534+
input = torch.reshape(input, [batch_size, num_classes, 0, 0])
535+
536+
if target.numel() > 0:
537+
target = torch.reshape(target, [batch_size, 1, -1])
535538
else:
539+
target = torch.reshape(target, [batch_size, 0, 0])
540+
541+
if reduction == "none":
536542
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
543+
else:
544+
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
545+
return torch.reshape(result, out_size)
537546

538547

539548
# tanhshrink does not use _make_elementwise_unary_reference because it does not support out

0 commit comments

Comments
 (0)
0