8000 [dynamic shapes] prims_common non_overlapping_and_dense (#160462) · pytorch/pytorch@5670291 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5670291

Browse files
pianpwkfacebook-github-bot
authored andcommitted
[dynamic shapes] prims_common non_overlapping_and_dense (#160462)
Summary: Pull Request resolved: #160462 Test Plan: test_dynamic_shapes Rollback Plan: Reviewed By: laithsakka Differential Revision: D80120333
1 parent 16ada80 commit 5670291

File tree

2 files changed

+133
-49
lines changed

2 files changed

+133
-49
lines changed

test/test_dynamic_shapes.py

Lines changed: 59 additions & 1 deletion
< 7440 td data-grid-cell-id="diff-036cb21341ff8e390cc250e74fe9e3f0f15f259ea4bec4abcce49d95febf1553-867-867-1" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">867
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def test_mul_int_oo_nan(self):
861861
s2 = create_symint(shape_env, 5, duck=False)
862862
bool(s0 * (s1 // s0) == s2)
863863

864-
def test_non_overlapping_and_dense(self):
864+
def test_non_overlapping_and_dense_backed(self):
865865
shape_env = ShapeEnv()
866866
a0 = create_symint(shape_env, 5)
867
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
@@ -896,6 +896,64 @@ def test_non_overlapping_and_dense_unbacked(self):
896896
)
897897
)
898898

899+
def test_prims_non_overlapping_and_dense(self):
900+
shape_env = ShapeEnv()
901+
cf = torch._prims_common.is_non_overlapping_and_dense
902+
903+
# backed case
904+
a0 = create_symint(shape_env, 5)
905+
self.assertTrue(cf(torch.empty_strided((a0, 7), (1, a0), device="meta")))
906+
907+
# unbacked
908+
u0 = shape_env.create_unbacked_symint()
909+
torch._check_is_size(u0)
910+
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
911+
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
912+
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
913+
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
914+
915+
Max = torch.sym_max
916+
self.assertTrue(
917+
cf(
918+
torch.empty_strided(
919+
(2, 3, 1, u0),
920+
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
921+
device="meta",
922+
)
923+
)
924+
)
925+
self.assertFalse(
926+
cf(
927+
torch.empty_strided(
928+
(2, 3, 1, u0),
929+
(Max(1, u0), Max(1, u0), 1, 3 * Max(1, u0)),
930+
device="meta",
931+
)
932+
)
933+
)
934+
935+
# return False on arbitrary strides
936+
u1 = shape_env.create_unbacked_symint()
937+
torch._check_is_size(u1)
938+
self.assertFalse(
939+
cf(
940+
torch.empty_strided(
941+
(2 * u0, u0, 1),
942+
(u1, u0, u0 + u1),
943+
device="meta",
944+
)
945+
)
946+
)
947+
self.assertFalse(
948+
cf(
949+
torch.empty_strided(
950+
(2, 3, u0),
951+
(u1, 3, 1),
952+
device="meta",
953+
)
954+
)
955+
)
956+
899957
def test_sympy_optimized_add_binary_search(self):
900958
import sympy
901959

torch/_prims_common/__init__.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,14 @@ def check_all_strides(
254254
return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
255255

256256

257-
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
258-
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
257+
def check_contiguous_sizes_strides(sizes, strides, false_if_dde=False):
259258
"""
260-
Tests whether a tensor is contiguous or not.
261-
262-
Tensors are contiguous when they have no elements,
263-
one element, or when they have "nested" strides.
259+
Performs an equality check between actual stride & expected stride (based on composed sizes),
260+
handling contiguous stride representations:
261+
e.g. torch.empty(u0, u1, u2).contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1)
262+
and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes.
264263
"""
264+
265265
from torch.fx.experimental.symbolic_shapes import (
266266
guard_or_false,
267267
guard_or_true,
@@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
272272
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273273
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
274274

275-
if maybe_guard_or_false(a.numel() < 2):
276-
return True
277-
278275
expected_stride = 1
279276
expected_stride_max = 1
280277

281-
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
278+
for x, y in reversed(tuple(zip(sizes, strides))):
282279
# Skips checking strides when a dimension has length 1.
283280
if maybe_guard_or_false(x == 1):
284281
continue
@@ -299,6 +296,29 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
299296
return True
300297

301298

299+
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
300+
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
301+
"""
302+
Tests whether a tensor is contiguous or not.
303+
304+
Tensors are contiguous when they have no elements,
305+
one element, or when they have "nested" strides.
306+
"""
307+
from torch.fx.experimental.symbolic_shapes import (
308+
guard_or_false,
309+
guard_size_oblivious,
310+
)
311+
312+
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
313+
314+
if maybe_guard_or_false(a.numel() < 2):
315+
return True
316+
317+
return check_contiguous_sizes_strides(
318+
a.shape, a.stride(), false_if_dde=false_if_dde
319+
)
320+
321+
302322
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
303323
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
304324
# NHWC or not channels last 2D contiguous
@@ -438,32 +458,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
438458
) or is_channels_last_contiguous_or_false_3d(a)
439459

440460

441-
def is_non_overlapping_and_dense(a: Tensor) -> bool:
461+
def _is_non_overlapping_and_dense_or_false(sizes, strides) -> bool:
442462
"""
443-
True when a tensor is non-overlapping and dense.
463+
Helper function for is_non_overlapping_and_dense.
464+
For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense,
465+
and False otherwise.
444466
445-
A tensor is non-overlapping and dense when there exists a permutation of
446-
its dimensions that is contiguous.
467+
e.g. sizes: [u0, u1], strides: [u2, u3]
468+
this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1},
469+
but isn't true for all values.
447470
"""
471+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
472+
from torch.utils._sympy.functions import Max
448473

449-
from torch.fx.experimental.symbolic_shapes import (
450-
guard_or_false,
451-
guard_size_oblivious,
452-
)
453-
454-
if a.is_sparse:
455-
return False
456-
457-
# Short-circuits if the tensor is already contiguous or channels-last contiguous
458-
if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a):
474+
# Short-circuits for 0/1-element tensors
475+
if guard_or_false(prod(sizes) < 2): # type: ignore[operator]
459476
return True
460477

461-
# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
462-
463478
# Short-circuits for tensors of rank one, which are
464479
# non-overlapping and "dense" if their stride is one
465-
if a.ndim == 1:
466-
return a.stride()[0] == 1
480+
if len(sizes) == 1:
481+
return guard_or_false(strides[0] == 1)
467482

468483
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
469484
# Sorts (length, stride) pairs by stride
@@ -476,33 +491,44 @@ class K(NamedTuple):
476491
stride: int
477492

478493
def __lt__(self, other):
479-
return guard_size_oblivious(self.stride < other.stride)
480-
481-
def __gt__(self, other):
482-
return guard_size_oblivious(self.stride > other.stride)
483-
484-
def __le__(self, other):
485-
return guard_size_oblivious(self.stride <= other.stride)
494+
# for backed symbols, this is practically a < operation
495+
# for unbacked, we return True if < is statically known,
496+
# then try to answer this symbolically, with stride ordering semantics
497+
# (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True)
498+
return (
499+
guard_or_false(
500+
self.stride < other.stride
501+
) # checks statically known inequality
502+
or (
503+
(
504+
guard_or_false(self.stride == 0)
505+
or guard_or_false(other.stride % self.stride == 0)
506+
)
507+
and guard_or_true(self.stride != other.stride)
508+
) # checks symbolic inequality (e.g. u0 < 2048 * u0)
509+
)
486510

487-
def __ge__(self, other):
488-
return guard_size_oblivious(self.stride >= other.stride)
511+
lengths_and_strides = sorted(map(K, sizes, strides))
489512

490-
def __eq__(self, other):
491-
return guard_size_oblivious(self.stride == other.stride)
513+
# verify actual strides match the expected (composed sizes)
514+
sizes = [x.size for x in lengths_and_strides][::-1]
515+
strides = [x.stride for x in lengths_and_strides][::-1]
516+
return check_contiguous_sizes_strides(sizes, strides, false_if_dde=True)
492517

493-
lengths_and_strides = sorted(map(K, a.shape, a.stride()))
494518

495-
expected_stride = 1
496-
for length, stride in lengths_and_strides:
497-
if guard_or_false(length == 1):
498-
continue
519+
def is_non_overlapping_and_dense(a: Tensor) -> bool:
520+
"""
521+
True when a tensor is non-overlapping and dense.
499522
500-
if guard_size_oblivious(stride != expected_stride):
501-
return False
523+
A tensor is non-overlapping and dense when there exists a permutation of
524+
its dimensions that is contiguous.
525+
"""
526+
from torch.fx.experimental.symbolic_shapes import guard_or_false
502527

503-
expected_stride *= length
528+
if a.is_sparse:
529+
return False
504530

505-
return True
531+
return _is_non_overlapping_and_dense_or_false(a.shape, a.stride())
506532

507533

508534
# NOTE: Based on the implementation in TensorIterator.cpp, but note that

0 commit comments

Comments
 (0)
0