@@ -294,6 +294,95 @@ def fn(value, mask):
294
294
expected = fn (* example_inputs )
295
295
torch .testing .assert_close (actual , expected )
296
296
297
+ @dynamo_config .patch ({"capture_scalar_outputs" : True })
298
+ @parametrize ("dynamic" , [False , True , None ])
299
+ def test_unbacked_slice_on_subclass (self , device , dynamic ):
300
+ from torch .testing ._internal .common_subclass import WrapperTensor
301
+ from torch .utils ._pytree import tree_map
302
+
303
+ # NB: the error we're testing for only triggers when unbacked SymInts
304
+ # are created within a subclass's torch_dispatch, because they're not seen
305
+ # by Dynamo and thus are considered freshly-created when the subclass instance
306
+ # return value of the torch_dispatch is handled.
307
+ # Subclass forwards everything along to the single underlying dense tensor
308
+ # component, except for slice(), which it handles via data-dependent bounds access
309
+ class CustomSliceSubclass (WrapperTensor ):
310
+ @classmethod
311
+ def get_wrapper_properties (cls , t , slice_bounds = None ):
312
+ return t , {}
313
+
314
+ def __init__ (self , t , slice_bounds = None ):
315
+ self .t = t
316
+ self .slice_bounds = slice_bounds
317
+
318
+ def __repr__ (self ):
319
+ t_repr = repr (self .t )
320
+ slice_bounds_repr = repr (self .slice_bounds )
321
+ return f"CustomSliceSubclass({ t_repr } , { slice_bounds_repr } )"
322
+
323
+ def __tensor_flatten__ (self ):
324
+ return ["t" , "slice_bounds" ], None
325
+
326
+ @classmethod
327
+ def __tensor_unflatten__ (
328
+ cls , inner_tensors , meta , outer_size , outer_stride
329
+ ):
330
+ t = inner_tensors ["t" ]
331
+ slice_bounds = inner_tensors ["slice_bounds" ]
332
+ return cls (t , slice_bounds )
333
+
334
+ @classmethod
335
+ def __torch_dispatch__ (cls , func , types , args = (), kwargs = None ):
336
+ if func is torch .ops .aten .slice .Tensor :
337
+ inp = args [0 ]
338
+
339
+ start = inp .slice_bounds [0 ].item ()
340
+ torch ._check_is_size (start )
341
+ torch ._check (start <= inp .size (0 ))
342
+
343
+ length = (args [0 ].slice_bounds [1 ] - args [0 ].slice_bounds [0 ]).item ()
344
+ torch ._check_is_size (length )
345
+ torch ._check (start + length <= inp .size (0 ))
346
+
347
+ return CustomSliceSubclass (
348
+ func (args [0 ].t , dim = 0 , start = start , end = (start + length )),
349
+ slice_bounds = args [0 ].slice_bounds ,
350
+ )
351
+
352
+ if not all (issubclass (cls , t ) for t in types ):
353
+ return NotImplemented
354
+
355
+ if kwargs is None :
356
+ kwargs = {}
357
+
358
+ def unwrap (e ):
359
+ return e .t if isinstance (e , CustomSliceSubclass ) else e
360
+
361
+ def wrap (e ):
362
+ return CustomSliceSubclass (e ) if isinstance (e , torch .Tensor ) else e
363
+
364
+ rs = tree_map (
365
+ wrap ,
366
+ func (* tree_map (unwrap , args ), ** tree_map (unwrap , kwargs or {})),
367
+ )
368
+ return rs
369
+
370
+ def fn (t , start , length ):
371
+ return torch .ops .aten .slice .Tensor (
372
+ t , dim = 0 , start = start , end = start + length
373
+ )
374
+
375
+ t = make_tensor (22 , 5 , dtype = torch .float32 , device = device )
376
+ sub = CustomSliceSubclass (t , slice_bounds = torch .tensor ([2 , 5 ], device = t .device ))
377
+ start = 2
378
+ length = 3
379
+ ragged_idx = 1
380
+ example_inputs = (sub , start , length )
381
+
382
+ actual = torch .compile (fn , dynamic = dynamic , fullgraph = True )(* example_inputs )
383
+ expected = fn (* example_inputs )
384
+ torch .testing .assert_close (actual .t , expected .t )
385
+
297
386
298
387
instantiate_device_type_tests (TestUnbackedSymints , globals (), allow_xpu = True )
299
388
0 commit comments