-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (non-dynamic) #143526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…(non-dynamic) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143526
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 1494342 with merge base e885225 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…142062) Related: #125914 (specifically see [comment](#125914 (comment))) This PR addresses two broken things involving the usage of unbacked SymInts for calls to `slice()` with data-dependent bounds. These issues are encountered in practice for `narrow()` operating on the batch dim with an NJT input, but apply to other subclasses as well. The test in this PR uses a purpose-built subclass. There are two different issues here, depending on whether `torch.compile()` is called with `dynamic=True`. In practice, these only occur when the unbacked SymInts are created within the torch_dispatch implementation of a subclass, because the unbacked symbols are considered "freshly created" when the output subclass instance is handled in Dynamo. **Error 1 (dynamic=False):** ``` LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(-Min(22, Max(0, u0)) + Min(22, Max(u0 + u1, Max(0, u0))), 0) (unhinted: Eq(-Min(s0, Max(0, u0)) + Min(s0, Max(u0 + u1, Max(0, u0))), 0)). (Size-like symbols: u1, u0) ``` The expression comes from the use of `clamp()` logic for `SliceView` in Inductor: https://github.com/pytorch/pytorch/blob/41e59754b407533b060b874c22ca4feda38bd83a/torch/_inductor/ir.py#L3014 If the (start, end) bounds for the `slice()` are statically known to be in range for the given dim (e.g. provided via `torch._check()` calls), we can avoid this `clamp()` logic and the error. This PR implements this fix. **Error 2 (dynamic=True):** ``` torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u0} not in returned outputs NestedTensor(size=(2, s16, s1), offsets=FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int64), grad_fn=<NarrowBackwardAutogradNestedTensor0 object at 0x7f1f8603cfd0>, contiguous=True) ((s1*s16, s1, 1), s1*u0) ``` The storage offset of the values component of the returned NJT is `s1*u0` where `s1` is known to be an integer. This PR expands the special logic handling the `constant * u0` case to handle SymInts as well: https://github.com/pytorch/pytorch/blob/314e08eb52ad0e9b1c3eb6e149ec8a452e05b9c3/torch/fx/experimental/symbolic_shapes.py#L1013-L1031 Pull Request resolved: #142062 Approved by: https://github.com/ezyang ghstack dependencies: #143526
Stack from ghstack (oldest at bottom):
Lifted non-controversial (non-dynamic) fixes from #142062. See description there for context.
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov