8000 Unbacked SymInt fixes for subclasses + data-dependent slice() bounds … · pytorch/pytorch@fc03c62 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc03c62

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (#142062)
Related: #125914 (specifically see [comment](#125914 (comment))) This PR addresses two broken things involving the usage of unbacked SymInts for calls to `slice()` with data-dependent bounds. These issues are encountered in practice for `narrow()` operating on the batch dim with an NJT input, but apply to other subclasses as well. The test in this PR uses a purpose-built subclass. There are two different issues here, depending on whether `torch.compile()` is called with `dynamic=True`. In practice, these only occur when the unbacked SymInts are created within the torch_dispatch implementation of a subclass, because the unbacked symbols are considered "freshly created" when the output subclass instance is handled in Dynamo. **Error 1 (dynamic=False):** ``` LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(-Min(22, Max(0, u0)) + Min(22, Max(u0 + u1, Max(0, u0))), 0) (unhinted: Eq(-Min(s0, Max(0, u0)) + Min(s0, Max(u0 + u1, Max(0, u0))), 0)). (Size-like symbols: u1, u0) ``` The expression comes from the use of `clamp()` logic for `SliceView` in Inductor: https://github.com/pytorch/pytorch/blob/41e59754b407533b060b874c22ca4feda38bd83a/torch/_inductor/ir.py#L3014 If the (start, end) bounds for the `slice()` are statically known to be in range for the given dim (e.g. provided via `torch._check()` calls), we can avoid this `clamp()` logic and the error. This PR implements this fix. **Error 2 (dynamic=True):** ``` torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u0} not in returned outputs NestedTensor(size=(2, s16, s1), offsets=FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int64), grad_fn=<NarrowBackwardAutogradNestedTensor0 object at 0x7f1f8603cfd0>, contiguous=True) ((s1*s16, s1, 1), s1*u0) ``` The storage offset of the values component of the returned NJT is `s1*u0` where `s1` is known to be an integer. This PR expands the special logic handling the `constant * u0` case to handle SymInts as well: https://github.com/pytorch/pytorch/blob/314e08eb52ad0e9b1c3eb6e149ec8a452e05b9c3/torch/fx/experimental/symbolic_shapes.py#L1013-L1031 Pull Request resolved: #142062 Approved by: https://github.com/ezyang ghstack dependencies: #143526
1 parent 0b2c479 commit fc03c62

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
instantiate_device_type_tests,
1414
skipGPUIf,
1515
)
16-
from torch.testing._internal.common_utils import decorateIf, IS_LINUX, parametrize
16+
from torch.testing._internal.common_utils import IS_LINUX, parametrize
1717
from torch.testing._internal.inductor_utils import (
1818
GPU_TYPE,
1919
HAS_CUDA,
@@ -295,7 +295,6 @@ def fn(value, mask):
295295
torch.testing.assert_close(actual, expected)
296296

297297
@dynamo_config.patch({"capture_scalar_outputs": True})
298-
@decorateIf(unittest.expectedFailure, lambda params: params["dynamic"])
299298
@parametrize("dynamic", [False, True, None])
300299
def test_unbacked_slice_on_subclass(self, device, dynamic):
301300
from torch.testing._internal.common_subclass import WrapperTensor

torch/fx/experimental/symbolic_shapes.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ def get(self, o: Any) -> Any:
933933

934934
@dataclass(frozen=True)
935935
class DivideByKey:
936-
divisor: int
936+
divisor: Union[int, SymInt]
937937

938938
def __str__(self) -> str:
939939
return f".__floordiv__({self.divisor})"
@@ -1049,16 +1049,39 @@ def free_unbacked_symbols_with_path(
10491049
isinstance(a, torch.SymInt)
10501050
and isinstance(s := a.node._expr, sympy.Mul)
10511051
and len(s.args) == 2
1052-
and isinstance(lhs := s.args[0], sympy.Integer)
1052+
and isinstance(lhs := s.args[0], (sympy.Integer, sympy.Symbol))
10531053
and isinstance(rhs := s.args[1], sympy.Symbol)
1054-
and rhs in pending
1054+
# support exactly one unbacked for now
1055+
and ((rhs in pending) ^ (lhs in pending))
1056+
# support constant coefficient or backed symbolic coefficient
1057+
and (
1058+
isinstance(coeff := lhs if lhs not in pending else rhs, sympy.Integer)
1059+
or coeff in a.node.shape_env.var_to_val
1060+
)
10551061
):
1062+
1063+
def _symint_wrap(s: sympy.Symbol) -> SymInt:
1064+
return a.node.shape_env.create_symintnode(
1065+
s,
1066+
hint=int(a.node.shape_env.var_to_val[s]),
1067+
source=a.node.shape_env.var_to_sources.get(s, [None])[0],
1068+
)
1069+
1070+
unbacked = lhs if lhs in pending else rhs
1071+
divisor: Union[int, SymInt] = (
1072+
int(coeff) if isinstance(coeff, sympy.Integer) else _symint_wrap(coeff)
1073+
)
10561074
# TODO: DivideByKey needs to test divisibility at runtime!
1057-
r[rhs] = path + (DivideByKey(int(lhs)),)
1075+
r[unbacked] = path + (DivideByKey(divisor),)
10581076
if real is not None:
10591077
assert isinstance(real, int)
1060-
shape_env.set_unbacked_var_to_val(rhs, real // int(lhs))
1061-
pending.remove(rhs)
1078+
val = (
1079+
real // int(coeff)
1080+
if isinstance(coeff, sympy.Integer)
1081+
else CleanDiv(real, coeff)
1082+
)
1083+
shape_env.set_unbacked_var_to_val(unbacked, val)
1084+
pending.remove(unbacked)
10621085
# The annoyance here arises from the fact that SymBool is
10631086
# allocated by allocating a SymInt and then testing if it's equal
10641087
# to one. So you have a complicated binding site logic for this.

0 commit comments

Comments
 (0)
0