8000 [ONNX] Automatically convert dynamic_axes to dynamic_shapes with torch.export.Dim.AUTO by titaiwangms · Pull Request #143158 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Automatically convert dynamic_axes to dynamic_shapes with torch.export.Dim.AUTO #143158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

titaiwangms
Copy link
Collaborator
@titaiwangms titaiwangms commented Dec 13, 2024

With #133620 introducing Dim.AUTO, we can now automatically convert dynamic_axes to dynamic_shapes without specifying min and max. However, exporting still could be crashed when there are same specs shared between inputs and there is no guarantee that the axes will be dynamic (see PR description).

Therefore, a follow-up PR should create a post-processing ONNX side pass to enable the missed dynamic axes rename the dynamic shapes (s0, s1, ...) to dynamic_axes (user setting names).

This PR does:
(1) Apply torch.export.Dim.AUTO to dynamic_axes when dynamic_shapes is not provided.
(2) Convert args/kwargs to tuple inputs, which follows the generated dynamic_shapes format to avoid errors during torch.export.export.
(3) Avoid KeyError in _rename_dynamic_shapes_with_model_inputs funtion.
(4) Add real world case of a HF model with kv_cache to test on ONNX exporter.

Copy link
pytorch-bot bot commented Dec 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143158

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit df32d2f with merge base 8d4926e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Dec 13, 2024
@titaiwangms titaiwangms added the topic: improvements topic category label Dec 13, 2024
dynamic_shapes[input_name] = {
k: torch.export.Dim(f"{input_name}_dim_{k}", max=99999) for k in axes
}
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be worth checking that k is a valid value (int in the expected range)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

):
model = GPTJForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj")

input_ids = torch.tensor(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: are these specific values critical? are can we replace by random ints in a given range (which will take up less space, and clarify things better).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was concerned about flaky, but we will see.

# dpends on the traced model to provide the correct min and max values.
# We set max to 99999 to avoid the constraints violation error with the default int64 max.
# https://github.com/pytorch/pytorch/blob/32f585d9346e316e554c8d9bf7548af9f62141fc/test/export/test_export.py#L687
# TODO(titaiwang): Add ONNX IR pass to post-process the dynamic axes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to include the information returned by AUTO dim analysis in the export-report? For example, if analysis indicates dim1 and dim2 are expected to be equal? Or, if it indicates dim1 is expected to be in range (5, 10)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can a follow-up PR that we try to identify the constraints and store them in metadata of ONNX IR.

@@ -144,8 +147,143 @@ def forward(self, *x):
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)

def test_onnx_export_huggingface_llm_models_with_kv_cache(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is meant to house end to end tests for "small" models whose definition being just a few ops. A HF model should be placed elsewhere. Consider creating a file just for testing HF models to make change management easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created test_hf_models_e2e.py for llm testing.

list[str],
]
):
model = transformers.GPTJForCausalLM.from_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know the size of it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 16, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@titaiwangms
Copy link
Collaborator Author

@xadupre Could you approve again to see if you have the ownership now?

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 17, 2024
@titaiwangms titaiwangms linked an issue Dec 17, 2024 that may be closed by this pull request
@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ONNX] _rename_dynamic_shapes_with_model_inputs is not robust
8 participants
0