8000 [ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor by justinchuby · Pull Request #151259 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor #151259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

justinchuby
Copy link
Collaborator

Split the changes from #151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.

Copy link
pytorch-bot bot commented Apr 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151259

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3cf10f7 with merge base 46ce8f7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Apr 14, 2025
@justinchuby justinchuby added module: onnx Related to torch.onnx topic: bug fixes topic category labels Apr 14, 2025
@justinchuby justinchuby force-pushed the justinchu/fix-bfloat16-optimize branch from 52c6def to 3f0e0cb Compare April 14, 2025 19:50
@justinchuby justinchuby added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 15, 2025
justinchuby added a commit to microsoft/onnxscript that referenced this pull request Apr 15, 2025
Bring changes from pytorch/pytorch#151259 to
correctly support bfloat16 and float8* types.
@@ -15,17 +16,17 @@ class TorchTensorTest(common_utils.TestCase):
@common_utils.parametrize(
"dtype, np_dtype",
[
(torch.bfloat16, np.uint16),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't undeerstand why this is wrong and still passed the test?

Copy link
8000
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the test still legit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is changed with the implementation. So they are updated together to reflect the new behavior

@@ -116,15 +116,17 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None):
def numpy(self) -> npt.NDArray:
self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True)
return (
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this workaround? Looks like it's getting around the issue that numpy does not support bfloat6?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Since there is no bfloat16 in numpy, converting a tensor directly from torch would fail. Thus we view it as uint16 first in torch, get the numpy representation, and then re-view it with the ml_dtypes type we get from the onnx ir dtype.numpy() call to get an array ONNX IR expects (which has dtype ml_dtypes.bfloat16. )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And uint16 is closer to bfloat16 in terms of the number of bit? Maybe worth mentioning somewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a comment. Yes they are both 16 bit dtypes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby
Copy link
Collaborator Author

@pytorchbot merge

@justinchuby
Copy link
Collaborator Author

Will add comments separately since this is ready to merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
…151259)

Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.
Pull Request resolved: pytorch#151259
Approved by: https://github.com/titaiwangms
timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…151259)

Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.
Pull Request resolved: pytorch#151259
Approved by: https://github.com/titaiwangms
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0