|
13 | 13 | instantiate_device_type_tests,
|
14 | 14 | skipGPUIf,
|
15 | 15 | )
|
16 |
| -from torch.testing._internal.common_utils import IS_LINUX, parametrize |
| 16 | +from torch.testing._internal.common_utils import decorateIf, IS_LINUX, parametrize |
17 | 17 | from torch.testing._internal.inductor_utils import (
|
18 | 18 | GPU_TYPE,
|
19 | 19 | HAS_CUDA,
|
@@ -294,6 +294,96 @@ def fn(value, mask):
|
294 | 294 | expected = fn(*example_inputs)
|
295 | 295 | torch.testing.assert_close(actual, expected)
|
296 | 296 |
|
| 297 | + @dynamo_config.patch({"capture_scalar_outputs": True}) |
| 298 | + @decorateIf(unittest.expectedFailure, lambda params: params["dynamic"]) |
| 299 | + @parametrize("dynamic", [False, True, None]) |
| 300 | + def test_unbacked_slice_on_subclass(self, device, dynamic): |
| 301 | + from torch.testing._internal.common_subclass import WrapperTensor |
| 302 | + from torch.utils._pytree import tree_map |
| 303 | + |
| 304 | + # NB: the error we're testing for only triggers when unbacked SymInts |
| 305 | + # are created within a subclass's torch_dispatch, because they're not seen |
| 306 | + # by Dynamo and thus are considered freshly-created when the subclass instance |
| 307 | + # return value of the torch_dispatch is handled. |
| 308 | + # Subclass forwards everything along to the single underlying dense tensor |
| 309 | + # component, except for slice(), which it handles via data-dependent bounds access |
| 310 | + class CustomSliceSubclass(WrapperTensor): |
| 311 | + @classmethod |
| 312 | + def get_wrapper_properties(cls, t, slice_bounds=None): |
| 313 | + return t, {} |
| 314 | + |
| 315 | + def __init__(self, t, slice_bounds=None): |
| 316 | + self.t = t |
| 317 | + self.slice_bounds = slice_bounds |
| 318 | + |
| 319 | + def __repr__(self): |
| 320 | + t_repr = repr(self.t) |
| 321 | + slice_bounds_repr = repr(self.slice_bounds) |
| 322 | + return f"CustomSliceSubclass({t_repr}, {slice_bounds_repr})" |
| 323 | + |
| 324 | + def __tensor_flatten__(self): |
| 325 | + return ["t", "slice_bounds"], None |
| 326 | + |
| 327 | + @classmethod |
| 328 | + def __tensor_unflatten__( |
| 329 | +
8000
cls, inner_tensors, meta, outer_size, outer_stride |
| 330 | + ): |
| 331 | + t = inner_tensors["t"] |
| 332 | + slice_bounds = inner_tensors["slice_bounds"] |
| 333 | + return cls(t, slice_bounds) |
| 334 | + |
| 335 | + @classmethod |
| 336 | + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| 337 | + if func is torch.ops.aten.slice.Tensor: |
| 338 | + inp = args[0] |
| 339 | + |
| 340 | + start = inp.slice_bounds[0].item() |
| 341 | + torch._check_is_size(start) |
| 342 | + torch._check(start <= inp.size(0)) |
| 343 | + |
| 344 | + length = (args[0].slice_bounds[1] - args[0].slice_bounds[0]).item() |
| 345 | + torch._check_is_size(length) |
| 346 | + torch._check(start + length <= inp.size(0)) |
| 347 | + |
| 348 | + return CustomSliceSubclass( |
| 349 | + func(args[0].t, dim=0, start=start, end=(start + length)), |
| 350 | + slice_bounds=args[0].slice_bounds, |
| 351 | + ) |
| 352 | + |
| 353 | + if not all(issubclass(cls, t) for t in types): |
| 354 | + return NotImplemented |
| 355 | + |
| 356 | + if kwargs is None: |
| 357 | + kwargs = {} |
| 358 | + |
| 359 | + def unwrap(e): |
| 360 | + return e.t if isinstance(e, CustomSliceSubclass) else e |
| 361 | + |
| 362 | + def wrap(e): |
| 363 | + return CustomSliceSubclass(e) if isinstance(e, torch.Tensor) else e |
| 364 | + |
| 365 | + rs = tree_map( |
| 366 | + wrap, |
| 367 | + func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})), |
| 368 | + ) |
| 369 | + return rs |
| 370 | + |
| 371 | + def fn(t, start, length): |
| 372 | + return torch.ops.aten.slice.Tensor( |
| 373 | + t, dim=0, start=start, end=start + length |
| 374 | + ) |
| 375 | + |
| 376 | + t = make_tensor(22, 5, dtype=torch.float32, device=device) |
| 377 | + sub = CustomSliceSubclass(t, slice_bounds=torch.tensor([2, 5], device=t.device)) |
| 378 | + start = 2 |
| 379 | + length = 3 |
| 380 | + ragged_idx = 1 |
| 381 | + example_inputs = (sub, start, length) |
| 382 | + |
| 383 | + actual = torch.compile(fn, dynamic=dynamic, fullgraph=True)(*example_inputs) |
| 384 | + expected = fn(*example_inputs) |
| 385 | + torch.testing.assert_close(actual.t, expected.t) |
| 386 | + |
297 | 387 |
|
298 | 388 | instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
299 | 389 |
|
|
0 commit comments