10000 [export] Refactor pt2 save/load by angelayi · Pull Request #152495 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
< 8000 div id="repository-details-container" class="flex-shrink-0" data-turbo-replace style="max-width: 70%;">

[export] Refactor pt2 save/load #152495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[export] Refactor pt2 save/load #152495

wants to merge 1 commit into from

Conversation

angelayi
Copy link
Contributor
@angelayi angelayi commented Apr 30, 2025

Refactor the pt2 archive saving to consolidate the format of torch.export.save and torch._inductor.package.package_aoti.

This PR adds the following functions, which torch.export.save and AOTI packaging calls into:

package_pt2(
    f: FileLike,
    *,
    exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]] = None,
    aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None,
    extra_files: Optional[dict[str, Any]] = None,
) -> FileLike

@dataclass
class PT2ArchiveContents:
    exported_programs: dict[str, ExportedProgram]
    aoti_runners: dict[str, AOTICompiledModel]
    extra_files: dict[str, Any]

load_pt2(f: FileLike) -> PT2ArchiveContents

Power users directly call into these APIs if they want to bundle multiple exported programs, aoti files, or extra metadata.

This is how the pt2 archive looks like (spec):

├── archive_format
├── version
├── .data
├── data
│   ├── aotinductor
│   │   └── model1
│   │       ├── model1.cpp
│   │       ├── model1.so  # currently AOTI automatically moves weights in here, TODO to move it out
│   │       ├── cg7domx3woam3nnliwud7yvtcencqctxkvvcafuriladwxw4nfiv.cubin
│   │       └── cubaaxppb6xmuqdm4bej55h2pftbce3bjyyvljxbtdfuolmv45ex.cubin
│   ├── weights
│   │  ├── model1.pt  # TODO to dedup weights between model1/model2
│   │  └── model2.pt
│   └── constants
│   │  ├── model1.pt  # TODO to dedup weights between model1/model2
│   │  └── model2.pt
│   └── sample_inputs
│      ├── model1.pt  # TODO to dedup weights between model1/model2
│      └── model2.pt
├── extra
│   └── user_metadata.txt
└── models
    ├── model1.json
    └── model2.json

Future todos:

  • unbundle the weights -- instead of .pt, we can use bin files, which will also allow us to dedup weights if we store multiple models
  • update aoti_compile_and_package to also save the exported program
  • integrate TNR with this packaging flow

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Apr 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152495

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 7 Unrelated Failures

As of commit 98cdf80 with merge base ce317cd (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@angelayi angelayi force-pushed the angelayi/export_save branch from 75ee02b to 98cdf80 Compare May 1, 2025 16:20
@angelayi angelayi marked this pull request as ready for review May 1, 2025 16:46
@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 1, 2025
@angelayi angelayi requested a review from SherlockNoMad May 1, 2025 22:25
@isaaccorley
Copy link

@angelayi this is perfect! Can't wait to use this

Copy link
Contributor
@desertfire desertfire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI test failure is real.

@@ -1876,7 +1875,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
magic_number = 0
else:
magic_number = cast(
int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item()
"int", torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

*,
expected_opset_version: Optional[dict[str, int]] = None,
run_single_threaded: bool = False,
num_runners: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to pick up changes in #152093

@desertfire desertfire self-requested a review May 13, 2025 14:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0