8000 [dynamic shapes] prims_common non_overlapping_and_dense (#160462) · pytorch/pytorch@ae88f5a · GitHub
[go: up one dir, main page]

Skip to content

Commit ae88f5a

Browse files
pianpwkfacebook-github-bot
authored andcommitted
[dynamic shapes] prims_common non_overlapping_and_dense (#160462)
Summary: Pull Request resolved: #160462 Test Plan: test_dynamic_shapes Rollback Plan: Differential Revision: D80120333
1 parent 25d0d8b commit ae88f5a

File tree

2 files changed

+135
-48
lines changed

2 files changed

+135
-48
lines changed

test/test_dynamic_shapes.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def test_mul_int_oo_nan(self):
861861
s2 = create_symint(shape_env, 5, duck=False)
862862
bool(s0 * (s1 // s0) == s2)
863863

864-
def test_non_overlapping_and_dense(self):
864+
def test_non_overlapping_and_dense_backed(self):
865865
shape_env = ShapeEnv()
866866
a0 = create_symint(shape_env, 5)
867867
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
@@ -896,6 +896,64 @@ def test_non_overlapping_and_dense_unbacked(self):
896896
)
897897
)
898898

899+
def test_prims_non_overlapping_and_dense(self):
900+
shape_env = ShapeEnv()
901+
cf = torch._prims_common.is_non_overlapping_and_dense
902+
903+
# backed case
904+
a0 = create_symint(shape_env, 5)
905+
self.assertTrue(cf(torch.empty_strided((a0, 7), (1, a0), device="meta")))
906+
907+
# unbacked
908+
u0 = shape_env.create_unbacked_symint()
909+
torch._check_is_size(u0)
910+
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
911+
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
912+
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
913+
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
914+
915+
Max = torch.sym_max
916+
self.assertTrue(
917+
cf(
918+
torch.empty_strided(
919+
(2, 3, 1, u0),
920+
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
921+
device="meta",
922+
)
923+
)
924+
)
925+
self.assertFalse(
926+
cf(
927+
torch.empty_strided(
928+
(2, 3, 1, u0),
929+
(Max(1, u0), Max(1, u0), 1, 3 * Max(1, u0)),
930+
device="meta",
931+
)
932+
)
933+
)
934+
935+
# return False on arbitrary strides
936+
u1 = shape_env.create_unbacked_symint()
937+
torch._check_is_size(u1)
938+
self.assertFalse(
939+
cf(
940+
torch.empty_strided(
941+
(2 * u0, u0, 1),
942+
(u1, u0, u0 + u1),
943+
device="meta",
944+
)
945+
)
946+
)
947+
self.assertFalse(
948+
cf(
949+
torch.empty_strided(
950+
(2, 3, u0),
951+
(u1, 3, 1),
952+
device="meta",
953+
)
954+
)
955+
)
956+
899957
def test_sympy_optimized_add_binary_search(self):
900958
import sympy
901959

torch/_prims_common/__init__.py

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,14 @@ def check_all_strides(
254254
return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
255255

256256

257-
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
258-
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
257+
def check_contiguous_sizes_strides(sizes, strides, false_if_dde=False):
259258
"""
260-
Tests whether a tensor is contiguous or not.
261-
262-
Tensors are contiguous when they have no elements,
263-
one element, or when they have "nested" strides.
259+
Performs an equality check between actual stride & expected stride (based on composed sizes),
260+
handling contiguous stride representations:
261+
e.g. torch.empty(u0, u1, u2).contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1)
262+
and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes.
264263
"""
264+
265265
from torch.fx.experimental.symbolic_shapes import (
266266
guard_or_false,
267267
guard_or_true,
@@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
272272
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273273
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
274274

275-
if maybe_guard_or_false(a.numel() < 2):
276-
return True
277-
278275
expected_stride = 1
279276
expected_stride_max = 1
280277

281-
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
278+
for x, y in reversed(tuple(zip(sizes, strides))):
282279
# Skips checking strides when a dimension has length 1.
283280
if maybe_guard_or_false(x == 1):
284281
continue
@@ -299,6 +296,27 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
299296
return 6D38 True
300297

301298

299+
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
300+
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
301+
"""
302+
Tests whether a tensor is contiguous or not.
303+
304+
Tensors are contiguous when they have no elements,
305+
one element, or when they have "nested" strides.
306+
"""
307+
from torch.fx.experimental.symbolic_shapes import (
308+
guard_or_false,
309+
guard_size_oblivious,
310+
)
311+
312+
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
313+
314+
if maybe_guard_or_false(a.numel() < 2):
315+
return True
316+
317+
return check_contiguous_sizes_strides(a.shape, a.stride(), false_if_dde=false_if_dde)
318+
319+
302320
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
303321
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
304322
# NHWC or not channels last 2D contiguous
@@ -438,32 +456,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
438456
) or is_channels_last_contiguous_or_false_3d(a)
439457

440458

441-
def is_non_overlapping_and_dense(a: Tensor) -> bool:
459+
def _is_non_overlapping_and_dense_or_false(sizes, strides) -> bool:
442460
"""
443-
True when a tensor is non-overlapping and dense.
461+
Helper function for is_non_overlapping_and_dense.
462+
For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense,
463+
and False otherwise.
444464
445-
A tensor is non-overlapping and dense when there exists a permutation of
446-
its dimensions that is contiguous.
465+
e.g. sizes: [u0, u1], strides: [u2, u3]
466+
this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1},
467+
but isn't true for all values.
447468
"""
469+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
470+
from torch.utils._sympy.functions import Max
448471

449-
from torch.fx.experimental.symbolic_shapes import (
450-
guard_or_false,
451-
guard_size_oblivious,
452-
)
453-
454-
if a.is_sparse:
455-
return False
456-
457-
# Short-circuits if the tensor is already contiguous or channels-last contiguous
458-
if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a):
472+
# Short-circuits for 0/1-element tensors
473+
if guard_or_false(prod(sizes) < 2): # type: ignore[operator]
459474
return True
460475

461-
# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
462-
463476
# Short-circuits for tensors of rank one, which are
464477
# non-overlapping and "dense" if their stride is one
465-
if a.ndim == 1:
466-
return a.stride()[0] == 1
478+
if len(sizes) == 1:
479+
return guard_or_false(strides[0] == 1)
467480

468481
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
469482
# Sorts (length, stride) pairs by stride
@@ -476,33 +489,49 @@ class K(NamedTuple):
476489
stride: int
477490

478491
def __lt__(self, other):
479-
return guard_size_oblivious(self.stride < other.stride)
480-
481-
def __gt__(self, other):
482-
return guard_size_oblivious(self.stride > other.stride)
492+
# for backed symbols, this is practically a < operation
493+
# for unbacked, we return True if < is statically known,
494+
# then try to answer this symbolically, with stride ordering semantics
495+
# (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True)
496+
return (
497+
guard_or_false(
498+
self.stride < other.stride
499+
) # checks statically known inequality
500+
or (
501+
(
502+
guard_or_false(self.stride == 0)
503+
or guard_or_false(other.stride % self.stride == 0)
504+
)
505+
and guard_or_true(self.stride != other.stride)
506+
) # checks symbolic inequality (e.g. u0 < 2048 * u0)
507+
)
483508

484-
def __le__(self, other):
485-
return guard_size_oblivious(self.stride <= other.stride)
509+
lengths_and_strides = sorted(map(K, sizes, strides))
486510

487-
def __ge__(self, other):
488-
return guard_size_oblivious(self.stride >= other.stride)
511+
# verify that sorted order was imposed (checks the "non-overlapping condition")
512+
for i, j in zip(lengths_and_strides[:-1], lengths_and_strides[1:]):
513+
if guard_or_false(i.stride == 0) or guard_or_false(j.stride % i.stride != 0):
514+
return False
489515

490-
def __eq__(self, other):
491-
return guard_size_oblivious(self.stride == other.stride)
516+
# verify actual strides match the expected (composed sizes)
517+
sizes = [x.size for x in lengths_and_strides][::-1]
518+
strides = [x.stride for x in lengths_and_strides][::-1]
519+
return check_contiguous_sizes_strides(sizes, strides, false_if_dde=True)
492520

493-
lengths_and_strides = sorted(map(K, a.shape, a.stride()))
494521

495-
expected_stride = 1
496-
for length, stride in lengths_and_strides:
497-
if guard_or_false(length == 1):
498-
continue
522+
def is_non_overlapping_and_dense(a: Tensor) -> bool:
523+
"""
524+
True when a tensor is non-overlapping and dense.
499525
500-
if guard_size_oblivious(stride != expected_stride):
501-
return False
526+
A tensor is non-overlapping and dense when there exists a permutation of
527+
its dimensions that is contiguous.
528+
"""
529+
from torch.fx.experimental.symbolic_shapes import guard_or_false
502530

503-
expected_stride *= length
531+
if a.is_sparse:
532+
return False
504533

505-
return True
534+
return _is_non_overlapping_and_dense_or_false(a.shape, a.stride())
506535

507536

508537
# NOTE: Based on the implementation in TensorIterator.cpp, but note that

0 commit comments

Comments
 (0)
0