@@ -254,14 +254,14 @@ def check_all_strides(
254254 return _check_strides_helper (a , b , only_cuda = only_cuda , significant_only = False )
255255
256256
257- # This function is equivalent to compute_contiguous() from TensorImpl.cpp
258- def is_contiguous (a : TensorLikeType , false_if_dde = False ) -> bool :
257+ def check_contiguous_sizes_strides (sizes , strides , false_if_dde = False ):
259258 """
260- Tests whether a tensor is contiguous or not.
261-
262- Tensors are contiguous when they have no elements,
263- one element, or when they have "nested" strides .
259+ Performs an equality check between actual stride & expected stride (based on composed sizes),
260+ handling contiguous stride representations:
261+ e.g. torch.empty(u0, u1, u2). contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1)
262+ and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes .
264263 """
264+
265265 from torch .fx .experimental .symbolic_shapes import (
266266 guard_or_false ,
267267 guard_or_true ,
@@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
272272 maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273273 maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
274274
275- if maybe_guard_or_false (a .numel () < 2 ):
276- return True
277-
278275 expected_stride = 1
279276 expected_stride_max = 1
280277
281- for x , y in reversed (tuple (zip (a . shape , a . stride () ))):
278+ for x , y in reversed (tuple (zip (sizes , strides ))):
282279 # Skips checking strides when a dimension has length 1.
283280 if maybe_guard_or_false (x == 1 ):
284281 continue
@@ -299,6 +296,29 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
299296 return True
300297
301298
299+ # This function is equivalent to compute_contiguous() from TensorImpl.cpp
300+ def is_contiguous (a : TensorLikeType , false_if_dde = False ) -> bool :
301+ """
302+ Tests whether a tensor is contiguous or not.
303+
304+ Tensors are contiguous when they have no elements,
305+ one element, or when they have "nested" strides.
306+ """
307+ from torch .fx .experimental .symbolic_shapes import (
308+ guard_or_false ,
309+ guard_size_oblivious ,
310+ )
311+
312+ maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
313+
314+ if maybe_guard_or_false (a .numel () < 2 ):
315+ return True
316+
317+ return check_contiguous_sizes_strides (
318+ a .shape , a .stride (), false_if_dde = false_if_dde
319+ )
320+
321+
302322# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
303323def is_channels_last_contiguous_2d (a : Tensor , false_if_dde = False ) -> bool :
304324 # NHWC or not channels last 2D contiguous
@@ -438,32 +458,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
438458 ) or is_channels_last_contiguous_or_false_3d (a )
439459
440460
441- def is_non_overlapping_and_dense ( a : Tensor ) -> bool :
461+ def _is_non_overlapping_and_dense_or_false ( sizes , strides ) -> bool :
442462 """
443- True when a tensor is non-overlapping and dense.
463+ Helper function for is_non_overlapping_and_dense.
464+ For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense,
465+ and False otherwise.
444466
445- A tensor is non-overlapping and dense when there exists a permutation of
446- its dimensions that is contiguous.
467+ e.g. sizes: [u0, u1], strides: [u2, u3]
468+ this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1},
469+ but isn't true for all values.
447470 """
471+ from torch .fx .experimental .symbolic_shapes import guard_or_false , guard_or_true
472+ from torch .utils ._sympy .functions import Max
448473
449- from torch .fx .experimental .symbolic_shapes import (
450- guard_or_false ,
451- guard_size_oblivious ,
452- )
453-
454- if a .is_sparse :
455- return False
456-
457- # Short-circuits if the tensor is already contiguous or channels-last contiguous
458- if is_contiguous_or_false (a ) or is_channels_last_contiguous_or_false (a ):
474+ # Short-circuits for 0/1-element tensors
475+ if guard_or_false (prod (sizes ) < 2 ): # type: ignore[operator]
459476 return True
460477
461- # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
462-
463478 # Short-circuits for tensors of rank one, which are
464479 # non-overlapping and "dense" if their stride is one
465- if a . ndim == 1 :
466- return a . stride () [0 ] == 1
480+ if len ( sizes ) == 1 :
481+ return guard_or_false ( strides [0 ] == 1 )
467482
468483 # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
469484 # Sorts (length, stride) pairs by stride
@@ -476,33 +491,44 @@ class K(NamedTuple):
476491 stride : int
477492
478493 def __lt__ (self , other ):
479- return guard_size_oblivious (self .stride < other .stride )
480-
481- def __gt__ (self , other ):
482- return guard_size_oblivious (self .stride > other .stride )
483-
484- def __le__ (self , other ):
485- return guard_size_oblivious (self .stride <= other .stride )
494+ # for backed symbols, this is practically a < operation
495+ # for unbacked, we return True if < is statically known,
496+ # then try to answer this symbolically, with stride ordering semantics
497+ # (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True)
498+ return (
499+ guard_or_false (
500+ self .stride < other .stride
501+ ) # checks statically known inequality
502+ or (
503+ (
504+ guard_or_false (self .stride == 0 )
505+ or guard_or_false (other .stride % self .stride == 0 )
506+ )
507+ and guard_or_true (self .stride != other .stride )
508+ ) # checks symbolic inequality (e.g. u0 < 2048 * u0)
509+ )
486510
487- def __ge__ (self , other ):
488- return guard_size_oblivious (self .stride >= other .stride )
511+ lengths_and_strides = sorted (map (K , sizes , strides ))
489512
490- def __eq__ (self , other ):
491- return guard_size_oblivious (self .stride == other .stride )
513+ # verify actual strides match the expected (composed sizes)
514+ sizes = [x .size for x in lengths_and_strides ][::- 1 ]
515+ strides = [x .stride for x in lengths_and_strides ][::- 1 ]
516+ return check_contiguous_sizes_strides (sizes , strides , false_if_dde = True )
492517
493- lengths_and_strides = sorted (map (K , a .shape , a .stride ()))
494518
495- expected_stride = 1
496- for length , stride in lengths_and_strides :
497- if guard_or_false (length == 1 ):
498- continue
519+ def is_non_overlapping_and_dense (a : Tensor ) -> bool :
520+ """
521+ True when a tensor is non-overlapping and dense.
499522
500- if guard_size_oblivious (stride != expected_stride ):
501- return False
523+ A tensor is non-overlapping and dense when there exists a permutation of
524+ its dimensions that is contiguous.
525+ """
526+ from torch .fx .experimental .symbolic_shapes import guard_or_false
502527
503- expected_stride *= length
528+ if a .is_sparse :
529+ return False
504530
505- return True
531+ return _is_non_overlapping_and_dense_or_false ( a . shape , a . stride ())
506532
507533
508534# NOTE: Based on the implementation in TensorIterator.cpp, but note that
0 commit comments