8000 Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (non-dynamic) by jbschlosser · Pull Request #143526 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Unbacked SymInt fixes for subclasses + data-dependent slice() bounds (non-dynamic) #143526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion test/inductor/test_unbacked_symints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
instantiate_device_type_tests,
skipGPUIf,
)
from torch.testing._internal.common_utils import IS_LINUX, parametrize
from torch.testing._internal.common_utils import decorateIf, IS_LINUX, parametrize
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA,
Expand Down Expand Up @@ -294,6 +294,96 @@ def fn(value, mask):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)

@dynamo_config.patch({"capture_scalar_outputs": True})
@decorateIf(unittest.expectedFailure, lambda params: params["dynamic"])
@parametrize("dynamic", [False, True, None])
def test_unbacked_slice_on_subclass(self, device, dynamic):
from torch.testing._internal.common_subclass import WrapperTensor
from torch.utils._pytree import tree_map

# NB: the error we're testing for only triggers when unbacked SymInts
# are created within a subclass's torch_dispatch, because they're not seen
# by Dynamo and thus are considered freshly-created when the subclass instance
# return value of the torch_dispatch is handled.
# Subclass forwards everything along to the single underlying dense tensor
# component, except for slice(), which it handles via data-dependent bounds access
class CustomSliceSubclass(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, t, slice_bounds=None):
return t, {}

def __init__(self, t, slice_bounds=None):
self.t = t
self.slice_bounds = slice_bounds

def __repr__(self):
t_repr = repr(self.t)
slice_bounds_repr = repr(self.slice_bounds)
return f"CustomSliceSubclass({t_repr}, {slice_bounds_repr})"

def __tensor_flatten__(self):
return ["t", "slice_bounds"], None

@classmethod
def __tensor_unflatten__(
cls, inner_tensors, meta, outer_size, outer_stride
):
t = inner_tensors["t"]
slice_bounds = inner_tensors["slice_bounds"]
return cls(t, slice_bounds)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func is torch.ops.aten.slice.Tensor:
inp = args[0]

start = inp.slice_bounds[0].item()
torch._check_is_size(start)
torch._check(start <= inp.size(0))

length = (args[0].slice_bounds[1] - args[0].slice_bounds[0]).item()
torch._check_is_size(length)
torch._check(start + length <= inp.size(0))

return CustomSliceSubclass(
func(args[0].t, dim=0, start=start, end=(start + length)),
slice_bounds=args[0].slice_bounds,
)

if not all(issubclass(cls, t) for t in types):
return NotImplemented

if kwargs is None:
kwargs = {}

def unwrap(e):
return e.t if isinstance(e, CustomSliceSubclass) else e

def wrap(e):
return CustomSliceSubclass(e) if isinstance(e, torch.Tensor) else e

rs = tree_map(
wrap,
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})),
)
return rs

def fn(t, start, length):
return torch.ops.aten.slice.Tensor(
t, dim=0, start=start, end=start + length
)

t = make_tensor(22, 5, dtype=torch.float32, device=device)
sub = CustomSliceSubclass(t, slice_bounds=torch.tensor([2, 5], device=t.device))
start = 2
length = 3
ragged_idx = 1
example_inputs = (sub, start, length)

actual = torch.compile(fn, dynamic=dynamic, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual.t, expected.t)


instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)

Expand Down
20 changes: 14 additions & 6 deletions torch/_inductor/ir.py
< 8000 td class="blob-num blob-num-expandable" colspan="2"> Expand Up
Original file line number Diff line number Diff line change
@@ -2952,14 +2952,22 @@ def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def
dim_size = x.get_size()[dim]

if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):

def clamp(x, lower, upper): # type: ignore[no-untyped-def]
return sympy.Min(sympy.Max(x, lower), upper)

min_func = sympy.Min
max_func = sympy.Max
else:
min_func = sizevars.evaluate_min
max_func = sizevars.evaluate_max

def clamp(x, lower, upper): # type: ignore[no-untyped-def]
return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
def clamp(x, lower, upper): # type: ignore[no-untyped-def]
clamped_lower = (
x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
)
clamped_full = (
clamped_lower
if sizevars.statically_known_leq(clamped_lower, upper)
else min_func(clamped_lower, upper)
)
return clamped_full

def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def]
if val is None:
Expand Down
7 changes: 6 additions & 1 deletion torch/fx/passes/runtime_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,12 @@ def _node_metadata_hook(
node.args,
)
try:
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
target = node.target
if node.op == "call_method":
assert isinstance(node.target, str)
target = getattr(fake_args[0], node.target)
fake_args = fake_args[1:]
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
except NotImplementedError:
# This can happen when attempting to reify a symbol with an unsupported call_function node,
# e.g. with NestedTensors + sym_size.int via match_symbol().
Expand Down
Loading
0