8000 [primTorch] Need to update data-dependent check policy · Issue #85834 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[primTorch] Need to update data-dependent check policy #85834

@rdspring1

Description

@rdspring1

🐛 Describe the bug

Excerpt from #81128 that checks if the values are within bounds for nll_loss reference.

    num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
    valid_classes_mask = torch.logical_and(
        (flat_target >= 0), (flat_target < num_classes)
    )
    class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))

    # TODO: This check does not work with FakeTensor inputs
    # Explicit cast for class_check to bool; See Issue #78071
    utils.check(
        isinstance(target, FakeTensor) or bool(class_check.item()),
        lambda: "A target class is out-of-bounds and not the ignore index.",
    )

Versions

Pytorch upstream commit: a4bd89b

cc @ezyang @mruberry @ngimel @lezcano @fdrocha

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: primTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0