-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Open
Labels
module: primTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Excerpt from #81128 that checks if the values are within bounds for nll_loss reference.
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
valid_classes_mask = torch.logical_and(
(flat_target >= 0), (flat_target < num_classes)
)
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
# TODO: This check does not work with FakeTensor inputs
# Explicit cast for class_check to bool; See Issue #78071
utils.check(
isinstance(target, FakeTensor) or bool(class_check.item()),
lambda: "A target class is out-of-bounds and not the ignore index.",
)Versions
Pytorch upstream commit: a4bd89b
Metadata
Metadata
Assignees
Labels
module: primTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module