@@ -254,14 +254,14 @@ def check_all_strides(
254254 return _check_strides_helper (a , b , only_cuda = only_cuda , significant_only = False )
255255
256256
257- # This function is equivalent to compute_contiguous() from TensorImpl.cpp
258- def is_contiguous (a : TensorLikeType , false_if_dde = False ) -> bool :
257+ def check_contiguous_sizes_strides (sizes , strides , false_if_dde = False ):
259258 """
260- Tests whether a tensor is contiguous or not.
261-
262- Tensors are contiguous when they have no elements,
263- one element, or when they have "nested" strides .
259+ Performs an equality check between actual stride & expected stride (based on composed sizes),
260+ handling contiguous stride representations:
261+ e.g. torch.empty(u0, u1, u2). contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1)
262+ and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes .
264263 """
264+
265265 from torch .fx .experimental .symbolic_shapes import (
266266 guard_or_false ,
267267 guard_or_true ,
@@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
272272 maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273273 maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
274274
275- if maybe_guard_or_false (a .numel () < 2 ):
276- return True
277-
278275 expected_stride = 1
279276 expected_stride_max = 1
280277
281- for x , y in reversed (tuple (zip (a . shape , a . stride () ))):
278+ for x , y in reversed (tuple (zip (sizes , strides ))):
282279 # Skips checking strides when a dimension has length 1.
283280 if maybe_guard_or_false (x == 1 ):
284281 continue
@@ -299,6 +296,27 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
299296 return
6D38
True
300297
301298
299+ # This function is equivalent to compute_contiguous() from TensorImpl.cpp
300+ def is_contiguous (a : TensorLikeType , false_if_dde = False ) -> bool :
301+ """
302+ Tests whether a tensor is contiguous or not.
303+
304+ Tensors are contiguous when they have no elements,
305+ one element, or when they have "nested" strides.
306+ """
307+ from torch .fx .experimental .symbolic_shapes import (
308+ guard_or_false ,
309+ guard_size_oblivious ,
310+ )
311+
312+ maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
313+
314+ if maybe_guard_or_false (a .numel () < 2 ):
315+ return True
316+
317+ return check_contiguous_sizes_strides (a .shape , a .stride (), false_if_dde = false_if_dde )
318+
319+
302320# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
303321def is_channels_last_contiguous_2d (a : Tensor , false_if_dde = False ) -> bool :
304322 # NHWC or not channels last 2D contiguous
@@ -438,32 +456,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
438456 ) or is_channels_last_contiguous_or_false_3d (a )
439457
440458
441- def is_non_overlapping_and_dense ( a : Tensor ) -> bool :
459+ def _is_non_overlapping_and_dense_or_false ( sizes , strides ) -> bool :
442460 """
443- True when a tensor is non-overlapping and dense.
461+ Helper function for is_non_overlapping_and_dense.
462+ For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense,
463+ and False otherwise.
444464
445- A tensor is non-overlapping and dense when there exists a permutation of
446- its dimensions that is contiguous.
465+ e.g. sizes: [u0, u1], strides: [u2, u3]
466+ this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1},
467+ but isn't true for all values.
447468 """
469+ from torch .fx .experimental .symbolic_shapes import guard_or_false , guard_or_true
470+ from torch .utils ._sympy .functions import Max
448471
449- from torch .fx .experimental .symbolic_shapes import (
450- guard_or_false ,
451- guard_size_oblivious ,
452- )
453-
454- if a .is_sparse :
455- return False
456-
457- # Short-circuits if the tensor is already contiguous or channels-last contiguous
458- if is_contiguous_or_false (a ) or is_channels_last_contiguous_or_false (a ):
472+ # Short-circuits for 0/1-element tensors
473+ if guard_or_false (prod (sizes ) < 2 ): # type: ignore[operator]
459474 return True
460475
461- # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
462-
463476 # Short-circuits for tensors of rank one, which are
464477 # non-overlapping and "dense" if their stride is one
465- if a . ndim == 1 :
466- return a . stride () [0 ] == 1
478+ if len ( sizes ) == 1 :
479+ return guard_or_false ( strides [0 ] == 1 )
467480
468481 # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
469482 # Sorts (length, stride) pairs by stride
@@ -476,33 +489,49 @@ class K(NamedTuple):
476489 stride : int
477490
478491 def __lt__ (self , other ):
479- return guard_size_oblivious (self .stride < other .stride )
480-
481- def __gt__ (self , other ):
482- return guard_size_oblivious (self .stride > other .stride )
492+ # for backed symbols, this is practically a < operation
493+ # for unbacked, we return True if < is statically known,
494+ # then try to answer this symbolically, with stride ordering semantics
495+ # (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True)
496+ return (
497+ guard_or_false (
498+ self .stride < other .stride
499+ ) # checks statically known inequality
500+ or (
501+ (
502+ guard_or_false (self .stride == 0 )
503+ or guard_or_false (other .stride % self .stride == 0 )
504+ )
505+ and guard_or_true (self .stride != other .stride )
506+ ) # checks symbolic inequality (e.g. u0 < 2048 * u0)
507+ )
483508
484- def __le__ (self , other ):
485- return guard_size_oblivious (self .stride <= other .stride )
509+ lengths_and_strides = sorted (map (K , sizes , strides ))
486510
487- def __ge__ (self , other ):
488- return guard_size_oblivious (self .stride >= other .stride )
511+ # verify that sorted order was imposed (checks the "non-overlapping condition")
512+ for i , j in zip (lengths_and_strides [:- 1 ], lengths_and_strides [1 :]):
513+ if guard_or_false (i .stride == 0 ) or guard_or_false (j .stride % i .stride != 0 ):
514+ return False
489515
490- def __eq__ (self , other ):
491- return guard_size_oblivious (self .stride == other .stride )
516+ # verify actual strides match the expected (composed sizes)
517+ sizes = [x .size for x in lengths_and_strides ][::- 1 ]
518+ strides = [x .stride for x in lengths_and_strides ][::- 1 ]
519+ return check_contiguous_sizes_strides (sizes , strides , false_if_dde = True )
492520
493- lengths_and_strides = sorted (map (K , a .shape , a .stride ()))
494521
495- expected_stride = 1
496- for length , stride in lengths_and_strides :
497- if guard_or_false (length == 1 ):
498- continue
522+ def is_non_overlapping_and_dense (a : Tensor ) -> bool :
523+ """
524+ True when a tensor is non-overlapping and dense.
499525
500- if guard_size_oblivious (stride != expected_stride ):
501- return False
526+ A tensor is non-overlapping and dense when there exists a permutation of
527+ its dimensions that is contiguous.
528+ """
529+ from torch .fx .experimental .symbolic_shapes import guard_or_false
502530
503- expected_stride *= length
531+ if a .is_sparse :
532+ return False
504533
505- return True
534+ return _is_non_overlapping_and_dense_or_false ( a . shape , a . stride ())
506535
507536
508537# NOTE: Based on the implementation in TensorIterator.cpp, but note that
0 commit comments