-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[dynamic shapes] prims_common non_overlapping_and_dense #160462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,14 +254,14 @@ def check_all_strides( | |
| return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) | ||
|
|
||
|
|
||
| # This function is equivalent to compute_contiguous() from TensorImpl.cpp | ||
| 8000 def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: | ||
| def check_contiguous_sizes_strides(sizes, strides, false_if_dde=False): | ||
| """ | ||
| Tests whether a tensor is contiguous or not. | ||
|
|
||
| Tensors are contiguous when they have no elements, | ||
| one element, or when they have "nested" strides. | ||
| Performs an equality check between actual stride & expected stride (based on composed sizes), | ||
| handling contiguous stride representations: | ||
| e.g. torch.empty(u0, u1, u2).contiguous().stride() -> (Max(1, u1) * Max(1, u2), Max(1, u2), 1) | ||
| and we'd like to treat this equal to (u1 * u2, u2, 1) for comparison purposes. | ||
| """ | ||
|
|
||
| from torch.fx.experimental.symbolic_shapes import ( | ||
| guard_or_false, | ||
| guard_or_true, | ||
|
|
@@ -272,13 +272,10 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: | |
| maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious | ||
| maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious | ||
|
|
||
| if maybe_guard_or_false(a.numel() < 2): | ||
| return True | ||
|
|
||
| expected_stride = 1 | ||
| expected_stride_max = 1 | ||
|
|
||
| for x, y in reversed(tuple(zip(a.shape, a.stride()))): | ||
| for x, y in reversed(tuple(zip(sizes, strides))): | ||
| # Skips checking strides when a dimension has length 1. | ||
| if maybe_guard_or_false(x == 1): | ||
| continue | ||
|
|
@@ -299,6 +296,29 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: | |
| return True | ||
|
|
||
|
|
||
| # This function is equivalent to compute_contiguous() from TensorImpl.cpp | ||
| def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: | ||
| """ | ||
| Tests whether a tensor is contiguous or not. | ||
|
|
||
| Tensors are contiguous when they have no elements, | ||
| one element, or when they have "nested" strides. | ||
| """ | ||
| from torch.fx.experimental.symbolic_shapes import ( | ||
| guard_or_false, | ||
| guard_size_oblivious, | ||
| ) | ||
|
|
||
| maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious | ||
|
|
||
| if maybe_guard_or_false(a.numel() < 2): | ||
| return True | ||
|
|
||
| return check_contiguous_sizes_strides( | ||
| a.shape, a.stride(), false_if_dde=false_if_dde | ||
| ) | ||
|
|
||
|
|
||
| # This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp | ||
| def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool: | ||
| # NHWC or not channels last 2D contiguous | ||
|
|
@@ -438,32 +458,27 @@ def is_channels_last_contiguous_or_false(a: Tensor) -> bool: | |
| ) or is_channels_last_contiguous_or_false_3d(a) | ||
|
|
||
|
|
||
| def is_non_overlapping_and_dense(a: Tensor) -> bool: | ||
| def _is_non_overlapping_and_dense_or_false(sizes, strides) -> bool: | ||
| """ | ||
| True when a tensor is non-overlapping and dense. | ||
| Helper function for is_non_overlapping_and_dense. | ||
| For unbacked sizes & strides, returns True only if symbolically non-overlapping & dense, | ||
| and False otherwise. | ||
|
|
||
| A tensor is non-overlapping and dense when there exists a permutation of | ||
| its dimensions that is contiguous. | ||
| e.g. sizes: [u0, u1], strides: [u2, u3] | ||
| this may be non-overlapping & dense at runtime, for values {u0: 4, u1: 4, u2: 4, u3: 1}, | ||
| but isn't true for all values. | ||
| """ | ||
| from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true | ||
| from torch.utils._sympy.functions import Max | ||
|
|
||
| from torch.fx.experimental.symbolic_shapes import ( | ||
| guard_or_false, | ||
| guard_size_oblivious, | ||
| ) | ||
|
|
||
| if a.is_sparse: | ||
| return False | ||
|
|
||
| # Short-circuits if the tensor is already contiguous or channels-last contiguous | ||
| if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a): | ||
| # Short-circuits for 0/1-element tensors | ||
| if guard_or_false(prod(sizes) < 2): # type: ignore[operator] | ||
| return True | ||
|
|
||
| # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp | ||
|
|
||
| # Short-circuits for tensors of rank one, which are | ||
| # non-overlapping and "dense" if their stride is one | ||
| if a.ndim == 1: | ||
| return a.stride()[0] == 1 | ||
| if len(sizes) == 1: | ||
| return guard_or_false(strides[0] == 1) | ||
|
|
||
| # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous | ||
| # Sorts (length, stride) pairs by stride | ||
|
|
@@ -476,33 +491,44 @@ class K(NamedTuple): | |
| stride: int | ||
|
|
||
| def __lt__(self, other): | ||
| return guard_size_oblivious(self.stride < other.stride) | ||
|
|
||
| def __gt__(self, other): | ||
| return guard_size_oblivious(self.stride > other.stride) | ||
|
|
||
| def __le__(self, other): | ||
| return guard_size_oblivious(self.stride <= other.stride) | ||
| # for backed symbols, this is practically a < operation | ||
| # for unbacked, we return True if < is statically known, | ||
| # then try to answer this symbolically, with stride ordering semantics | ||
| # (e.g. u0 < u0 is False, u0 < u1 is False with no axioms, u0 < 2 * u0 is True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a comnent explaining why its ok if in the worst case for unbacked we picked a wrong order. |
||
| return ( | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| guard_or_false( | ||
| self.stride < other.stride | ||
| ) # checks statically known inequality | ||
| or ( | ||
| ( | ||
| guard_or_false(self.stride == 0) | ||
| or guard_or_false(other.stride % self.stride == 0) | ||
| ) | ||
| and guard_or_true(self.stride != other.stride) | ||
| ) # checks symbolic inequality (e.g. u0 < 2048 * u0) | ||
| ) | ||
|
|
||
| def __ge__(self, other): | ||
| return guard_size_oblivious(self.stride >= other.stride) | ||
| lengths_and_strides = sorted(map(K, sizes, strides)) | ||
|
|
||
| def __eq__(self, other): | ||
| return guard_size_oblivious(self.stride == other.stride) | ||
| # verify actual strides match the expected (composed sizes) | ||
| sizes = [x.size for x in lengths_and_strides][::-1] | ||
| strides = [x.stride for x in lengths_and_strides][::-1] | ||
| return check_contiguous_sizes_strides(sizes, strides, false_if_dde=True) | ||
|
|
||
| lengths_and_strides = sorted(map(K, a.shape, a.stride())) | ||
|
|
||
| expected_stride = 1 | ||
| for length, stride in lengths_and_strides: | ||
| if guard_or_false(length == 1): | ||
| continue | ||
| def is_non_overlapping_and_dense(a: Tensor) -> bool: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this for BC compatibility? maybe its ok |
||
| """ | ||
| True when a tensor is non-overlapping and dense. | ||
|
|
||
| if guard_size_oblivious(stride != expected_stride): | ||
| return False | ||
| A tensor is non-overlapping and dense when there exists a permutation of | ||
| its dimensions that is contiguous. | ||
| """ | ||
| from torch.fx.experimental.symbolic_shapes import guard_or_false | ||
|
|
||
| expected_stride *= length | ||
| if a.is_sparse: | ||
| return False | ||
|
|
||
| return True | ||
| return _is_non_overlapping_and_dense_or_false(a.shape, a.stride()) | ||
|
|
||
|
|
||
| # NOTE: Based on the implementation in TensorIterator.cpp, but note that | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_non_overlapping_and_dense-> is_non_overlapping_and_dense_or_false?