8000 [primTorch] Implement NLL loss reference by rdspring1 · Pull Request #81128 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[primTorch] Implement NLL loss reference #81128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2256d36
Initial nll_loss implementation
rdspring1 Jun 29, 2022
b1393e5
fixup
rdspring1 Jun 29, 2022
f72db25
Disable validate_view_consistency check
rdspring1 Jun 29, 2022
055e0e2
Merge 1d and 2d nll_loss functions
rdspring1 Jun 29, 2022
96cc303
Add target class check - disabled because of FakeTensor
rdspring1 Jun 29, 2022
370bc60
refactor helper function
rdspring1 Jul 8, 2022
612ce91
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Sep 25, 2022
e7a3ae4
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Sep 27, 2022
44702b8
Address comments - rnd 1
rdspring1 Sep 27, 2022
c71d746
fixup
rdspring1 Sep 27, 2022
e0554f2
Refactor class weight selection
rdspring1 Sep 28, 2022
6aa6b62
Add comments
rdspring1 Sep 28, 2022
dde53e3
Replace 4-D case for image inputs with general 3-D case
rdspring1 Sep 28, 2022
4df99 8000 71
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Sep 28, 2022
39883b6
add comments
rdspring1 Sep 28, 2022
1a635cd
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Sep 28, 2022
590866b
Add class check
rdspring1 Sep 28, 2022
1b88f57
Add FakeTensor Issue
rdspring1 Sep 28, 2022
c59279e
add zero-dim check
rdspring1 Sep 28, 2022
e6d01e4
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Sep 30, 2022
f2c9c3f
Update comments
rdspring1 Sep 30, 2022
10b85ff
fixup
rdspring1 Sep 30, 2022
96a6142
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Oct 3, 2022
6cbdf01
lint
rdspring1 Oct 3, 2022
746a60e
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Oct 11, 2022
e1eb641
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 Oct 16, 2022
ef5719e
PR comments
rdspring1 Oct 16, 2022
76bfc80
update test args
rdspring1 Oct 16, 2022
3cd82ab
add type promotion wrapper
rdspring1 Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion torch/_refs/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Optional, Union

import torch

import torch._prims as prims
import torch._prims_common as utils
import torch._refs as refs
Expand All @@ -25,6 +24,8 @@
_make_elementwise_unary_reference,
)

from torch._subclasses.fake_tensor import FakeTensor

__all__ = [
"celu",
"dropout",
Expand All @@ -36,6 +37,7 @@
"l1_loss",
"margin_ranking_loss",
"mish",
"nll_loss",
"mse_loss",
"poisson_nll_loss",
"prelu",
Expand Down Expand Up @@ -435,6 +437,159 @@ def hinge_embedding_loss(
return _apply_loss_reduction(loss, reduction)


def _nll_loss_nd(
input: TensorLikeType,
target: TensorLikeType,
weight: Optional[TensorLikeType],
reduction: str,
ignore_index: int,
) -> TensorLikeType:
utils.check(
input.ndim > 0 and input.ndim <= 3,
lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.",
)

utils.check(
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
)

_check_reduction_value(reduction)

flat_target = torch.flatten(target)
ignore_classes_mask = torch.eq(flat_target, ignore_index)

# TODO: Enable data-dependent checks with debug mode
# TODO: This check does not work with FakeTensor inputs; See Issue #85834
# Explicit cast for class_check to bool; See Issue #78071
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These data-dependent checks are a pain.

Would you file an issue so we can discuss this as a group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened issue #85834.

num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
valid_classes_mask = torch.logical_and(
(flat_target >= 0), (flat_target < num_classes)
)
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
utils.check(
isinstance(target, FakeTensor) or bool(class_check.item()),
lambda: "A target class is out-of-bounds and not the ignore index.",
)
Comment on lines +471 to +474
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment this out for now until we have a debug mode for data-dependent checks.

"""

ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
class_weight = (
torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
if weight is None
else weight[flat_target]
)
current_weight = torch.where(
ignore_classes_mask,
ignore_class_weight,
class_weight,
)

if input.ndim == 1:
# implicit batch size = 1
# input (1 batch size, C classes)
loss = -input[target] * current_weight
elif input.ndim == 2:
# input (N batch size, C classes)
batch_size = input.shape[0]
loss = -input[torch.arange(batch_size), target] * current_weight
else:
# 3D case (N batch size, C classe, K dimensions)
# input (N batch size, C classes, K)
batch_size = input.shape[0]
extent = input.shape[2]
numel = batch_size * extent
indices = torch.arange(numel)
bdx = indices // extent
kdx = indices % extent
loss = -input[bdx, flat_target, kdx] * current_weight
loss = torch.reshape(loss, target.shape)

if reduction == "none":
return loss
elif reduction == "sum":
return torch.sum(loss)
else:
# calculate weighted mean of the loss function
return torch.sum(loss) / torch.sum(current_weight)


@register_decomposition(torch.ops.aten.nll_loss)
@elementwise_type_promotion_wrapper(
type_promoting_args=("input",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
@out_wrapper()
def nll_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try wrapping with type promotion decorator

input: TensorLikeType,
target: TensorLikeType,
weight: Optional[TensorLikeType] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
) -> TensorLikeType:
"""
Reference implementation of torch.nn.functional.nll_loss
"""
utils.check(
input.ndim > 0,
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
)

# TODO: raise exception instead of converting value
# msg = "size_average and reduce args are deprecated, please use reduction argument."
# Convert these options for consistency with the eager mode
if size_average is not None or reduce is not None:
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)

# The expected behavior when the target and input have zero elements:
# reduction = 'none' --- tensor([])
# reduction = 'sum' --- tensor(0.)
# reduction = 'mean' --- tensor(nan)
# Mean reduction on empty tensors produces NaN. See the discussion in
# https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
if input.numel() == 0 and target.numel() == 0:
if reduction == "none":
return torch.zeros_like(target)
elif reduction == "sum":
return torch.empty_like(target)
else:
return torch.full_like(target, float("nan"))

# The _nll_loss_nd helper function handles the most common cases.
# ndim == 1 (Single Example)
# => Batch Size: 1, Input: (C), Target: ()
# ndim == 2 (k = 1)
# => Batch Size: N, Input: (N, C), Target: (N)
# ndim == 3 (k > 1)
# => Batch Size: N, Input: (N, C, K), Target: (N, K)
if input.ndim <= 3:
return _nll_loss_nd(input, target, weight, reduction, ignore_index)

# For ndim > 3, we reshape the input and target to 3-D case.
# Input (N batch-size, C classes, k-dimensions)
# Target (N batch-size, k-dimensions)
utils.check(
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
lambda: f"Expected target shape {out_size} but got {target.shape}",
)

batch_size = input.shape[0]
num_classes = input.shape[1]
out_size = [batch_size] + list(target.shape[1:])

input = torch.reshape(input, [batch_size, num_classes, -1])
target = torch.reshape(target, [batch_size, -1])
if reduction != "none":
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
else:
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
# reshape flattened inner-dim to original k-dimensions
return torch.reshape(result, out_size)


# TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
# https://github.com/pytorch/pytorch/issues/83931
# TODO: Could be rewritten to support complex:
Expand Down
17 changes: 17 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16862,6 +16862,23 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
torch_opinfo_name="nn.functional.hinge_embedding_loss",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.nll_loss",
torch_opinfo_name="nn.functional.nll_loss",
# The corresponding PyTorch op doesn't support out. But the ref is
# registered as a decomp and ATen has an out variant.
supports_out=True,
supports_nvfuser=False,
# For simpler indexing, we flatten target indices, then reshape the result tensor.
# This creates inconsistent view state with reference impl.
validate_view_consistency=False,
skips=(
# RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
DecorateInfo(
unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
),
),
),
PythonRefInfo(
"_refs.nn.functional.huber_loss",
torch_opinfo_name="nn.functional.huber_loss",
Expand Down
0