8000 Fix one_hot inconsistent errors after compile by zeshengzong · Pull Request #146466 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix one_hot inconsistent errors after compile #146466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

zeshengzong
Copy link
Contributor
@zeshengzong zeshengzong commented Feb 5, 2025

Fixes #146274

Test Result

>>> import torch
>>> f = torch.nn.functional.one_hot
>>> a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
>>> num_classes = 0
>>> torch.nn.functional.one_hot(a,num_classes)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Class values must be smaller than num_classes.

>>> torch.compile(torch.nn.functional.one_hot)(a,num_classes)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zong/code/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_dynamo/external_utils.py", line 48, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: Class values must be smaller than num_classes.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @bdhirsh

Copy link
pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146466

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit fa04893 with merge base de68ddc (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@zeshengzong zeshengzong marked this pull request as ready for review February 5, 2025 06:26
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 5, 2025
Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good if the tests pass but i wonder if the reason why the error checks weren't there was because they aren't compile friendly

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 8, 2025
@cpuhrsch cpuhrsch requested a review from bdhirsh February 8, 2025 01:37
@zeshengzong
Copy link
Contributor Author

Is the behavior of compile should consistent with eager mode about param check? There're some other issues about this kind of difference, like #144183. But if not raise error, users may not be aware about wrong result they get. Thanks! :D

@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 10, 2025

@zeshengzong can you add a test? The simplest place would probably be in test_repros.py - you can include a basic test that tries to compile one_hot with an invalid num_classes and assert that it raises an error

@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 10, 2025

(the extra checks here look compile friendly - the decomp itself looks like it is already not very compile friendly, since it induces a h2d sync through the .item() call - but the check itself is just an assert on top of that)

@zeshengzong
Copy link
Contributor Author

Test case added, thanks!

@zeshengzong
Copy link
Contributor Author

@bdhirsh Hello, is there a way to test FakeTensor in aten? Might need skip checks about it. Thanks!

num_classes = self.max().item().toLong() + 1;
num_classes = self.max().item().toLong() + 1;
} else {
check_num_classes(self, num_classes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zeshengzong the test failures are probably coming from the fact that there are places where you code now calls .item() where we used to not.

Under compile, we shouldn't try to raise the error here if computing the error requires a h2d sync (via the item() call). So the first thing you're going to have to do is tweak this condition to only run in the same cases that item() was used before.

From looking at the code, I think.... we can always do the num_classes should be positive check, but we can only do the Class values must be smaller than num_classes check if the num_classes field has already been computed

@zeshengzong zeshengzong force-pushed the fix/aten/one_hot branch 2 times, most recently from 26fc162 to 0bfac06 Compare February 14, 2025 07:57
@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix/aten/one_hot onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout fix/aten/one_hot && git pull --rebase)

@shaoyuyoung
Copy link
Contributor

Hi, any updates on this fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.nn.functional.one_hot has inconsistent behavior between eager and torch.compile when num_classes=0
7 participants
0