You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (#142062)
Related: #125914 (specifically see [comment](#125914 (comment)))
This PR addresses two broken things involving the usage of unbacked SymInts for calls to `slice()` with data-dependent bounds. These issues are encountered in practice for `narrow()` operating on the batch dim with an NJT input, but apply to other subclasses as well. The test in this PR uses a purpose-built subclass.
There are two different issues here, depending on whether `torch.compile()` is called with `dynamic=True`. In practice, these only occur when the unbacked SymInts are created within the torch_dispatch implementation of a subclass, because the unbacked symbols are considered "freshly created" when the output subclass instance is handled in Dynamo.
**Error 1 (dynamic=False):**
```
LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(-Min(22, Max(0, u0)) + Min(22, Max(u0 + u1, Max(0, u0))), 0) (unhinted: Eq(-Min(s0, Max(0, u0)) + Min(s0, Max(u0 + u1, Max(0, u0))), 0)). (Size-like symbols: u1, u0)
```
The expression comes from the use of `clamp()` logic for `SliceView` in Inductor:
https://github.com/pytorch/pytorch/blob/41e59754b407533b060b874c22ca4feda38bd83a/torch/_inductor/ir.py#L3014
If the (start, end) bounds for the `slice()` are statically known to be in range for the given dim (e.g. provided via `torch._check()` calls), we can avoid this `clamp()` logic and the error. This PR implements this fix.
**Error 2 (dynamic=True):**
```
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u0} not in returned outputs NestedTensor(size=(2, s16, s1), offsets=FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int64), grad_fn=<NarrowBackwardAutogradNestedTensor0 object at 0x7f1f8603cfd0>, contiguous=True) ((s1*s16, s1, 1), s1*u0)
```
The storage offset of the values component of the returned NJT is `s1*u0` where `s1` is known to be an integer. This PR expands the special logic handling the `constant * u0` case to handle SymInts as well:
https://github.com/pytorch/pytorch/blob/314e08eb52ad0e9b1c3eb6e149ec8a452e05b9c3/torch/fx/experimental/symbolic_shapes.py#L1013-L1031
Pull Request resolved: #142062
Approved by: https://github.com/ezyang
ghstack dependencies: #143526
0 commit comments