diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index bd146e96c494..58d3c82445ff 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,7 +1,6 @@ from typing import Callable, Optional, Union import torch - import torch._prims as prims import torch._prims_common as utils import torch._refs as refs @@ -25,6 +24,8 @@ _make_elementwise_unary_reference, ) +from torch._subclasses.fake_tensor import FakeTensor + __all__ = [ "celu", "dropout", @@ -36,6 +37,7 @@ "l1_loss", "margin_ranking_loss", "mish", + "nll_loss", "mse_loss", "poisson_nll_loss", "prelu", @@ -435,6 +437,159 @@ def hinge_embedding_loss( return _apply_loss_reduction(loss, reduction) +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + utils.check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.", + ) + + utils.check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + utils.check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classe, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(torch.ops.aten.nll_loss) +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +@out_wrapper() +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + utils.check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + utils.check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: f"Expected target shape {out_size} but got {target.shape}", + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + # TODO: This ref supports int reduction and out kwarg to be compatible with ATen: # https://github.com/pytorch/pytorch/issues/83931 # TODO: Could be rewritten to support complex: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 350263976096..fae5f203e697 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16862,6 +16862,23 @@ def reference_flatten(input, start_dim=0, end_dim=-1): torch_opinfo_name="nn.functional.hinge_embedding_loss", supports_nvfuser=False, ), + PythonRefInfo( + "_refs.nn.functional.nll_loss", + torch_opinfo_name="nn.functional.nll_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + supports_nvfuser=False, + # For simpler indexing, we flatten target indices, then reshape the result tensor. + # This creates inconsistent view state with reference impl. + validate_view_consistency=False, + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda" + ), + ), + ), PythonRefInfo( "_refs.nn.functional.huber_loss", torch_opinfo_name="nn.functional.huber_loss",