8000 torch.nn.functional.one_hot has inconsistent behavior between eager and torch.compile when num_classes=0 · Issue #146274 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.nn.functional.one_hot has inconsistent behavior between eager and torch.compile when num_classes=0 #146274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
meetmul opened this issue Feb 2, 2025 · 4 comments · May be fixed by #146466
Labels
actionable module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@meetmul
Copy link
meetmul commented Feb 2, 2025

🐛 Describe the bug

When num_classes=0, torch.nn.functional.one_hot will throw Class values must be smaller than num_classes. under eager but outputs empty tensor under torch.compile.

import torch
f = torch.nn.functional.one_hot
a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
num_classes = 0
try:
    torch.nn.functional.one_hot(a,num_classes)
except Exception as e:
    print("Error on eager: ", str(e))
res = torch.compile(torch.nn.functional.one_hot)(a,num_classes)
print("Output under torch.compile: ", res)

Error logs

Error on eager: Class values must be smaller than num_classes.
Output under torch.compile: tensor([], size=(5, 0), dtype=torch.int64)

Versions

[pip3] numpy==1.26.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.13.1
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] numpy 1.26.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] optree 0.13.1 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi

cc @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225

@anijain2305
Copy link
Contributor

cc @bdhirsh seems like aot_eager issue. I think this is decomposition, but I could be wrong.

@anijain2305 anijain2305 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Feb 3, 2025
@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 3, 2025

hmm @anijain2305 it looks like when we run with backend="eager":

(1) dynamo traces out a graph with one_hot in it
(2) when we interpret the graph at runtime, we raise the error.

I think this is probably bad though (dynamo should be raising the error at compile time, not runtime). Since dynamo does stuff like unrolling torch.no_grad context managers, which relies on the invariant that the runtime code will not throw.

@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 3, 2025

I can repro just with fake tensor - looks like the onehot decomp doesn't raise when used with FakeTensorMode:

import torch
num_classes = 0
with torch._subclasses.FakeTensorMode():
    a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
    torch.nn.functional.one_hot(a,num_classes)
a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
torch.nn.functional.one_hot(a,num_classes)

@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 4, 2025

Looks like someone manually wrote a custom "compile-friendly" version in the C++ code here, but we didn't include the error checking in there: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Onehot.cpp#L23

The easiest fix is probably to take the error checking code a few lines down and add it to this code path

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0