-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[primTorch] Implement NLL loss reference #81128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
2256d36
Initial nll_loss implementation
rdspring1 b1393e5
fixup
rdspring1 f72db25
Disable validate_view_consistency check
rdspring1 055e0e2
Merge 1d and 2d nll_loss functions
rdspring1 96cc303
Add target class check - disabled because of FakeTensor
rdspring1 370bc60
refactor helper function
rdspring1 612ce91
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 e7a3ae4
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 44702b8
Address comments - rnd 1
rdspring1 c71d746
fixup
rdspring1 e0554f2
Refactor class weight selection
rdspring1 6aa6b62
Add comments
rdspring1 dde53e3
Replace 4-D case for image inputs with general 3-D case
rdspring1 4df99
8000
71
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 39883b6
add comments
rdspring1 1a635cd
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 590866b
Add class check
rdspring1 1b88f57
Add FakeTensor Issue
rdspring1 c59279e
add zero-dim check
rdspring1 e6d01e4
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 f2c9c3f
Update comments
rdspring1 10b85ff
fixup
rdspring1 96a6142
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 6cbdf01
lint
rdspring1 746a60e
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 e1eb641
Merge branch 'master' of github.com:rdspring1/pytorch into ref_nll_loss
rdspring1 ef5719e
PR comments
rdspring1 76bfc80
update test args
rdspring1 3cd82ab
add type promotion wrapper
rdspring1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
|
||
import torch._prims as prims | ||
import torch._prims_common as utils | ||
import torch._refs as refs | ||
|
@@ -25,6 +24,8 @@ | |
_make_elementwise_unary_reference, | ||
) | ||
|
||
from torch._subclasses.fake_tensor import FakeTensor | ||
|
||
__all__ = [ | ||
"celu", | ||
"dropout", | ||
|
@@ -36,6 +37,7 @@ | |
"l1_loss", | ||
"margin_ranking_loss", | ||
"mish", | ||
"nll_loss", | ||
"mse_loss", | ||
"poisson_nll_loss", | ||
"prelu", | ||
|
@@ -435,6 +437,159 @@ def hinge_embedding_loss( | |
return _apply_loss_reduction(loss, reduction) | ||
|
||
|
||
def _nll_loss_nd( | ||
input: TensorLikeType, | ||
target: TensorLikeType, | ||
weight: Optional[TensorLikeType], | ||
reduction: str, | ||
ignore_index: int, | ||
) -> TensorLikeType: | ||
utils.check( | ||
input.ndim > 0 and input.ndim <= 3, | ||
lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.", | ||
) | ||
|
||
utils.check( | ||
(input.ndim == 1) or (input.shape[0] == target.shape[0]), | ||
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", | ||
) | ||
|
||
_check_reduction_value(reduction) | ||
|
||
flat_target = torch.flatten(target) | ||
ignore_classes_mask = torch.eq(flat_target, ignore_index) | ||
|
||
# TODO: Enable data-dependent checks with debug mode | ||
# TODO: This check does not work with FakeTensor inputs; See Issue #85834 | ||
# Explicit cast for class_check to bool; See Issue #78071 | ||
""" | ||
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] | ||
valid_classes_mask = torch.logical_and( | ||
(flat_target >= 0), (flat_target < num_classes) | ||
) | ||
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) | ||
utils.check( | ||
isinstance(target, FakeTensor) or bool(class_check.item()), | ||
lambda: "A target class is out-of-bounds and not the ignore index.", | ||
) | ||
Comment on lines
+471
to
+474
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment this out for now until we have a debug mode for data-dependent checks. |
||
""" | ||
|
||
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) | ||
class_weight = ( | ||
torch.scalar_tensor(1, dtype=input.dtype, device=input.device) | ||
if weight is None | ||
else weight[flat_target] | ||
) | ||
current_weight = torch.where( | ||
ignore_classes_mask, | ||
ignore_class_weight, | ||
class_weight, | ||
) | ||
|
||
if input.ndim == 1: | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# implicit batch size = 1 | ||
# input (1 batch size, C classes) | ||
loss = -input[target] * current_weight | ||
elif input.ndim == 2: | ||
# input (N batch size, C classes) | ||
batch_size = input.shape[0] | ||
loss = -input[torch.arange(batch_size), target] * current_weight | ||
else: | ||
# 3D case (N batch size, C classe, K dimensions) | ||
# input (N batch size, C classes, K) | ||
batch_size = input.shape[0] | ||
extent = input.shape[2] | ||
numel = batch_size * extent | ||
indices = torch.arange(numel) | ||
bdx = indices // extent | ||
kdx = indices % extent | ||
loss = -input[bdx, flat_target, kdx] * current_weight | ||
loss = torch.reshape(loss, target.shape) | ||
|
||
if reduction == "none": | ||
return loss | ||
elif reduction == "sum": | ||
return torch.sum(loss) | ||
else: | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# calculate weighted mean of the loss function | ||
return torch.sum(loss) / torch.sum(current_weight) | ||
|
||
|
||
@register_decomposition(torch.ops.aten.nll_loss) | ||
@elementwise_type_promotion_wrapper( | ||
type_promoting_args=("input",), | ||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, | ||
) | ||
@out_wrapper() | ||
def nll_loss( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try wrapping with type promotion decorator |
||
input: TensorLikeType, | ||
target: TensorLikeType, | ||
weight: Optional[TensorLikeType] = None, | ||
size_average: Optional[bool] = None, | ||
ignore_index: int = -100, | ||
reduce: Optional[bool] = None, | ||
reduction: str = "mean", | ||
) -> TensorLikeType: | ||
""" | ||
Reference implementation of torch.nn.functional.nll_loss | ||
""" | ||
utils.check( | ||
input.ndim > 0, | ||
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", | ||
) | ||
|
||
# TODO: raise exception instead of converting value | ||
# msg = "size_average and reduce args are deprecated, please use reduction argument." | ||
# Convert these options for consistency with the eager mode | ||
if size_average is not None or reduce is not None: | ||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) | ||
|
||
# The expected behavior when the target and input have zero elements: | ||
# reduction = 'none' --- tensor([]) | ||
# reduction = 'sum' --- tensor(0.) | ||
# reduction = 'mean' --- tensor(nan) | ||
# Mean reduction on empty tensors produces NaN. See the discussion in | ||
# https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 | ||
if input.numel() == 0 and target.numel() == 0: | ||
if reduction == "none": | ||
return torch.zeros_like(target) | ||
elif reduction == "sum": | ||
return torch.empty_like(target) | ||
else: | ||
return torch.full_like(target, float("nan")) | ||
|
||
# The _nll_loss_nd helper function handles the most common cases. | ||
# ndim == 1 (Single Example) | ||
# => Batch Size: 1, Input: (C), Target: () | ||
# ndim == 2 (k = 1) | ||
# => Batch Size: N, Input: (N, C), Target: (N) | ||
# ndim == 3 (k > 1) | ||
# => Batch Size: N, Input: (N, C, K), Target: (N, K) | ||
if input.ndim <= 3: | ||
return _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
|
||
# For ndim > 3, we reshape the input and target to 3-D case. | ||
# Input (N batch-size, C classes, k-dimensions) | ||
# Target (N batch-size, k-dimensions) | ||
utils.check( | ||
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], | ||
lambda: f"Expected target shape {out_size} but got {target.shape}", | ||
) | ||
|
||
batch_size = input.shape[0] | ||
num_classes = input.shape[1] | ||
out_size = [batch_size] + list(target.shape[1:]) | ||
|
||
input = torch.reshape(input, [batch_size, num_classes, -1]) | ||
target = torch.reshape(target, [batch_size, -1]) | ||
if reduction != "none": | ||
return _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
else: | ||
result = _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
# reshape flattened inner-dim to original k-dimensions | ||
return torch.reshape(result, out_size) | ||
|
||
|
||
# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: | ||
# https://github.com/pytorch/pytorch/issues/83931 | ||
# TODO: Could be rewritten to support complex: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These data-dependent checks are a pain.
Would you file an issue so we can discuss this as a group?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened issue #85834.