8000 Add class check · pytorch/pytorch@590866b · GitHub
[go: up one dir, main page]

Skip to content

Commit 590866b

Browse files
committed
Add class check
1 parent 1a635cd commit 590866b

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Union
22

33
import torch
4-
54
import torch._prims as prims
65
import torch._prims_common as utils
76
import torch._refs as refs
@@ -24,6 +23,8 @@
2423
_make_elementwise_unary_reference,
2524
)
2625

26+
from torch._subclasses.fake_tensor import FakeTensor
27+
2728
__all__ = [
2829
"celu",
2930
"dropout",
@@ -442,13 +443,18 @@ def _nll_loss_nd(
442443
flat_target = torch.flatten(target)
443444
ignore_classes_mask = torch.eq(flat_target, ignore_index)
444445

445-
# TODO: This check does not work with FakeTensor inputs
446-
"""
447446
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
448-
valid_classes_mask = torch.logical_and((flat_target >= 0), (flat_target < num_classes))
449-
if not torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)):
450-
raise ValueError("Target class is out-of-bounds and not ignore index")
451-
"""
447+
valid_classes_mask = torch.logical_and(
448+
(flat_target >= 0), (flat_target < num_classes)
449+
)
450+
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
451+
452+
# TODO: This check does not work with FakeTensor inputs
453+
# Explicit cast for class_check to bool; See Issue #78071
454+
utils.check(
455+
isinstance(target, FakeTensor) or bool(class_check.item()),
456+
lambda: "A target class is out-of-bounds and not the ignore index.",
457+
)
452458

453459
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
454460
class_weight = (

0 commit comments

Comments
 (0)
0