8000 [dynamic shapes] prims_common non_overlapping_and_dense by pianpwk · Pull Request #160462 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

8000 Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion test/test_dynamic_shapes.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def test_mul_int_oo_nan(self):
s2 = create_symint(shape_env, 5, duck=False)
bool(s0 * (s1 // s0) == s2)

def test_non_overlapping_and_dense(self):
def test_non_overlapping_and_dense_backed(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
Expand Down Expand Up @@ -896,6 +896,64 @@ def test_non_overlapping_and_dense_unbacked(self):
)
)

def test_prims_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
cf = torch._prims_common.is_non_overlapping_and_dense
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_non_overlapping_and_dense-> is_non_overlapping_and_dense_or_false?


# backed case
a0 = create_symint(shape_env, 5)
self.assertTrue(cf(torch.empty_strided((a0, 7), (1, a0), device="meta")))

# unbacked
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))

Max = torch.sym_max
self.assertTrue(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
)
)
self.assertFalse(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(Max(1, u0), Max(1, u0), 1, 3 * Max(1, u0)),
device="meta",
)
)
)

# return False on arbitrary strides
u1 = shape_env.create_unbacked_symint()
torch._check_is_size(u1)
self.assertFalse(
cf(
torch.empty_strided(
(2 * u0, u0, 1),
(u1, u0, u0 + u1),
device="meta",
)
)
)
self.assertFalse(
cf(
torch.empty_strided(
(2, 3, u0),
(u1, 3, 1),
device="meta",
)
)
)

def test_sympy_optimized_add_binary_search(self):
import sympy

Expand Down
122 changes: 74 additions & 48 deletions torch/_prims_common/__init__.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ def check_all_strides(
return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)


# This function is equivalent to compute_contiguous() from TensorImpl.cpp
8000 def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
def check_contiguous_sizes_strides(sizes, strides, false_if_dde=False):
"""
Tests whether a tensor is contiguous or not.

Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides.
Performs an equality check between actual stride & expected stride (based on composed sizes),
handling contiguous stride representations:
e.g. torch.empty(u0, u1, u2).contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1)
and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes.
"""

from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
Expand All @@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious

if maybe_guard_or_false(a.numel() < 2):
return True

expected_stride = 1
expected_stride_max = 1

for x, y in reversed(tuple(zip(a.shape, a.stride()))):
for x, y in reversed(tuple(zip(sizes, strides))):
# Skips checking strides when a dimension has length 1.
if maybe_guard_or_false(x == 1):
continue
Expand All @@ -299,6 +296,29 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
return True


# This function is equivalent to compute_contiguous() from TensorImpl.cpp
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
"""
Tests whether a tensor is contiguous or not.

Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides.
"""
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
)

maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious

if maybe_guard_or_false(a.numel() < 2):
return True

return check_contiguous_sizes_strides(
a.shape, a.stride(), false_if_dde=false_if_dde
)


# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
# NHWC or not channels last 2D contiguous
Expand Down Expand Up @@ -438,32 +458,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
) or is_channels_last_contiguous_or_false_3d(a)


def is_non_overlapping_and_dense(a: Tensor) -> bool:
def _is_non_overlapping_and_dense_or_false(sizes, strides) -> bool:
"""
True when a tensor is non-overlapping and dense.
Helper function for is_non_overlapping_and_dense.
For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense,
and False otherwise.

A tensor is non-overlapping and dense when there exists a permutation of
its dimensions that is contiguous.
e.g. sizes: [u0, u1], strides: [u2, u3]
this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1},
but isn't true for all values.
"""
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
from torch.utils._sympy.functions import Max

from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
)

if a.is_sparse:
return False

# Short-circuits if the tensor is already contiguous or channels-last contiguous
if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a):
# Short-circuits for 0/1-element tensors
if guard_or_false(prod(sizes) < 2): # type: ignore[operator]
return True

# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp

# Short-circuits for tensors of rank one, which are
# non-overlapping and "dense" if their stride is one
if a.ndim == 1:
return a.stride()[0] == 1
if len(sizes) == 1:
return guard_or_false(strides[0] == 1)

# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
# Sorts (length, stride) pairs by stride
Expand All @@ -476,33 +491,44 @@ class K(NamedTuple):
stride: int

def __lt__(self, other):
return guard_size_oblivious(self.stride < other.stride)

def __gt__(self, other):
return guard_size_oblivious(self.stride > other.stride)

def __le__(self, other):
return guard_size_oblivious(self.stride <= other.stride)
# for backed symbols, this is practically a < operation
# for unbacked, we return True if < is statically known,
# then try to answer this symbolically, with stride ordering semantics
# (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comnent explaining why its ok if in the worst case for unbacked we picked a wrong order.

return (
guard_or_false(
self.stride < other.stride
) # checks statically known inequality
or (
(
guard_or_false(self.stride == 0)
or guard_or_false(other.stride % self.stride == 0)
)
and guard_or_true(self.stride != other.stride)
) # checks symbolic inequality (e.g. u0 < 2048 * u0)
)

def __ge__(self, other):
return guard_size_oblivious(self.stride >= other.stride)
lengths_and_strides = sorted(map(K, sizes, strides))

def __eq__(self, other):
return guard_size_oblivious(self.stride == other.stride)
# verify actual strides match the expected (composed sizes)
sizes = [x.size for x in lengths_and_strides][::-1]
strides = [x.stride for x in lengths_and_strides][::-1]
return check_contiguous_sizes_strides(sizes, strides, false_if_dde=True)

lengths_and_strides = sorted(map(K, a.shape, a.stride()))

expected_stride = 1
for length, stride in lengths_and_strides:
if guard_or_false(length == 1):
continue
def is_non_overlapping_and_dense(a: Tensor) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for BC compatibility?
In my mind calling is_non_overlapping_and_dense should throw DDE
calling _is_non_overlapping_and_dense_or_false should not
now this is sort of alias for _is_non_overlapping_and_dense_or_false
meaning unbacked semantics are implicit to users calling is_non_overlapping_and_dense

maybe its ok

"""
True when a tensor is non-overlapping and dense.

if guard_size_oblivious(stride != expected_stride):
return False
A tensor is non-overlapping and dense when there exists a permutation of
its dimensions that is contiguous.
"""
from torch.fx.experimental.symbolic_shapes import guard_or_false

expected_stride *= length
if a.is_sparse:
return False

return True
return _is_non_overlapping_and_dense_or_false(a.shape, a.stride())


# NOTE: Based on the implementation in TensorIterator.cpp, but note that
Expand Down
Loading
0