8000 Unbacked SymInt fixes for slice() on subclasses · pytorch/pytorch@9ba5171 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9ba5171

Browse files
committed
Unbacked SymInt fixes for slice() on subclasses
ghstack-source-id: bd3275d Pull Request resolved: #142062
1 parent 9012e7a commit 9ba5171

File tree

3 files changed

+132
-6
lines changed

3 files changed

+132
-6
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,95 @@ def fn(value, mask):
294294
expected = fn(*example_inputs)
295295
torch.testing.assert_close(actual, expected)
296296

297+
@dynamo_config.patch({"capture_scalar_outputs": True})
298+
@parametrize("dynamic", [False, True, None])
299+
def test_unbacked_slice_on_subclass(self, device, dynamic):
300+
from torch.testing._internal.common_subclass import WrapperTensor
301+
from torch.utils._pytree import tree_map
302+
303+
# NB: the error we're testing for only triggers when unbacked SymInts
304+
# are created within a subclass's torch_dispatch, because they're not seen
305+
# by Dynamo and thus are considered freshly-created when the subclass instance
306+
# return value of the torch_dispatch is handled.
307+
# Subclass forwards everything along to the single underlying dense tensor
308+
# component, except for slice(), which it handles via data-dependent bounds access
309+
class CustomSliceSubclass(WrapperTensor):
310+
@classmethod
311+
def get_wrapper_properties(cls, t, slice_bounds=None):
312+
return t, {}
313+
314+
def __init__(self, t, slice_bounds=None):
315+
self.t = t
316+
self.slice_bounds = slice_bounds
317+
318+
def __repr__(self):
319+
t_repr = repr(self.t)
320+
slice_bounds_repr = repr(self.slice_bounds)
321+
return f"CustomSliceSubclass({t_repr}, {slice_bounds_repr})"
322+
323+
def __tensor_flatten__(self):
324+
return ["t", "slice_bounds"], None
325+
326+
@classmethod
327+
def __tensor_unflatten__(
328+
cls, inner_tensors, meta, outer_size, outer_stride
329+
):
330+
t = inner_tensors["t"]
331+
slice_bounds = inner_tensors["slice_bounds"]
332+
return cls(t, slice_bounds)
333+
334+
@classmethod
335+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
336+
if func is torch.ops.aten.slice.Tensor:
337+
inp = args[0]
338+
339+
start = inp.slice_bounds[0].item()
340+
torch._check_is_size(start)
341+
torch._check(start <= inp.size(0))
342+
343+
length = (args[0].slice_bounds[1] - args[0].slice_bounds[0]).item()
344+
torch._check_is_size(length)
345+
torch._check(start + length <= inp.size(0))
346+
347+
return CustomSliceSubclass(
348+
func(args[0].t, dim=0, start=start, end=(start + length)),
349+
slice_bounds=args[0].slice_bounds,
350+
)
351+
352+
if not all(issubclass(cls, t) for t in types):
353+
return NotImplemented
354+
355+
if kwargs is None:
356+
kwargs = {}
357+
358+
def unwrap(e):
359+
return e.t if isinstance(e, CustomSliceSubclass) else e
360+
361+
def wrap(e):
362+
return CustomSliceSubclass(e) if isinstance(e, torch.Tensor) else e
363+
364+
rs = tree_map(
365+
wrap,
366+
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})),
367+
)
368+
return rs
369+
370+
def fn(t, start, length):
371+
return torch.ops.aten.slice.Tensor(
372+
t, dim=0, start=start, end=start + length
373+
)
374+
375+
t = make_tensor(22, 5, dtype=torch.float32, device=device)
376+
sub = CustomSliceSubclass(t, slice_bounds=torch.tensor([2, 5], device=t.device))
377+
start = 2
378+
length = 3
379+
ragged_idx = 1
380+
example_inputs = (sub, start, length)
381+
382+
actual = torch.compile(fn, dynamic=dynamic, fullgraph=True)(*example_inputs)
383+
expected = fn(*example_inputs)
384+
torch.testing.assert_close(actual.t, expected.t)
385+
297386

298387
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
299388

torch/_inductor/ir.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,14 +3009,22 @@ def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def
30093009
dim_size = x.get_size()[dim]
30103010

30113011
if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
3012-
3013-
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
3014-
return sympy.Min(sympy.Max(x, lower), upper)
3015-
3012+
min_func = sympy.Min
3013+
max_func = sympy.Max
30163014
else:
3015+
min_func = sizevars.evaluate_min
3016+
max_func = sizevars.evaluate_max
30173017

3018-
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
3019-
return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
3018+
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
3019+
clamped_lower = (
3020+
x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
3021+
)
3022+
clamped_full = (
3023+
clamped_lower
3024+
if sizevars.statically_known_leq(clamped_lower, upper)
3025+
else min_func(clamped_lower, upper)
3026+
)
3027+
return clamped_full
30203028

30213029
def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def]
30223030
if val is None:

torch/fx/experimental/symbolic_shapes.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,35 @@ def free_unbacked_symbols_with_path(
10271027
assert isinstance(real, int)
10281028
shape_env.set_unbacked_var_to_val(rhs, real // int(lhs))
10291029
pending.remove(rhs)
1030+
# as previous, but for unbacked SymInt * backed SymInt e.g. s1*u0
1031+
elif (
1032+
isinstance(a, torch.SymInt)
1033+
and isinstance(s := a.node._expr, sympy.Mul)
1034+
and len(s.args) == 2
1035+
and isinstance(lhs := s.args[0], sympy.Symbol)
1036+
and isinstance(rhs := s.args[1], sympy.Symbol)
1037+
and ((rhs in pending) ^ (lhs in pending))
1038+
and (
1039+
(rhs in a.node.shape_env.var_to_val)
1040+
^ (lhs in a.node.shape_env.var_to_val)
1041+
)
1042+
):
1043+
unbacked, backed = (lhs, rhs) if lhs in pending else (rhs, lhs)
1044+
# NB: We need a SymInt to pass to DivideByKey.
1045+
# TODO: Is it really necessary to construct the SymInt here or can we get it
1046+
# from somewhere else?
1047+
key = DivideByKey(
1048+
a.node.shape_env.create_symintnode(
1049+
backed,
1050+
hint=int(a.node.shape_env.var_to_val[backed]),
1051+
source=a.node.shape_env.var_to_sources.get(backed, [None])[0],
1052+
)
1053+
)
1054+
r[unbacked] = path + (key,)
1055+
if real is not None:
1056+
assert isinstance(real, int)
1057+
shape_env.set_unbacked_var_to_val(unbacked, CleanDiv(real, backed))
1058+
pending.remove(unbacked)
10301059
# The annoyance here arises from the fact that SymBool is
10311060
# allocated by allocating a SymInt and then testing if it's equal
10321061
# to one. So you have a complicated binding site logic for this.

0 commit comments

Comments
 (0)
0