8000 Small improvements to NJT matrix multiplies by michael-diggin · Pull Request #146405 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Small improvements to NJT matrix multiplies #146405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

michael-diggin
Copy link
Contributor

Fixes #146404

Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor.
This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e
A@B where A isn't nested but B is.

The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested.

Unit tests are also added.

< 8000 div class="AvatarStack-body" > @michael-diggin
Copy link
pytorch-bot bot commented Feb 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146405

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 2b2d432 with merge base 57b1fc3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@michael-diggin
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

Copy link
Contributor
@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Added some small comments

Btw, there's some existing infra around testing that we can use that will automatically test forward, backward, compile, etc. (See the tests under TestNestedTensorOpInfo)

To test the new inputs that you are supporting here, instead of adding new one-off tests here, you can add the inputs that you would like to test to sample_inputs_matmul in torch/testing/_internal/opinfo/definitions/nested.py

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 5, 2025
@michael-diggin
Copy link
Contributor Author

Thanks for the PR! Added some small comments

Btw, there's some existing infra around testing that we can use that will automatically test forward, backward, compile, etc. (See the tests under TestNestedTensorOpInfo)

To test the new inputs that you are supporting here, instead of adding new one-off tests here, you can add the inputs that you would like to test to sample_inputs_matmul in torch/testing/_internal/opinfo/definitions/nested.py

Thanks for the review @soulitzer! I've made changes based on the comments, and updated the tests to use the existing test infra/setup which is much nicer.

Copy link
Contributor
@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for quick updates! Just have another nit

@michael-diggin
Copy link
Contributor Author

Thanks for quick updates! Just have another nit

Thanks for the quick review @soulitzer! Just committed that small change. Could you rerun the workflows when you get a chance? As I think that's needed before I can merge.

Copy link
Contributor
@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for the PR! looks solid on my end, thanks for updating the generated sample inputs to cover this case :)

@jbschlosser jbschlosser added topic: bug fixes topic category release notes: nested tensor Changes that have a direct impact on nested tensors and removed topic: not user facing topic category labels Feb 5, 2025
@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nested tensor Changes that have a direct impact on nested tensors topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can't back prop through NJT matrix multiplication when output is strided tensor
5 participants
0