10000 Unbacked SymInt fixes for subclasses + data-dependent slice() bounds … · pytorch/pytorch@1494342 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1494342

Browse files
committed
Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (non-dynamic)
[ghstack-poisoned]
1 parent e885225 commit 1494342

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
instantiate_device_type_tests,
1414
skipGPUIf,
1515
)
16-
from torch.testing._internal.common_utils import IS_LINUX, parametrize
16+
from torch.testing._internal.common_utils import decorateIf, IS_LINUX, parametrize
1717
from torch.testing._internal.inductor_utils import (
1818
GPU_TYPE,
1919
HAS_CUDA,
@@ -294,6 +294,96 @@ def fn(value, mask):
294294
expected = fn(*example_inputs)
295295
torch.testing.assert_close(actual, expected)
296296

297+
@dynamo_config.patch({"capture_scalar_outputs": True})
298+
@decorateIf(unittest.expectedFailure, lambda params: params["dynamic"])
299+
@parametrize("dynamic", [False, True, None])
300+
def test_unbacked_slice_on_subclass(self, device, dynamic):
301+
from torch.testing._internal.common_subclass import WrapperTensor
302+
from torch.utils._pytree import tree_map
303+
304+
# NB: the error we're testing for only triggers when unbacked SymInts
305+
# are created within a subclass's torch_dispatch, because they're not seen
306+
# by Dynamo and thus are considered freshly-created when the subclass instance
307+
# return value of the torch_dispatch is handled.
308+
# Subclass forwards everything along to the single underlying dense tensor
309+
# component, except for slice(), which it handles via data-dependent bounds access
310+
class CustomSliceSubclass(WrapperTensor):
311+
@classmethod
312+
def get_wrapper_properties(cls, t, slice_bounds=None):
313+
return t, {}
314+
315+
def __init__(self, t, slice_bounds=None):
316+
self.t = t
317+
self.slice_bounds = slice_bounds
318+
319+
def __repr__(self):
320+
t_repr = repr(self.t)
321+
slice_bounds_repr = repr(self.slice_bounds)
322+
return f"CustomSliceSubclass({t_repr}, {slice_bounds_repr})"
323+
324+
def __tensor_flatten__(self):
325+
return ["t", "slice_bounds"], None
326+
327+
@classmethod
328+
def __tensor_unflatten__(
329+
8000 cls, inner_tensors, meta, outer_size, outer_stride
330+
):
331+
t = inner_tensors["t"]
332+
slice_bounds = inner_tensors["slice_bounds"]
333+
return cls(t, slice_bounds)
334+
335+
@classmethod
336+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
337+
if func is torch.ops.aten.slice.Tensor:
338+
inp = args[0]
339+
340+
start = inp.slice_bounds[0].item()
341+
torch._check_is_size(start)
342+
torch._check(start <= inp.size(0))
343+
344+
length = (args[0].slice_bounds[1] - args[0].slice_bounds[0]).item()
345+
torch._check_is_size(length)
346+
torch._check(start + length <= inp.size(0))
347+
348+
return CustomSliceSubclass(
349+
func(args[0].t, dim=0, start=start, end=(start + length)),
350+
slice_bounds=args[0].slice_bounds,
351+
)
352+
353+
if not all(issubclass(cls, t) for t in types):
354+
return NotImplemented
355+
356+
if kwargs is None:
357+
kwargs = {}
358+
359+
def unwrap(e):
360+
return e.t if isinstance(e, CustomSliceSubclass) else e
361+
362+
def wrap(e):
363+
return CustomSliceSubclass(e) if isinstance(e, torch.Tensor) else e
364+
365+
rs = tree_map(
366+
wrap,
367+
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})),
368+
)
369+
return rs
370+
371+
def fn(t, start, length):
372+
return torch.ops.aten.slice.Tensor(
373+
t, dim=0, start=start, end=start + length
374+
)
375+
376+
t = make_tensor(22, 5, dtype=torch.float32, device=device)
377+
sub = CustomSliceSubclass(t, slice_bounds=torch.tensor([2, 5], device=t.device))
378+
start = 2
379+
length = 3
380+
ragged_idx = 1
381+
example_inputs = (sub, start, length)
382+
383+
actual = torch.compile(fn, dynamic=dynamic, fullgraph=True)(*example_inputs)
384+
expected = fn(*example_inputs)
385+
torch.testing.assert_close(actual.t, expected.t)
386+
297387

298388
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
299389

torch/_inductor/ir.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,14 +2952,22 @@ def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def
29522952
dim_size = x.get_size()[dim]
29532953

29542954
if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
2955-
2956-
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
2957-
return sympy.Min(sympy.Max(x, lower), upper)
2958-
2955+
min_func = sympy.Min
2956+
max_func = sympy.Max
29592957
else:
2958+
min_func = sizevars.evaluate_min
2959+
max_func = sizevars.evaluate_max
29602960

2961-
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
2962-
return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
2961+
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
2962+
clamped_lower = (
2963+
x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
2964+
)
2965+
clamped_full = (
2966+
clamped_lower
2967+
if sizevars.statically_known_leq(clamped_lower, upper)
2968+
else min_func(clamped_lower, upper)
2969+
)
2970+
return clamped_full
29632971

29642972
def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def]
29652973
if val is None:

torch/fx/passes/runtime_assert.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,12 @@ def _node_metadata_hook(
172172
node.args,
173173
)
174174
try:
175-
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
175+
target = node.target
176+
if node.op == "call_method":
177+
assert isinstance(node.target, str)
178+
target = getattr(fake_args[0], node.target)
179+
fake_args = fake_args[1:]
180+
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
176181
except NotImplementedError:
177182
# This can happen when attempting to reify a symbol with an unsupported call_function node,
178183
# e.g. with NestedTensors + sym_size.int via match_symbol().

0 commit comments

Comments
 (0)
0