8000 [primTorch] Implement NLL loss reference by rdspring1 · Pull Request #81128 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[primTorch] Implement NLL loss reference #81128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from

Conversation

rdspring1
Copy link
Contributor

Add Reference:

  • nll_loss

Depends on:

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jul 8, 2022

🔗 Helpful links

❌ 6 New Failures

As of commit 370bc60 (more details on the Dr. CI page):

Expand to see more
  • 6/6 failures introduced in this PR

🕵️ 6 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (1/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T20:40:56.0459425Z RuntimeError: test_ops failed!
2022-07-08T20:40:53.4625808Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCommonCUDA-20220708192405.xml
2022-07-08T20:40:53.9824697Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCompositeComplianceCUDA-20220708192405.xml
2022-07-08T20:40:54.1438402Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestMathBitsCUDA-20220708192405.xml
2022-07-08T20:40:54.2217954Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestFakeTensorNonErroringCUDA-20220708192405.xml
2022-07-08T20:40:54.3220849Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestTagsCUDA-20220708192405.xml
2022-07-08T20:40:56.0453604Z Traceback (most recent call last):
2022-07-08T20:40:56.0454025Z   File "test/run_test.py", line 945, in <module>
2022-07-08T20:40:56.0456820Z     main()
2022-07-08T20:40:56.0457581Z   File "test/run_test.py", line 923, in main
2022-07-08T20:40:56.0459112Z     raise RuntimeError(err_message)
2022-07-08T20:40:56.0459425Z RuntimeError: test_ops failed!
2022-07-08T20:40:56.7955507Z 
2022-07-08T20:40:56.7956225Z real	77m0.930s
2022-07-08T20:40:56.7956822Z user	75m51.609s
2022-07-08T20:40:56.7957080Z sys	1m17.947s
2022-07-08T20:40:56.8009522Z ##[error]Process completed with exit code 1.
2022-07-08T20:40:56.8050673Z Prepare all required actions
2022-07-08T20:40:56.8051101Z Getting action download info
2022-07-08T20:40:57.0010552Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T20:40:57.0010866Z with:
2022-07-08T20:40:57.0011349Z   github-token: ***

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (2/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T19:56:24.2430512Z RuntimeError: test_ops failed!
2022-07-08T19:56:23.4579702Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestCompositeComplianceCPU-20220708193438.xml
2022-07-08T19:56:23.4580250Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestFakeTensorNonErroringCPU-20220708193438.xml
2022-07-08T19:56:23.4580733Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestMathBitsCPU-20220708193438.xml
2022-07-08T19:56:23.4581204Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestRefsOpsInfoCPU-20220708193438.xml
2022-07-08T19:56:23.4581656Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestTagsCPU-20220708193438.xml
2022-07-08T19:56:24.2429078Z Traceback (most recent call last):
2022-07-08T19:56:24.2429421Z   File "run_test.py", line 945, in <module>
2022-07-08T19:56:24.2429816Z     main()
2022-07-08T19:56:24.2430042Z   File "run_test.py", line 923, in main
2022-07-08T19:56:24.2430301Z     raise RuntimeError(err_message)
2022-07-08T19:56:24.2430512Z RuntimeError: test_ops failed!
2022-07-08T19:56:24.5498157Z 
2022-07-08T19:56:24.5498829Z (base) C:\actions-runner\_work\pytorch\pytorch\test>if ERRORLEVEL 1 goto fail 
2022-07-08T19:56:24.5500493Z 
2022-07-08T19:56:24.5500744Z (base) C:\actions-runner\_work\pytorch\pytorch\test>exit /b 1 
2022-07-08T19:56:24.5559168Z ##[error]Process completed with exit code 1.
2022-07-08T19:56:24.5705095Z Prepare all required actions
2022-07-08T19:56:24.57055
8000
11Z Getting action download info
2022-07-08T19:56:24.7653426Z Download action repository 'nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a' (SHA:71062288b76e2b6214ebde0e673ce0de1755740a)
2022-07-08T19:56:24.8893033Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T19:56:24.8893268Z with:

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge) (3/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T19:48:30.5813591Z RuntimeError: test_ops failed!
2022-07-08T19:48:29.3346474Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCompositeComplianceCPU-20220708192707.xml
2022-07-08T19:48:29.6525440Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestFakeTensorNonErroringCPU-20220708192707.xml
2022-07-08T19:48:29.7665588Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestMathBitsCPU-20220708192707.xml
2022-07-08T19:48:29.7789287Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestRefsOpsInfoCPU-20220708192707.xml
2022-07-08T19:48:29.8328756Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestTagsCPU-20220708192707.xml
2022-07-08T19:48:30.5809091Z Traceback (most recent call last):
2022-07-08T19:48:30.5809409Z   File "test/run_test.py", line 945, in <module>
2022-07-08T19:48:30.5810900Z     main()
2022-07-08T19:48:30.5811087Z   File "test/run_test.py", line 923, in main
2022-07-08T19:48:30.5813369Z     raise RuntimeError(err_message)
2022-07-08T19:48:30.5813591Z RuntimeError: test_ops failed!
2022-07-08T19:48:30.8404555Z 
2022-07-08T19:48:30.8404854Z real	21m34.532s
2022-07-08T19:48:30.8405303Z user	57m1.239s
2022-07-08T19:48:30.8405604Z sys	2m14.622s
2022-07-08T19:48:30.8437550Z ##[error]Process completed with exit code 1.
2022-07-08T19:48:30.8471564Z Prepare all required actions
2022-07-08T19:48:30.8471876Z Getting action download info
2022-07-08T19:48:31.0549718Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T19:48:31.0549942Z with:
2022-07-08T19:48:31.0550336Z   github-token: ***

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge) (4/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T19:56:22.9728990Z RuntimeError: test_ops failed!
2022-07-08T19:56:21.8603314Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCompositeComplianceCPU-20220708192702.xml
2022-07-08T19:56:21.9040747Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestFakeTensorNonErroringCPU-20220708192702.xml
2022-07-08T19:56:22.0181635Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestMathBitsCPU-20220708192702.xml
2022-07-08T19:56:22.0304218Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestRefsOpsInfoCPU-20220708192702.xml
2022-07-08T19:56:22.3531252Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestTagsCPU-20220708192702.xml
2022-07-08T19:56:22.9724652Z Traceback (most recent call last):
2022-07-08T19:56:22.9724931Z   File "test/run_test.py", line 945, in <module>
2022-07-08T19:56:22.9726843Z     main()
2022-07-08T19:56:22.9727047Z   File "test/run_test.py", line 923, in main
2022-07-08T19:56:22.9728766Z     raise RuntimeError(err_message)
2022-07-08T19:56:22.9728990Z RuntimeError: test_ops failed!
2022-07-08T19:56:23.2846827Z 
2022-07-08T19:56:23.2847443Z real	29m27.117s
2022-07-08T19:56:23.2847729Z user	80m51.217s
2022-07-08T19:56:23.2847901Z sys	3m37.581s
2022-07-08T19:56:23.2879662Z ##[error]Process completed with exit code 1.
2022-07-08T19:56:23.2915362Z Prepare all required actions
2022-07-08T19:56:23.2915670Z Getting action download info
2022-07-08T19:56:23.4898311Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T19:56:23.4898528Z with:
2022-07-08T19:56:23.4898925Z   github-token: ***

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge) (5/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T19:48:05.0944462Z RuntimeError: test_ops failed!
2022-07-08T19:48:04.1499153Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCompositeComplianceCPU-20220708192659.xml
2022-07-08T19:48:04.1932507Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestFakeTensorNonErroringCPU-20220708192659.xml
2022-07-08T19:48:04.3062902Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestMathBitsCPU-20220708192659.xml
2022-07-08T19:48:04.3185374Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestRefsOpsInfoCPU-20220708192659.xml
2022-07-08T19:48:04.3729409Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestTagsCPU-20220708192659.xml
2022-07-08T19:48:05.0939704Z Traceback (most recent call last):
2022-07-08T19:48:05.0940139Z   File "test/run_test.py", line 945, in <module>
2022-07-08T19:48:05.0941807Z     main()
2022-07-08T19:48:05.0942147Z   File "test/run_test.py", line 923, in main
2022-07-08T19:48:05.0944059Z     raise RuntimeError(err_message)
2022-07-08T19:48:05.0944462Z RuntimeError: test_ops failed!
2022-07-08T19:48:05.3824038Z 
2022-07-08T19:48:05.3824354Z real	21m12.879s
2022-07-08T19:48:05.3824562Z user	55m55.299s
2022-07-08T19:48:05.3824800Z sys	3m21.488s
2022-07-08T19:48:05.3857479Z ##[error]Process completed with exit code 1.
2022-07-08T19:48:05.3904239Z Prepare all required actions
2022-07-08T19:48:05.3904546Z Getting action download info
2022-07-08T19:48:05.6652772Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T19:48:05.6652994Z with:
2022-07-08T19:48:05.6653320Z   github-token: ***

See GitHub Actions build pull / linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (6/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-08T19:54:15.6609434Z RuntimeError: test_ops failed!
2022-07-08T19:54:14.7824832Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCompositeComplianceCPU-20220708193147.xml
2022-07-08T19:54:14.8260945Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestFakeTensorNonErroringCPU-20220708193147.xml
2022-07-08T19:54:14.9391172Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestMathBitsCPU-20220708193147.xml
2022-07-08T19:54:14.9517068Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestRefsOpsInfoCPU-20220708193147.xml
2022-07-08T19:54:15.0055807Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestTagsCPU-20220708193147.xml
2022-07-08T19:54:15.6605369Z Traceback (most recent call last):
2022-07-08T19:54:15.6605669Z   File "test/run_test.py", line 945, in <module>
2022-07-08T19:54:15.6607361Z     main()
2022-07-08T19:54:15.6607594Z   File "test/run_test.py", line 923, in main
2022-07-08T19:54:15.6609193Z     raise RuntimeError(err_message)
2022-07-08T19:54:15.6609434Z RuntimeError: test_ops failed!
2022-07-08T19:54:15.9168465Z 
2022-07-08T19:54:15.9169029Z real	22m34.790s
2022-07-08T19:54:15.9169363Z user	42m36.024s
2022-07-08T19:54:15.9169547Z sys	1m51.022s
2022-07-08T19:54:15.9207580Z ##[error]Process completed with exit code 1.
2022-07-08T19:54:15.9249800Z Prepare all required actions
2022-07-08T19:54:15.9250093Z Getting action download info
2022-07-08T19:54:16.1310483Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-07-08T19:54:16.1310710Z with:
2022-07-08T19:54:16.1311038Z   github-token: ***

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@IvanYashchuk
Copy link
Collaborator

#79820 was merged. Is advanced indexing still a blocker? What exactly doesn't work?

@rdspring1
Copy link
Contributor Author

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 24, 2022
@pytorch-bot
Copy link
pytorch-bot bot commented Sep 25, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81128

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3cd82ab:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@rdspring1 rdspring1 marked this pull request as ready for review September 26, 2022 21:32
else:
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
return torch.reshape(result, out_size)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing what this else branch is for.

From a code organization and readability standpoing these branches seem a little odd. Maybe we can explain them better?

In particular -- can input be zero or one dimension? If so, how do we interpret that? The documentation for suggests that input should have at least two dimensions. And why are inputs with three or four dimensions special?

Finally, prefer putting shorter branches which short-circuit first. That typically lets code have fewer indentation levels:

# shortcircuits if foo because...
if foo:
  return x

# implicit else branch doesn't have to be indented
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored _nll_loss_nd to handle 1-3 dimensions. If there are more than 3 dimensions, the k-dimension is flattened to create a 3D tensor. The Aten implementation used a 4D case for image inputs.

    # The _nll_loss_nd helper function handles the most common cases.
    # ndim == 1 (Single Example)
    #   => Batch Size: 1, Input: (C), Target: ()
    # ndim == 2 (k = 1)
    #   => Batch Size: N, Input: (N, C), Target: (N)
    # ndim == 3 (k > 1)
    #   => Batch Size: N, Input: (N, C, K), Target: (N, K)
    # ndim > 3
    #   => reshape the input and target to the 3-D case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 4D case interesting to model here?

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 3, 2022
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link
linux-foundation-easycla bot commented Oct 3, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

Comment on lines +467 to +470
utils.check(
isinstance(target, FakeTensor) or bool(class_check.item()),
lambda: "A target class is out-of-bounds and not the ignore index.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment this out for now until we have a debug mode for data-dependent checks.

Copy link
Collaborator
@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Let's just update the data-dependent check per @IvanYashchuk's comment

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2022


@register_decomposition(torch.ops.aten.nll_loss)
def nll_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try wrapping with type promotion decorator

@rdspring1
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

Hey @rdspring1.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged module: primTorch open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0