8000 [primTorch] Implement NLL loss reference (#81128) · pytorch/pytorch@847ded6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 847ded6

Browse files
rdspring1pytorchmergebot
authored andcommitted
[primTorch] Implement NLL loss reference (#81128)
Add Reference: - nll_loss Depends on: - expand #79820 - advance indexing Pull Request resolved: #81128 Approved by: https://github.com/mruberry
1 parent 78e2289 commit 847ded6

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Callable, Optional, Union
22

33
import torch
4-
54
import torch._prims as prims
65
import torch._prims_common as utils
76
import torch._refs as refs
@@ -25,6 +24,8 @@
2524
_make_elementwise_unary_reference,
2625
)
2726

27+
from torch._subclasses.fake_tensor import FakeTensor
28+
2829
__all__ = [
2930
"celu",
3031
"dropout",
@@ -36,6 +37,7 @@
3637
"l1_loss",
3738
"margin_ranking_loss",
3839
"mish",
40+
"nll_loss",
3941
"mse_loss",
4042
"poisson_nll_loss",
4143
"prelu",
@@ -435,6 +437,159 @@ def hinge_embedding_loss(
435437
return _apply_loss_reduction(loss, reduction)
436438

437439

440+
def _nll_loss_nd(
441+
input: TensorLikeType,
442+
target: TensorLikeType,
443+
weight: Optional[TensorLikeType],
444+
reduction: str,
445+
ignore_index: int,
446+
) -> TensorLikeType:
447+
utils.check(
448+
input.ndim > 0 and input.ndim <= 3,
449+
lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.",
450+
)
451+
452+
utils.check(
453+
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
454+
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
455+
)
456+
457+
_check_reduction_value(reduction)
458+
459+
flat_target = torch.flatten(target)
460+
ignore_classes_mask = torch.eq(flat_target, ignore_index)
461+
462+
# TODO: Enable data-dependent checks with debug mode
463+
# TODO: This check does not work with FakeTensor inputs; See Issue #85834
464+
# Explicit cast for class_check to bool; See Issue #78071
465+
"""
466+
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
467+
valid_classes_mask = torch.logical_and(
468+
(flat_target >= 0), (flat_target < num_classes)
469+
)
470+
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
471+
utils.check(
472+
isinstance(target, FakeTensor) or bool(class_check.item()),
473+
lambda: "A target class is out-of-bounds and not the ignore index.",
474+
)
475+
"""
476+
477+
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
478+
class_weight = (
479+
torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
480+
if weight is None
481+
else weight[flat_target]
482+
)
483+
current_weight = torch.where(
484+
ignore_classes_mask,
485+
ignore_class_weight,
486+
class_weight,
487+
)
488+
489+
if input.ndim == 1:
490+
# implicit batch size = 1
491+
# input (1 batch size, C classes)
492+
loss = -input[target] * current_weight
493+
elif input.ndim == 2:
494+
# input (N batch size, C classes)
495+
batch_size = input.shape[0]
496+
loss = -input[torch.arange(batch_size), target] * current_weight
497+
else:
498+
# 3D case (N batch size, C classe, K dimensions)
499+
# input (N batch size, C classes, K)
500+
batch_size = input.shape[0]
501+
extent = input.shape[2]
502+
numel = batch_size * extent
503+
indices = torch.arange(numel)
504+
bdx = indices // extent
505+
kdx = indices % extent
506+
loss = -input[bdx, flat_target, kdx] * current_weight
507+
loss = torch.reshape(loss, target.shape)
508+
509+
if reduction == "none":
510+
return loss
511+
elif reduction == "sum":
512+
return torch.sum(loss)
513+
else:
514+
# calculate weighted mean of the loss function
515+
return torch.sum(loss) / torch.sum(current_weight)
516+
517+
518+
@register_decomposition(torch.ops.aten.nll_loss)
519+
@elementwise_type_promotion_wrapper(
520+
type_promoting_args=("input",),
521+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
522+
)
523+
@out_wrapper()
524+
def nll_loss(
525+
input: TensorLikeType,
526+
target: TensorLikeType,
527+
weight: Optional[TensorLikeType] = None,
528+
size_average: Optional[bool] = None,
529+
ignore_index: int = -100,
530+
reduce: Optional[bool] = None,
531+
reduction: str = "mean",
532+
) -> TensorLikeType:
533+
"""
534+
Reference implementation of torch.nn.functional.nll_loss
535+
"""
536+
utils.check(
537+
input.ndim > 0,
538+
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
539+
)
540+
541+
# TODO: raise exception instead of converting value
542+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
543+
# Convert these options for consistency with the eager mode
544+
if size_average is not None or reduce is not None:
545+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
546+
547+
# The expected behavior when the target and input have zero elements:
548+
# reduction = 'none' --- tensor([])
549+
# reduction = 'sum' --- tensor(0.)
550+
# reduction = 'mean' --- tensor(nan)
551+
# Mean reduction on empty tensors produces NaN. See the discussion in
552+
# https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
553+
if input.numel() == 0 and target.numel() == 0:
554+
if reduction == "none":
555+
return torch.zeros_like(target)
556+
elif reduction == "sum":
557+
return torch.empty_like(target)
558+
else:
559+
return torch.full_like(target, float("nan"))
560+
561+
# The _nll_loss_nd helper function handles the most common cases.
562+
# ndim == 1 (Single Example)
563+
# => Batch Size: 1, Input: (C), Target: ()
564+
# ndim == 2 (k = 1)
565+
# => Batch Size: N, Input: (N, C), Target: (N)
566+
# ndim == 3 (k > 1)
567+
# => Batch Size: N, Input: (N, C, K), Target: (N, K)
568+
if input.ndim <= 3:
569+
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
570+
571+
# For ndim > 3, we reshape the input and target to 3-D case.
572+
# Input (N batch-size, C classes, k-dimensions)
573+
# Target (N batch-size, k-dimensions)
574+
utils.check(
575+
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
576+
lambda: f"Expected target shape {out_size} but got {target.shape}",
577+
)
578+
579+
batch_size = input.shape[0]
580+
num_classes = input.shape[1]
581+
out_size = [batch_size] + list(target.shape[1:])
582+
583+
input = torch.reshape(input, [batch_size, num_classes, -1])
584+
target = torch.reshape(target, [batch_size, -1])
585+
if reduction != "none":
586+
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
587+
else:
588+
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
589+
# reshape flattened inner-dim to original k-dimensions
590+
return torch.reshape(result, out_size)
591+
592+
438593
# TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
439594
# https://github.com/pytorch/pytorch/issues/83931
440595
# TODO: Could be rewritten to support complex:

torch/testing/_internal/common_methods_invocations.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16871,6 +16871,23 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1687116871
torch_opinfo_name="nn.functional.hinge_embedding_loss",
1687216872
supports_nvfuser=False,
1687316873
),
16874+
PythonRefInfo(
16875+
"_refs.nn.functional.nll_loss",
16876+
torch_opinfo_name="nn.functional.nll_loss",
16877+
# The corresponding PyTorch op doesn't support out. But the ref is
16878+
# registered as a decomp and ATen has an out variant.
16879+
supports_out=True,
16880+
supports_nvfuser=False,
16881+
# For simpler indexing, we flatten target indices, then reshape the result tensor.
16882+
# This creates inconsistent view state with reference impl.
16883+
validate_view_consistency=False,
16884+
skips=(
16885+
# RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
16886+
DecorateInfo(
16887+
unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
16888+
),
16889+
),
16890+
),
1687416891
PythonRefInfo(
1687516892
"_refs.nn.functional.huber_loss",
1687616893
torch_opinfo_name="nn.functional.huber_loss",

0 commit comments

Comments
 (0)
0