10000 Initial nll_loss implementation · pytorch/pytorch@2256d36 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2256d36

Browse files
committed
Initial nll_loss implementation
1 parent 79a502f commit 2256d36

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"hinge_embedding_loss",
3434
"margin_ranking_loss",
3535
"mish",
36+
"nll_loss",
3637
"relu",
3738
"selu",
3839
"softplus",
@@ -349,6 +350,173 @@ def hinge_embedding_loss(
349350
return _apply_loss_reduction(loss, reduction)
350351

351352

353+
def _get_string_reduction_arg(
354+
size_average: Optional[bool], reduce: Optional[bool]
355+
) -> str:
356+
if size_average is None:
357+
size_average = True
358+
if reduce is None:
359+
reduce = True
360+
if size_average and reduce:
361+
ret = "mean"
362+
elif reduce:
363+
ret = "sum"
364+
else:
365+
ret = "none"
366+
return ret
367+
368+
369+
def _nll_loss_2d(
370+
input: TensorLikeType,
371+
target: TensorLikeType,
372+
weight: Optional[TensorLikeType],
373+
reduction: str,
374+
ignore_index: int,
375+
) -> TensorLikeType:
376+
if input.ndim != 4:
377+
msg = "Expected 4 dimensions for input but got {}."
378+
raise ValueError(msg.format(input.ndim))
379+
380+
if target.ndim != 3:
381+
msg = "Expected 3 dimensions for input but got {}."
382+
raise ValueError(msg.format(input.ndim))
383+
384+
_check_reduction_value(reduction)
385+
386+
current_target = torch.reshape(target, [-1])
387+
ignore_classes_mask = torch.eq(current_target, ignore_index)
388+
if weight is None:
389+
current_weight = torch.where(
390+
ignore_classes_mask,
391+
torch.scalar_tensor(0, dtype=input.dtype, device=input.device),
392+
torch.scalar_tensor(1, dtype=input.dtype, device=input.device),
393+
)
394+
else:
395+
ignore_class_weight = torch.scalar_tensor(
396+
0, dtype=input.dtype, device=input.device
397+
).expand_as(current_target)
398+
current_weight = torch.where(
399+
ignore_classes_mask, ignore_class_weight, weight[current_target]
400+
)
401+
402+
batch_size = input.shape[0]
403+
height = input.shape[2]
404+
width = input.shape[3]
405+
extent = height * width
406+
numel = batch_size * extent
407+
408+
bdx = torch.arange(numel) // extent
409+
hdx = (torch.arange(numel) - (bdx * extent)) // width
410+
wdx = torch.arange(numel) % width
411+
412+
loss = -input[bdx, current_target, hdx, wdx] * current_weight
413+
loss = torch.reshape(loss, target.shape)
414+
415+
if reduction == "none":
416+
return loss
417+
elif reduction == "sum":
418+
return torch.sum(loss)
419+
else:
420+
return torch.sum(loss) / torch.sum(current_weight)
421+
422+
423+
def _nll_loss_1d(
424+
input: TensorLikeType,
425+
target: TensorLikeType,
426+
weight: Optional[TensorLikeType],
427+
reduction: str,
428+
ignore_index: int,
429+
) -> TensorLikeType:
430+
if input.ndim < 1:
431+
msg = "Expected 1 or more dimension for input but got {}."
432+
raise ValueError(msg.format(input.ndim))
433+
434+
if input.ndim != 1 and input.shape[0] != target.shape[0]:
435+
msg = "Expected input batch size ({}) to match target batch size ({})."
436+
raise ValueError(msg.format(input.shape[0], target.shape[0]))
437+
438+
_check_reduction_value(reduction)
439+
440+
if target.ndim <= 1:
441+
current_target = < B41A span class=pl-s1>target
442+
else:
443+
current_target = target[:, 0]
444+
445+
ignore_classes_mask = torch.eq(current_target, ignore_index)
446+
if weight is None:
447+
current_weight = torch.where(
448+
ignore_classes_mask,
449+
torch.scalar_tensor(0, dtype=input.dtype, device=input.device),
450+
torch.scalar_tensor(1, dtype=input.dtype, device=input.device),
451+
)
452+
else:
453+
ignore_class_weight = torch.scalar_tensor(
454+
0, dtype=input.dtype, device=input.device
455+
).expand_as(current_target)
456+
current_weight = torch.where(
457+
ignore_classes_mask, ignore_class_weight, weight[current_target]
458+
)
459+
460+
batch_size = input.shape[0]
461+
if input.ndim == 1:
462+
loss = -input[current_target] * current_weight
463+
else:
464+
loss = -input[torch.arange(batch_size), current_target] * current_weight
465+
466+
if reduction == "none":
467+
return loss
468+
elif reduction == "sum":
469+
return torch.sum(loss)
470+
else:
471+
return torch.sum(loss) / torch.sum(current_weight)
472+
473+
474+
def nll_loss(
475+
input: TensorLikeType,
476+
target: TensorLikeType,
477+
weight: Optional[TensorLikeType] = None,
478+
size_average: Optional[bool] = None,
479+
ignore_index: int = -100,
480+
reduce: Optional[bool] = None,
481+
reduction: str = "mean",
482+
) -> TensorLikeType:
483+
if size_average is not None or reduce is not None:
484+
# TODO raise exception instead of converting value
485+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
486+
reduction = _get_string_reduction_arg(size_average, reduce)
487+
488+
if input.ndim == 1 or input.ndim == 2:
489+
return _nll_loss_1d(input, target, weight, reduction, ignore_index)
490+
elif input.ndim == 4:
491+
return _nll_loss_2d(input, target, weight, reduction, ignore_index)
492+
else:
493+
# input ndim is == 3 or > 4
494+
batch_size = input.shape[0]
495+
num_classes = input.shape[1]
496+
out_size = [batch_size] + list(input.shape[2:])
497+
498+
if target.shape[1:] != input.shape[2:]:
499+
msg = "Expected target size {} but got {}"
500+
raise ValueError(msg.format(out_size, target.shape))
501+
502+
# support empty batches, see #15870
503+
if input.numel() > 0:
504+
input = torch.reshape(input, [batch_size, num_classes, 1, -1])
505+
else:
506+
input = torch.reshape(input, [batch_size, num_classes, 0, 0])
507+
508+
if target.numel() > 0:
509+
target = torch.reshape(target, [batch_size, 1, -1])
510+
else:
511+
target = torch.reshape(target, [batch_size, 0, 0])
512+
513+
if reduction == "none":
514+
return _nll_loss_2d(input, target, weight, reduction, ignore_index)
515+
else:
516+
result = _nll_loss_2d(input, target, weight, reduction, ignore_index)
517+
return torch.reshape(result, out_size)
518+
519+
352520
# tanhshrink does not use _make_elementwise_unary_reference because it does not support out
353521
@elementwise_unary_scalar_wrapper
354522
@elementwise_type_promotion_wrapper(

torch/testing/_internal/common_methods_invocations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20505,6 +20505,17 @@ def __init__(
2050520505
torch_opinfo_name="nn.functional.hinge_embedding_loss",
2050620506
supports_nvfuser=False,
2050720507
),
20508+
PythonRefInfo(
20509+
"_refs.nn.functional.nll_loss",
20510+
torch_opinfo_name="nn.functional.nll_loss",
20511+
supports_nvfuser=False,
20512+
skips=(
20513+
# RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
20514+
DecorateInfo(
20515+
unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
20516+
),
20517+
),
20518+
),
2050820519
ElementwiseUnaryPythonRefInfo(
2050920520
"_refs.nn.functional.tanhshrink",
2051020521
torch_opinfo_name="nn.functional.tanhshrink",

0 commit comments

Comments
 (0)
0