8000 [dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims by pianpwk · Pull Request #150127 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims #150127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 35 commits into from

Conversation

pianpwk
Copy link
Contributor
@pianpwk pianpwk commented Mar 27, 2025

For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export.
For infer_size: assumes if user passes us an unbacked, it's probably not -1

Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR

Copy link
pytorch-bot bot commented Mar 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150127

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2bca726 with merge base bc6c0bc (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ep = export(f, (torch.tensor(6),))
ep.module()(torch.tensor(6))
with self.assertRaisesRegex(RuntimeError, r"Runtime assertion failed for .* u0 .* 6"):
ep.module()(torch.tensor(5))
Copy link
Contributor Author
@pianpwk pianpwk Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exports fine now without check_is_size, we just assume it's >= 0 and it later specializes to 6.

y = torch.ones(5, 5)
ep = export(Foo(), (x, y))
ep.module()(x, y)
ep.module()(torch.tensor(0), y)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test we can reshape [5+u0, 5] -> [5, u0 + 1] with a wildcard dim

error_type,
"Could not reshape a tensor with shape .*u0, u1.* as a tensor with shape .*u2, u3.*",
):
export(Foo(), (xs,))
Copy link
Contributor Author
@pianpwk pianpwk Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very similar to the M_v0 test I deleted above, just (u0, u1) -> (u2, u3).

Adding the checks below work, and we specialize to the type of reshape that happens. Sadly more complicated relations are hard to torch._check into success (e.g. if the tensor was size [4, 6, 8, 3]). I would like to blame symbolic shapes axioms for that...

):
export(N(), (t,), strict=strict)
export(N(), (t,))
Copy link
Contributor Author
@pianpwk pianpwk Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're trying to reshape (u0, u1) -> (u0, u2). I thought the added torch._check in _reshape_view_helper that does u0*u1 == u0*u2 would help this succeed, but seems like symbolic shapes isn't strong enough for this (can't figure out that u1 == u2).

M_v0, M_v1, M_v2 all end up with the same error here.

torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
torch._check(items[2] == r.shape[1])
Copy link
Contributor Author
@pianpwk pianpwk Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now all you need is this check, the others are irrelevant

xs = torch.tensor([4, 4])
ep = export(Foov3(), (xs,))
ep.module()(xs)
ep.module()(torch.tensor([5, 5]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also works, because there's only one way to do this, and via the added torch._check(a.numel() == shape_numel) we assume in good faith the op is valid.

Maybe we shouldn't allow this though, it could be a typo from the user, and downstream this check could lead to complications that are hard to debug.

@pianpwk pianpwk changed the title [dynamic shapes] guard_or_false for _reshape_view_helper [dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims Mar 27, 2025
@pianpwk pianpwk marked this pull request as ready for review March 28, 2025 00:06
@malfet malfet dismissed laithsakka’s stale review April 18, 2025 22:31

This change broke number of export tests, can we please be a bit more careful with the reviews

@laithsakka
Copy link
Contributor

failure seems related but fix seems easy .

6:00.7352045Z - Eq(9380*u1, 0)
2025-04-19T01:16:00.7352236Z ?    ^^^^
2025-04-19T01:16:00.7352405Z + Eq(2*u1, 10)
``` just verify its expected

@pianpwk
Copy link
Contributor Author
pianpwk commented Apr 22, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor
malfet commented Apr 22, 2025

@pytorchbot revert -m "Caused TestDynamoTimed.test_dynamo_timed to fail on macOS, see https://github.com/pytorch/pytorch/actions/runs/14584536979/job/40908019050" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Apr 22, 2025
…ls._infer_size for wildcard dims (#150127)"

This reverts commit a02eae8.

Reverted #150127 on behalf of https://github.com/malfet due to Caused TestDynamoTimed.test_dynamo_timed to fail on macOS, see https://github.com/pytorch/pytorch/actions/runs/14584536979/job/40908019050 ([comment](#150127 (comment)))
@pytorchmergebot
Copy link
Collaborator

@pianpwk your PR has been successfully reverted.

@pianpwk
Copy link
Contributor Author
pianpwk commented Apr 22, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/150127/head returned non-zero exit code 1

Rebasing (1/28)
Rebasing (2/28)
Rebasing (3/28)
Rebasing (4/28)
Auto-merging test/export/test_export.py
CONFLICT (content): Merge conflict in test/export/test_export.py
Auto-merging torch/_refs/__init__.py
error: could not apply ed8ed350d5d... assume numel = prod(shape)
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply ed8ed350d5d... assume numel = prod(shape)

Raised by https://github.com/pytorch/pytorch/actions/runs/14600107590

@pianpwk
Copy link
Contributor Author
pianpwk commented Apr 23, 2025

will land, CI looks good, and was told by @masnesral that previous failure was unrelated (flaky test)

@pianpwk
Copy link
Contributor Author
pianpwk commented Apr 23, 2025

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: export Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0