-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[ONNX] Automatically convert dynamic_axes to dynamic_shapes with torch.export.Dim.AUTO #143158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Automatically convert dynamic_axes to dynamic_shapes with torch.export.Dim.AUTO #143158
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143158
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit df32d2f with merge base 8d4926e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
dynamic_shapes[input_name] = { | ||
k: torch.export.Dim(f"{input_name}_dim_{k}", max=99999) for k in axes | ||
} | ||
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be worth checking that k
is a valid value (int in the expected range)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
): | ||
model = GPTJForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj") | ||
|
||
input_ids = torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: are these specific values critical? are can we replace by random ints in a given range (which will take up less space, and clarify things better).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was concerned about flaky, but we will see.
# dpends on the traced model to provide the correct min and max values. | ||
# We set max to 99999 to avoid the constraints violation error with the default int64 max. | ||
# https://github.com/pytorch/pytorch/blob/32f585d9346e316e554c8d9bf7548af9f62141fc/test/export/test_export.py#L687 | ||
# TODO(titaiwang): Add ONNX IR pass to post-process the dynamic axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to include the information returned by AUTO dim analysis in the export-report? For example, if analysis indicates dim1 and dim2 are expected to be equal? Or, if it indicates dim1 is expected to be in range (5, 10)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can a follow-up PR that we try to identify the constraints and store them in metadata of ONNX IR.
@@ -144,8 +147,143 @@ def forward(self, *x): | |||
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True) | |||
onnx_testing.assert_onnx_program(onnx_program) | |||
|
|||
def test_onnx_export_huggingface_llm_models_with_kv_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is meant to house end to end tests for "small" models whose definition being just a few ops. A HF model should be placed elsewhere. Consider creating a file just for testing HF models to make change management easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created test_hf_models_e2e.py for llm testing.
list[str], | ||
] | ||
): | ||
model = transformers.GPTJForCausalLM.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know the size of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the model card. Can you tell? https://huggingface.co/hf-internal-testing/tiny-random-gptj/tree/main
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
@xadupre Could you approve again to see if you have the ownership now? |
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Align to the changes in #143158 Pull Request resolved: #144356 Approved by: https://github.com/justinchuby
With #133620 introducing Dim.AUTO, we can now automatically convert dynamic_axes to dynamic_shapes without specifying min and max. However, exporting still could be crashed when there are same specs shared between inputs and there is no guarantee that the axes will be dynamic (see PR description).
Therefore, afollow-up PR should create a post-processing ONNX side pass toenable the missed dynamic axesrename the dynamic shapes (s0, s1, ...) to dynamic_axes (user setting names).This PR does:
(1) Apply torch.export.Dim.AUTO to dynamic_axes when dynamic_shapes is not provided.
(2) Convert args/kwargs to tuple inputs, which follows the generated dynamic_shapes format to avoid errors during torch.export.export.
(3) Avoid KeyError in _rename_dynamic_shapes_with_model_inputs funtion.
(4) Add real world case of a HF model with kv_cache to test on ONNX exporter.