-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Fix one_hot inconsistent errors after compile #146466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146466
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit fa04893 with merge base de68ddc ( NEW FAILURE - The following job has failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c95e6c3
to
27f709b
Compare
@pytorchbot label "topic: not user facing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good if the tests pass but i wonder if the reason why the error checks weren't there was because they aren't compile friendly
Is the behavior of compile should consistent with eager mode about param check? There're some other issues about this kind of difference, like #144183. But if not raise error, users may not be aware about wrong result they get. Thanks! :D |
@zeshengzong can you add a test? The simplest place would probably be in |
(the extra checks here look compile friendly - the decomp itself looks like it is already not very compile friendly, since it induces a h2d sync through the .item() call - but the check itself is just an assert on top of that) |
27f709b
to
c5b5f25
Compare
Test case added, thanks! |
@bdhirsh Hello, is there a way to test |
num_classes = self.max().item().toLong() + 1; | ||
num_classes = self.max().item().toLong() + 1; | ||
} else { | ||
check_num_classes(self, num_classes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zeshengzong the test failures are probably coming from the fact that there are places where you code now calls .item()
where we used to not.
Under compile, we shouldn't try to raise the error here if computing the error requires a h2d sync (via the item()
call). So the first thing you're going to have to do is tweak this condition to only run in the same cases that item()
was used before.
From looking at the code, I think.... we can always do the num_classes should be positive
check, but we can only do the Class values must be smaller than num_classes
check if the num_classes
field has already been computed
26fc162
to
0bfac06
Compare
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
0bfac06
to
fa04893
Compare
Hi, any updates on this fix? |
Fixes #146274
Test Result
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @bdhirsh