|
1 | 1 | from typing import Optional, Union
|
2 | 2 |
|
3 | 3 | import torch
|
4 |
| - |
5 | 4 | import torch._prims as prims
|
6 | 5 | import torch._prims_common as utils
|
7 | 6 | import torch._refs as refs
|
|
24 | 23 | _make_elementwise_unary_reference,
|
25 | 24 | )
|
26 | 25 |
|
| 26 | +from torch._subclasses.fake_tensor import FakeTensor |
| 27 | + |
27 | 28 | __all__ = [
|
28 | 29 | "celu",
|
29 | 30 | "dropout",
|
@@ -442,13 +443,18 @@ def _nll_loss_nd(
|
442 | 443 | flat_target = torch.flatten(target)
|
443 | 444 | ignore_classes_mask = torch.eq(flat_target, ignore_index)
|
444 | 445 |
|
445 |
| - # TODO: This check does not work with FakeTensor inputs |
446 |
| - """ |
447 | 446 | num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
|
448 |
| - valid_classes_mask = torch.logical_and((flat_target >= 0), (flat_target < num_classes)) |
449 |
| - if not torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)): |
450 |
| - raise ValueError("Target class is out-of-bounds and not ignore index") |
451 |
| - """ |
| 447 | + valid_classes_mask = torch.logical_and( |
| 448 | + (flat_target >= 0), (flat_target < num_classes) |
| 449 | + ) |
| 450 | + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) |
| 451 | + |
| 452 | + # TODO: This check does not work with FakeTensor inputs |
| 453 | + # Explicit cast for class_check to bool; See Issue #78071 |
| 454 | + utils.check( |
| 455 | + isinstance(target, FakeTensor) or bool(class_check.item()), |
| 456 | + lambda: "A target class is out-of-bounds and not the ignore index.", |
| 457 | + ) |
452 | 458 |
|
453 | 459 | ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
|
454 | 460 | class_weight = (
|
|
0 commit comments