-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[ONNX] Support float4 #151069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ONNX] Support float4 #151069
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151069
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 016340a with merge base 3a90fd4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
torch/onnx/_internal/exporter/_type_casting.py:8
- The FIXME indicates that the expected shape for unpacked float4x2 tensors is not fully understood. Please add detailed documentation or tests to clarify the intended behavior once determined.
# FIXME: Figure out what the shape really means
test/onnx/exporter/test_core.py:79
- [nitpick] Consider adding tests using multi-element tensors to better validate the unpacking and conversion behavior of the float4 implementation.
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
torch/onnx/_internal/exporter/_type_casting.py:10
- The unpack_float4x2_as_uint8 function doubles the number of elements from the input tensor, which may conflict with the expected byte size when the result is later viewed as a FLOAT4. Please verify that this doubling combined with the custom numpy dtype conversion produces the intended storage size.
result_size = tensor.numel() * 2
test/onnx/exporter/test_core.py:82
- The test expects a one-byte representation (b"\x01") for a tensor converted from FLOAT4, yet the unpacking function generates an array with two elements per tensor element before the final view conversion. Confirm that the custom numpy dtype for FLOAT4 effectively consolidates the two unpacked values into a single byte.
self.assertEqual(tensor.tobytes(), b"\x01")
f218a47
to
6076f8e
Compare
Does this mean that this PR allows exporter to generate the model with IR version 10 and float4? Is that in ONNX spec? It looks like this is a pre-realease, but it's available in the exporter already if the PR is merged. |
Admittedly the model will not conform to the spec if ir-version is 10 and has float4. We need to use ir version 11 for this. But this change still allows users to get some model they can work with. In a follow up I can create a warning mentioning the caveats if you agree with the idea. This PR also has a side effect of addressing microsoft/onnxscript#2187. I will isolate it in a separate PR. |
Split the changes from #151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: #151259 Approved by: https://github.com/titaiwangms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
…151259) Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: pytorch#151259 Approved by: https://github.com/titaiwangms
…151259) Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: pytorch#151259 Approved by: https://github.com/titaiwangms
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/15051848411 |
torch.float4_e2m1fn_x2
to PyTorch #148791 8000 (comment) (added last dim with size 2)Fix #150202