8000 [ONNX] Support float4 by justinchuby · Pull Request #151069 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Support float4 #151069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

justinchuby
Copy link
Collaborator
@justinchuby justinchuby commented Apr 11, 2025
  • Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well)
  • The shape of the type is set according to add torch.float4_e2m1fn_x2 to PyTorch #148791 8000 (comment) (added last dim with size 2)
  • Use ml_dtypes types when converting to numpy for consistency with ONNX IR

Fix #150202

Copy link
pytorch-bot bot commented Apr 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151069

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 016340a with merge base 3a90fd4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Apr 11, 2025
@justinchuby justinchuby requested a review from Copilot April 11, 2025 00:20
@justinchuby justinchuby added module: onnx Related to torch.onnx topic: new features topic category labels Apr 11, 2025
Copy link
Contributor
@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (2)

torch/onnx/_internal/exporter/_type_casting.py:8

  • The FIXME indicates that the expected shape for unpacked float4x2 tensors is not fully understood. Please add detailed documentation or tests to clarify the intended behavior once determined.
# FIXME: Figure out what the shape really means

test/onnx/exporter/test_core.py:79

  • [nitpick] Consider adding tests using multi-element tensors to better validate the unpacking and conversion behavior of the float4 implementation.
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2))

Copy link
Contributor
@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (2)

torch/onnx/_internal/exporter/_type_casting.py:10

  • The unpack_float4x2_as_uint8 function doubles the number of elements from the input tensor, which may conflict with the expected byte size when the result is later viewed as a FLOAT4. Please verify that this doubling combined with the custom numpy dtype conversion produces the intended storage size.
result_size = tensor.numel() * 2

test/onnx/exporter/test_core.py:82

  • The test expects a one-byte representation (b"\x01") for a tensor converted from FLOAT4, yet the unpacking function generates an array with two elements per tensor element before the final view conversion. Confirm that the custom numpy dtype for FLOAT4 effectively consolidates the two unpacked values into a single byte.
self.assertEqual(tensor.tobytes(), b"\x01")

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 11, 2025
@justinchuby justinchuby added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 13, 2025
@titaiwangms
Copy link
Collaborator

Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well)

Does this mean that this PR allows exporter to generate the model with IR version 10 and float4? Is that in ONNX spec? It looks like this is a pre-realease, but it's available in the exporter already if the PR is merged.

@justinchuby
Copy link
Collaborator Author
justinchuby commented Apr 14, 2025

Does this mean that this PR allows exporter to generate the model with IR version 10 and float4? Is that in ONNX spec?

Admittedly the model will not conform to the spec if ir-version is 10 and has float4. We need to use ir version 11 for this. But this change still allows users to get some model they can work with. In a follow up I can create a warning mentioning the caveats if you agree with the idea.

This PR also has a side effect of addressing microsoft/onnxscript#2187. I will isolate it in a separate PR.

@justinchuby justinchuby marked this pull request as draft April 14, 2025 19:39
pytorchmergebot pushed a commit that referenced this pull request Apr 15, 2025
Split the changes from #151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.
Pull Request resolved: #151259
Approved by: https://github.com/titaiwangms
@justinchuby justinchuby marked this pull request as ready for review April 16, 2025 03:41
@justinchuby justinchuby requested a review from Copilot April 16, 2025 03:49
Copy link
Contributor
@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
…151259)

Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.
Pull Request resolved: pytorch#151259
Approved by: https://github.com/titaiwangms
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…151259)

Split the changes from pytorch#151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected.
Pull Request resolved: pytorch#151259
Approved by: https://github.com/titaiwangms
@justinchuby
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/151069/head returned non-zero exit code 1

Rebasing (1/6)
Auto-merging torch/onnx/_internal/exporter/_core.py
CONFLICT (content): Merge conflict in torch/onnx/_internal/exporter/_core.py
Auto-merging torch/onnx/_internal/exporter/_dispatching.py
error: could not apply 1aaa93084f7... [ONNX] Support float4
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 1aaa93084f7... [ONNX] Support float4

Raised by https://github.com/pytorch/pytorch/actions/runs/15051848411

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ONNX] Support float4
5 participants
0