-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[primTorch] Implement NLL loss reference #81128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2256d36
b1393e5
f72db25
055e0e2
96cc303
370bc60
612ce91
e7a3ae4
44702b8
c71d746
e0554f2
6aa6b62
dde53e3
4df9971
39883b6
1a635cd
590866b
1b88f57
c59279e
e6d01e4
f2c9c3f
10b85ff
96a6142
6cbdf01
746a60e
e1eb641
ef5719e
76bfc80
3cd82ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -366,22 +366,18 @@ def _get_string_reduction_arg( | |
return ret | ||
|
||
|
||
def _nll_loss_2d( | ||
def _nll_loss_nd( | ||
input: TensorLikeType, | ||
target: TensorLikeType, | ||
weight: Optional[TensorLikeType], | ||
reduction: str, | ||
ignore_index: int, | ||
) -> TensorLikeType: | ||
if input.ndim != 4: | ||
msg = "Expected 4 dimensions for input but got {}." | ||
if input 8000 .ndim == 3 or input.ndim > 4: | ||
msg = "Expected input dimension to be either [1, 2, 4] but recieved {}." | ||
raise ValueError(msg.format(input.ndim)) | ||
|
||
if target.ndim != 3: | ||
msg = "Expected 3 dimensions for input but got {}." | ||
raise ValueError(msg.format(input.ndim)) | ||
|
||
if input.shape[0] != target.shape[0]: | ||
if input.ndim != 1 and input.shape[0] != target.shape[0]: | ||
msg = "Expected input batch size ({}) to match target batch size ({})." | ||
raise ValueError(msg.format(input.shape[0], target.shape[0])) | ||
|
||
|
@@ -403,62 +399,22 @@ def _nll_loss_2d( | |
weight[flat_target], | ||
) | ||
|
||
batch_size = input.shape[0] | ||
height = input.shape[2] | ||
width = input.shape[3] | ||
extent = height * width | ||
numel = batch_size * extent | ||
|
||
bdx = torch.arange(numel) // extent | ||
hdx = (torch.arange(numel) - (bdx * extent)) // width | ||
wdx = torch.arange(numel) % width | ||
|
||
loss = -input[bdx, flat_target, hdx, wdx] * current_weight | ||
loss = torch.reshape(loss, target.shape) | ||
|
||
if reduction == "none": | ||
return loss | ||
elif reduction == "sum": | ||
return torch.sum(loss) | ||
else: | ||
return torch.sum(loss) / torch.sum(current_weight) | ||
|
||
|
||
def _nll_loss_1d( | ||
input: TensorLikeType, | ||
target: TensorLikeType, | ||
weight: Optional[TensorLikeType], | ||
reduction: str, | ||
ignore_index: int, | ||
) -> TensorLikeType: | ||
if input.ndim < 1 or input.ndim > 2: | ||
8000 msg = "Expected 1 or 2 dimensions for input but got {}." | ||
raise ValueError(msg.format(input.ndim)) | ||
|
||
if input.ndim != 1 and input.shape[0] != target.shape[0]: | ||
msg = "Expected input batch size ({}) to match target batch size ({})." | ||
raise ValueError(msg.format(input.shape[0], target.shape[0])) | ||
|
||
_check_reduction_value(reduction) | ||
|
||
ignore_classes_mask = torch.eq(target, ignore_index) | ||
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) | ||
if weight is None: | ||
current_weight = torch.where( | ||
ignore_classes_mask, | ||
ignore_class_weight, | ||
torch.scalar_tensor(1, dtype=input.dtype, device=input.device), | ||
) | ||
else: | ||
current_weight = torch.where( | ||
ignore_classes_mask, ignore_class_weight.expand_as(target), weight[target] | ||
) | ||
|
||
if input.ndim == 1: | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
loss = -input[target] * current_weight | ||
else: | ||
elif input.ndim == 2: | ||
batch_size = input.shape[0] | ||
loss = -input[torch.arange(batch_size), target] * current_weight | ||
else: | ||
batch_size = input.shape[0] | ||
height = input.shape[2] | ||
width = input.shape[3] | ||
extent = height * width | ||
numel = batch_size * extent | ||
bdx = torch.arange(numel) // extent | ||
hdx = (torch.arange(numel) - (bdx * extent)) // width | ||
wdx = torch.arange(numel) % width | ||
loss = -input[bdx, flat_target, hdx, wdx] * current_weight | ||
loss = torch.reshape(loss, target.shape) | ||
|
||
if reduction == "none": | ||
return loss | ||
|
@@ -478,15 +434,11 @@ def nll_loss( | |
reduction: str = "mean", | ||
) -> TensorLikeType: | ||
if size_average is not None or reduce is not None: | ||
# TODO raise exception instead of converting value | ||
# TODO: raise exception instead of converting value | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# msg = "size_average and reduce args are deprecated, please use reduction argument." | ||
reduction = _get_string_reduction_arg(size_average, reduce) | ||
|
||
if input.ndim == 1 or input.ndim == 2: | ||
return _nll_loss_1d(input, target, weight, reduction, ignore_index) | ||
elif input.ndim == 4: | ||
return _nll_loss_2d(input, target, weight, reduction, ignore_index) | ||
else: | ||
if input.ndim == 3 or input.ndim > 4: | ||
# input ndim is == 3 or > 4 | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_size = input.shape[0] | ||
num_classes = input.shape[1] | ||
|
@@ -508,10 +460,12 @@ def nll_loss( | |
target = torch.reshape(target, [batch_size, 0, 0]) | ||
|
||
if reduction == "none": | ||
return _nll_loss_2d(input, target, weight, reduction, ignore_index) | ||
return _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
else: | ||
result = _nll_loss_2d(input, target, weight, reduction, ignore_index) | ||
result = _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
return torch.reshape(result, out_size) | ||
rdspring1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment describing what this else branch is for. From a code organization and readability standpoing these branches seem a little odd. Maybe we can explain them better? In particular -- can input be zero or one dimension? If so, how do we interpret that? The documentation for suggests that input should have at least two dimensions. And why are inputs with three or four dimensions special? Finally, prefer putting shorter branches which short-circuit first. That typically lets code have fewer indentation levels:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refactored _nll_loss_nd to handle 1-3 dimensions. If there are more than 3 dimensions, the k-dimension is flattened to create a 3D tensor. The Aten implementation used a 4D case for image inputs.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the 4D case interesting to model here? |
||
return _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
|
||
|
||
# tanhshrink does not use _make_elementwise_unary_reference because it does not support out | ||
|
Uh oh!
There was an error while loading. Please reload this page.