-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor #151259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor #151259
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151259
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3cf10f7 with merge base 46ce8f7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…bfloat/float8 arrays
52c6def
to
3f0e0cb
Compare
Bring changes from pytorch/pytorch#151259 to correctly support bfloat16 and float8* types.
@@ -15,17 +16,17 @@ class TorchTensorTest(common_utils.TestCase): | |||
@common_utils.parametrize( | |||
"dtype, np_dtype", | |||
[ | |||
(torch.bfloat16, np.uint16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't undeerstand why this is wrong and still passed the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the test still legit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is changed with the implementation. So they are updated together to reflect the new behavior
@@ -116,15 +116,17 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): | |||
def numpy(self) -> npt.NDArray: | |||
self.raw: torch.Tensor | |||
if self.dtype == ir.DataType.BFLOAT16: | |||
return self.raw.view(torch.uint16).numpy(force=True) | |||
return ( | |||
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this workaround? Looks like it's getting around the issue that numpy does not support bfloat6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Since there is no bfloat16 in numpy, converting a tensor directly from torch would fail. Thus we view it as uint16 first in torch, get the numpy representation, and then re-view it with the ml_dtypes type we get from the onnx ir dtype.numpy() call to get an array ONNX IR expects (which has dtype ml_dtypes.bfloat16. )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And uint16 is closer to bfloat16 in terms of the number of bit? Maybe worth mentioning somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add a comment. Yes they are both 16 bit dtypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytorchbot merge |
Will add comments separately since this is ready to merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#151371) Follow up of #151259 Pull Request resolved: #151371 Approved by: https://github.com/titaiwangms
…151259) Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: pytorch#151259 Approved by: https://github.com/titaiwangms
…pytorch#151371) Follow up of pytorch#151259 Pull Request resolved: pytorch#151371 Approved by: https://github.com/titaiwangms
…151259) Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: pytorch#151259 Approved by: https://github.com/titaiwangms
…pytorch#151371) Follow up of pytorch#151259 Pull Request resolved: pytorch#151371 Approved by: https://github.com/titaiwangms
Split the changes from #151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.