8000 Refactor class weight selection · pytorch/pytorch@e0554f2 · GitHub
[go: up one dir, main page]

Skip to content

Commit e0554f2

Browse files
committed
Refactor class weight selection
1 parent c71d746 commit e0554f2

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -451,22 +451,16 @@ def _nll_loss_nd(
451451
raise ValueError("Target class is out-of-bounds and not ignore index")
452452
"""
453453

454-
# TODO Add comment for expansion
455-
if weight is None:
456-
default_class_weight = torch.scalar_tensor(
457-
1, dtype=input.dtype, device=input.device
458-
)
459-
current_weight = torch.where(
460-
ignore_classes_mask,
461-
ignore_class_weight,
462-
default_class_weight,
463-
)
464-
else:
465-
current_weight = torch.where(
466-
ignore_classes_mask,
467-
ignore_class_weight.expand_as(flat_target),
468-
weight[flat_target],
469-
)
454+
class_weight = (
455+
torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
456+
if weight is None
457+
else weight[flat_target]
458+
)
459+
current_weight = torch.where(
460+
ignore_classes_mask,
461+
ignore_class_weight,
462+
class_weight,
463+
)
470464

471465
# TODO Add comments for each case
472466
if input.ndim == 1:

0 commit comments

Comments
 (0)
0