-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[primTorch] Implement NLL loss reference #81128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
❌ 6 New FailuresAs of commit 370bc60 (more details on the Dr. CI page): Expand to see more
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
#79820 was merged. Is advanced indexing still a blocker? What exactly doesn't work? |
https://github.com/pytorch/pytorch/pull/81128/files#diff-93c7b95139f636278cc494028e322a2c3c3c9ba1e83b2adb35d54ccabed5b47aR431 |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3cd82ab: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
else: | ||
result = _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
return torch.reshape(result, out_size) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment describing what this else branch is for.
From a code organization and readability standpoing these branches seem a little odd. Maybe we can explain them better?
In particular -- can input be zero or one dimension? If so, how do we interpret that? The documentation for suggests that input should have at least two dimensions. And why are inputs with three or four dimensions special?
Finally, prefer putting shorter branches which short-circuit first. That typically lets code have fewer indentation levels:
# shortcircuits if foo because...
if foo:
return x
# implicit else branch doesn't have to be indented
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored _nll_loss_nd to handle 1-3 dimensions. If there are more than 3 dimensions, the k-dimension is flattened to create a 3D tensor. The Aten implementation used a 4D case for image inputs.
# The _nll_loss_nd helper function handles the most common cases.
# ndim == 1 (Single Example)
# => Batch Size: 1, Input: (C), Target: ()
# ndim == 2 (k = 1)
# => Batch Size: N, Input: (N, C), Target: (N)
# ndim == 3 (k > 1)
# => Batch Size: N, Input: (N, C, K), Target: (N, K)
# ndim > 3
# => reshape the input and target to the 3-D case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the 4D case interesting to model here?
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
utils.check( | ||
isinstance(target, FakeTensor) or bool(class_check.item()), | ||
lambda: "A target class is out-of-bounds and not the ignore index.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment this out for now until we have a debug mode for data-dependent checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Let's just update the data-dependent check per @IvanYashchuk's comment
|
||
|
||
@register_decomposition(torch.ops.aten.nll_loss) | ||
def nll_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try wrapping with type promotion decorator
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @rdspring1. |
Add Reference:
Depends on: