-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims #150127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims #150127
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
7510a77
init
pianpwk 4630374
assume >= 0
pianpwk e076ca4
test
pianpwk ed8ed35
assume numel = prod(shape)
pianpwk fafaa47
test
pianpwk fe734e8
workaround
pianpwk 11a25a7
simple case
pianpwk 8a48fad
weird test
pianpwk ac40d4c
lint
pianpwk 207d3fb
switch to guard_or_true
pianpwk 1790f35
Update __init__.py
pianpwk 1dd4c2b
Update __init__.py
8000
pianpwk 9d35a27
lint
pianpwk cab5f4a
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk f5b10c4
stash
pianpwk 2eb56d0
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 9604307
reduce to only changing fast path
pianpwk 00031a1
Update fx.experimental.rst
pianpwk 983174e
Update fx.experimental.rst
pianpwk 7c8d113
lint
pianpwk 20749a3
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 5885633
add loop back
pianpwk 18496a9
Update __init__.py
pianpwk af74f34
Update test_export.py
pianpwk 69c1abe
try
pianpwk 2ac1303
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk ffe4273
comments
pianpwk fa11fa1
Update __init__.py
pianpwk 85fd59e
Update __init__.py
pianpwk 6d322e6
Update __init__.py
pianpwk 07c6ac9
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 5fa8ce9
test masked linear
pianpwk c85b5ce
lint
pianpwk b0db707
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk 2bca726
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing 8000 h2>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4301,7 +4301,7 @@ class M_v0(torch.nn.Module): | |
def forward(self, t): | ||
items = [t[i].item() for i in range(t.numel())] | ||
r = torch.randn([items[0], items[1]]) | ||
# Could not guard on data-dependent expression Eq(u2, -1) | ||
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0) | ||
return r.view(items[0], items[2]) | ||
pianpwk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
M = M_v0 | ||
|
@@ -4310,69 +4310,23 @@ def forward(self, t): | |
"The following call raised this error(.*\n)+" | ||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" | ||
"To fix the error, insert one of the following checks before this call.*:\n" | ||
f".*{re.escape('torch._check( 8000 items[2] == (-1))')}.*\n" | ||
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+" | ||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}", | ||
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n" | ||
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+" | ||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}" | ||
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}", | ||
): | ||
export(N(), (t,), strict=strict) | ||
|
||
class M_v1(torch.nn.Module): | ||
def forward(self, t): | ||
items = [t[i].item() for i in range(t.numel())] | ||
r = torch.randn([items[0], items[1]]) | ||
# Could not guard on data-dependent expression Eq(u2, -1) | ||
torch._check(items[2] != -1) | ||
# Could not guard on data-dependent expression u2 >= 0 | ||
# TODO(pianpwk): this isn't the suggested fixes. | ||
# fix issue with % being interpreted as PythonMod instead of Mod | ||
torch._check(items[1] == items[2]) | ||
return r.view(items[0], items[2]) | ||
|
||
M = M_v1 | ||
with self.assertRaisesRegex( | ||
error_type, | ||
"The following call raised this error(.*\n)+" | ||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" | ||
"To fix the error, insert one of the following checks before this call.*:\n" | ||
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n" | ||
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+" | ||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}", | ||
): | ||
export(N(), (t,), strict=strict) | ||
|
||
class M_v2(torch.nn.Module): | ||
def forward(self, t): | ||
items = [t[i].item() for i in range(t.numel())] | ||
r = torch.randn([items[0], items[1]]) | ||
# Could not guard on data-dependent expression Eq(u2, -1) | ||
torch._check(items[2] != -1) | ||
# Could not guard on data-dependent expression u2 >= 0 | ||
torch._check(items[2] >= 0) | ||
# Could not guard on data-dependent expression Eq(u1, u2) | ||
return r.view(items[0], items[2]) | ||
|
||
M = M_v2 | ||
with self.assertRaisesRegex( | ||
error_type, | ||
"The following call raised this error(.*\n)+" | ||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" | ||
"To fix the error, insert one of the following checks before this call.*:\n" | ||
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n" | ||
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+" | ||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}", | ||
): | ||
export(N(), (t,), strict=strict) | ||
|
||
class M_v3(torch.nn.Module): | ||
def forward(self, t): | ||
items = [t[i].item() for i in range(t.numel())] | ||
r = torch.randn([items[0], items[1]]) | ||
# Could not guard on data-dependent expression Eq(u2, -1) | ||
torch._check(items[2] != -1) | ||
# Could not guard on data-dependent expression u2 >= 0 | ||
torch._check(items[2] >= 0) | ||
# Could not guard on data-dependent expression Eq(u1, u2) | ||
torch._check(items[2] == r.shape[1]) | ||
return r.view(items[0], items[2]) | ||
|
||
M = M_v3 | ||
export(N(), (t,), strict=strict) | ||
|
||
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): | ||
|
@@ -4484,6 +4438,29 @@ def forward(self, x, offsets_t, fixes): | |
fixes=[], # nothing to fix! | ||
) | ||
|
||
def test_simple_unbacked_view(self): | ||
class Foo(torch.nn.Module): | ||
def forward(self, x): | ||
u0 = x.item() | ||
y = torch.empty(5, u0) | ||
return y.view(u0, 5) # [5, u0] -> [u0, 5] | ||
|
||
ep = export(Foo(), (torch.tensor([9]),)) | ||
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8) | ||
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5) | ||
|
||
class Foov2(torch.nn.Module): | ||
def forward(self, xs): | ||
xsl = xs.tolist() | ||
a, b = xsl | ||
x = torch.zeros(a) | ||
return x.reshape(b) | ||
|
||
xs = torch.tensor([4, 4]) | ||
ep = export(Foov2(), (xs,)) | ||
self.assertEqual(ep.module()(xs).size(0), 4) | ||
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5) | ||
|
||
def test_no_suggested_fixes_for_data_dependent_errors(self): | ||
# suggested fixes for data-dependent errors only work in non-strict mode | ||
strict = False | ||
|
@@ -7549,22 +7526,19 @@ def forward(self, xs, y): | |
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2 | ||
) | ||
|
||
def test_check_is_size_error(self): | ||
def test_no_check_is_size_error(self): | ||
class Module(torch.nn.Module): | ||
def forward(self, x): | ||
a = x.item() | ||
# We cannot automatically infer a is a size here because view | ||
# accepts -1 | ||
return torch.randn(24).view(a, 4) | ||
|
||
f = Module() | ||
if is_non_strict_test(self._testMethodName): | ||
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode | ||
else: | ||
error = torch._dynamo.exc.UserError | ||
error_msg = r"Could not guard on data-dependent expression" | ||
with self.assertRaisesRegex(error, error_msg): | ||
_ = export(f, (torch.tensor(6),)) | ||
ep = export(f, (torch.tensor(6),)) | ||
ep.module()(torch.tensor(6)) | ||
with self.assertRaisesRegex( | ||
RuntimeError, r"Runtime assertion failed for .* u.* 6" | ||
): | ||
ep.module()(torch.tensor(5)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exports fine now without check_is_size, we just assume it's >= 0 and it later specializes to 6. |
||
|
||
def test_is_non_negative_check_function(self): | ||
import sympy as sp | ||
|
@@ -13446,7 +13420,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
node.target == torch.ops.aten._assert_scalar.default | ||
for node in ep.graph.nodes | ||
].count(True) | ||
self.assertEqual(num_asserts, 1) | ||
self.assertEqual(num_asserts, 2) | ||
with self.assertRaises(RuntimeError): | ||
ep.module()(torch.randn(4, 2)) | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor: | |
|
||
|
||
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: | ||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq | ||
from torch._dynamo.exc import UserError, UserErrorType | ||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true | ||
|
||
# Creates a valid shape | ||
shape = utils.extract_shape_from_varargs(shape, validate=False) | ||
|
@@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL | |
shape = utils.infer_size(shape, a.numel()) | ||
|
||
# Special-cases tensors with no elements | ||
if guard_size_oblivious(a.numel() == 0): | ||
if guard_or_false(a.numel() == 0): | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) | ||
|
||
# Special-cases reshaping zero dim tensors | ||
|
@@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL | |
return torch.as_strided(a, [dim0, dim1], [dim1, 1]) | ||
|
||
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape | ||
shape_numel = reduce(operator.mul, shape, 1) | ||
pianpwk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch._check( | ||
a.numel() == shape_numel, | ||
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", | ||
) | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
deferred: list[Callable[[], bool]] = [] | ||
|
||
# NOTE [Reshape Algorithm] | ||
# This algorithm works by attempting to greedily construct the desired dimensions in | ||
|
@@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL | |
continue | ||
|
||
# Skips dimensions that are already the correct length | ||
if guard_size_oblivious(length == a_.shape[idx]): | ||
10000 td> | if guard_or_false(length == a_.shape[idx]): | |
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
idx = idx + 1 | ||
continue | ||
|
||
# Gathers enough original dimensions such that this new dimension can be created | ||
# Note that this accumulation will terminate because we've verified a and the shape | ||
# specify the same number of elements above | ||
def maybe_throw_dde(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe_throw_bypassed_dde |
||
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input | ||
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing | ||
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd | ||
# figure out a valid reshape later in the loop. | ||
# But we failed, either by running out of dimensions, or we couldn't figure out the strides, | ||
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user | ||
# can torch._check() to avoid this. | ||
for f in deferred: | ||
f() | ||
|
||
accum = a_.shape[idx] | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end = idx | ||
while guard_size_oblivious(accum % length != 0): | ||
while guard_or_true(accum % length != 0): | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
deferred.append(lambda: bool(accum % length != 0)) | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if end == a_.ndim - 1: | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
maybe_throw_dde() | ||
end = end + 1 | ||
accum = accum * a_.shape[end] | ||
if end != idx: | ||
|
@@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL | |
if allow_copy: | ||
return prims.reshape(a, shape) | ||
|
||
maybe_throw_dde() | ||
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" | ||
raise ValueError(msg) | ||
|
||
a_ = flatten(a_, idx, end) | ||
|
||
# Splits the (possibly flattened) dimension to create the desired dim length | ||
if guard_size_oblivious(accum != length): | ||
# Splits the (possibly flattened) dimension to create the desired dim length. | ||
# guard_or_true is safe due to the tail unsqueeze routine. | ||
if guard_or_true(accum != length): | ||
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
laithsakka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
a_ = prims.split_dim(a_, idx, length) | ||
|
||
idx = idx + 1 | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.