8000 [dynamic shapes] guard_or_false for _reshape_view_helper, utils._infe… · pytorch/pytorch@1dd2033 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1dd2033

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)
For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export. For infer_size: assumes if user passes us an unbacked, it's probably not -1 Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR Pull Request resolved: #150127 Approved by: https://github.com/laithsakka
1 parent c8240e3 commit 1dd2033

File tree

3 files changed

+79
-77
lines changed

3 files changed

+79
-77
lines changed

test/export/test_export.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4317,7 +4317,7 @@ class M_v0(torch.nn.Module):
43174317
def forward(self, t):
43184318
items = [t[i].item() for i in range(t.numel())]
43194319
r = torch.randn([items[0], items[1]])
4320-
# Could not guard on data-dependent expression Eq(u2, -1)
4320+
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
43214321
return r.view(items[0], items[2])
43224322

43234323
M = M_v0
@@ -4326,69 +4326,23 @@ def forward(self, t):
43264326
"The following call raised this error(.*\n)+"
43274327
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
43284328
"To fix the error, insert one of the following checks before this call.*:\n"
4329-
f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
4330-
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
4331-
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
4329+
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
4330+
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
4331+
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4332+
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
43324333
):
43334334
export(N(), (t,), strict=strict)
43344335

43354336
class M_v1(torch.nn.Module):
43364337
def forward(self, t):
43374338
items = [t[i].item() for i in range(t.numel())]
43384339
r = torch.randn([items[0], items[1]])
4339-
# Could not guard on data-dependent expression Eq(u2, -1)
4340-
torch._check(items[2] != -1)
4341-
# Could not guard on data-dependent expression u2 >= 0
4340+
# TODO(pianpwk): this isn't the suggested fixes.
4341+
# fix issue with % being interpreted as PythonMod instead of Mod
4342+
torch._check(items[1] == items[2])
43424343
return r.view(items[0], items[2])
43434344

43444345
M = M_v1
4345-
with self.assertRaisesRegex(
4346-
error_type,
4347-
"The following call raised this error(.*\n)+"
4348-
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4349-
"To fix the error, insert one of the following checks before this call.*:\n"
4350-
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
4351-
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
4352-
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
4353-
):
4354-
export(N(), (t,), strict=strict)
4355-
4356-
class M_v2(torch.nn.Module):
4357-
def forward(self, t):
4358-
items = [t[i].item() for i in range(t.numel())]
4359-
r = torch.randn([items[0], items[1]])
4360-
# Could not guard on data-dependent expression Eq(u2, -1)
4361-
torch._check(items[2] != -1)
4362-
# Could not guard on data-dependent expression u2 >= 0
4363-
torch._check(items[2] >= 0)
4364-
# Could not guard on data-dependent expression Eq(u1, u2)
4365-
return r.view(items[0], items[2])
4366-
4367-
M = M_v2
4368-
with self.assertRaisesRegex(
4369-
error_type,
4370-
"The following call raised this error(.*\n)+"
4371-
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4372-
"To fix the error, insert one of the following checks before this call.*:\n"
4373-
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
4374-
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
4375-
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
4376-
):
4377-
export(N(), (t,), strict=strict)
4378-
4379-
class M_v3(torch.nn.Module):
4380-
def forward(self, t):
4381-
items = [t[i].item() for i in range(t.numel())]
4382-
r = torch.randn([items[0], items[1]])
4383-
# Could not guard on data-dependent expression Eq(u2, -1)
4384-
torch._check(items[2] != -1)
4385-
# Could not guard on data-dependent expression u2 >= 0
4386-
torch._check(items[2] >= 0)
4387-
# Could not guard on data-dependent expression Eq(u1, u2)
4388-
torch._check(items[2] == r.shape[1])
4389-
return r.view(items[0], items[2])
4390-
4391-
M = M_v3
43924346
export(N(), (t,), strict=strict)
43934347

43944348
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@@ -4500,6 +4454,29 @@ def forward(self, x, offsets_t, fixes):
45004454
fixes=[], # nothing to fix!
45014455
)
45024456

4457+
def test_simple_unbacked_view(self):
4458+
class Foo(torch.nn.Module):
4459+
def forward(self, x):
4460+
u0 = x.item()
4461+
y = torch.empty(5, u0)
4462+
return y.view(u0, 5) # [5, u0] -> [u0, 5]
4463+
4464+
ep = export(Foo(), (torch.tensor([9]),))
4465+
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
4466+
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
4467+
4468+
class Foov2(torch.nn.Module):
4469+
def forward(self, xs):
4470+
xsl = xs.tolist()
4471+
a, b = xsl
4472+
x = torch.zeros(a)
4473+
return x.reshape(b)
4474+
4475+
xs = torch.tensor([4, 4])
4476+
ep = export(Foov2(), (xs,))
4477+
self.assertEqual(ep.module()(xs).size(0), 4)
4478+
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
4479+
45034480
def test_no_suggested_fixes_for_data_dependent_errors(self):
45044481
# suggested fixes for data-dependent errors only work in non-strict mode
45054482
strict = False
@@ -7422,22 +7399,19 @@ def forward(self, xs, y):
74227399
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
74237400
)
74247401

7425-
def test_check_is_size_error(self):
7402+
def test_no_check_is_size_error(self):
74267403
class Module(torch.nn.Module):
74277404
def forward(self, x):
74287405
a = x.item()
7429-
# We cannot automatically infer a is a size here because view
7430-
# accepts -1
74317406
return torch.randn(24).view(a, 4)
74327407

74337408
f = Module()
7434-
if is_non_strict_test(self._testMethodName):
7435-
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
7436-
else:
7437-
error = torch._dynamo.exc.UserError
7438-
error_msg = r"Could not guard on data-dependent expression"
7439-
with self.assertRaisesRegex(error, error_msg):
7440-
_ = export(f, (torch.tensor(6),))
7409+
ep = export(f, (torch.tensor(6),))
7410+
ep.module()(torch.tensor(6))
7411+
with self.assertRaisesRegex(
7412+
RuntimeError, r"Runtime assertion failed for .* u.* 6"
7413+
):
7414+
ep.module()(torch.tensor(5))
74417415

74427416
def test_is_non_negative_check_function(self):
74437417
import sympy as sp
@@ -13281,7 +13255,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1328113255
node.target == torch.ops.aten._assert_scalar.default
1328213256
for node in ep.graph.nodes
1328313257
].count(True)
13284-
self.assertEqual(num_asserts, 1)
13258+
self.assertEqual(num_asserts, 2)
1328513259
with self.assertRaises(RuntimeError):
1328613260
ep.module()(torch.randn(4, 2))
1328713261

torch/_prims_common/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
924924
Infers the size of a dim with size -1, if it exists.
925925
Also checks that new shape is compatible with the number of elements.
926926
"""
927+
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
928+
927929
dim = None
928930
newsize = 1
929931
for i, d in enumerate(shape):
930-
if d == -1:
932+
if guard_or_false(d == -1):
931933
torch._check(dim is None, lambda: "only one dimension can be inferred")
932934
dim = i
933-
elif d >= 0:
934-
newsize *= d
935935
else:
936-
torch._check(False, lambda: f"invalid shape dimension {d}")
936+
torch._check(
937+
d >= 0,
938+
lambda: (
939+
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
940+
"If this was meant to be inferred, please explicitly pass in -1."
941+
),
942+
)
943+
newsize *= d
937944
if dim is None:
938945
torch._check(
939946
numel == newsize,
940947
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
941948
)
942949
else:
943-
from torch.fx.experimental.symbolic_shapes import definitely_true
944-
945950
torch._check(
946951
newsize != 0,
947952
lambda: (

torch/_refs/__init__.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
37173717

37183718

37193719
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
3720-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
3720+
from torch._dynamo.exc import UserError, UserErrorType
3721+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
37213722

37223723
# Creates a valid shape
37233724
shape = utils.extract_shape_from_varargs(shape, validate=False)
@@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37263727
shape = utils.infer_size(shape, a.numel())
37273728

37283729
# Special-cases tensors with no elements
3729-
if guard_size_oblivious(a.numel() == 0):
3730+
if guard_or_false(a.numel() == 0):
37303731
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
37313732

37323733
# Special-cases reshaping zero dim tensors
@@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37623763
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
37633764

37643765
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
3766+
shape_numel = reduce(operator.mul, shape, 1)
3767+
torch._check(
3768+
a.numel() == shape_numel,
3769+
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
3770+
)
3771+
deferred: list[Callable[[], bool]] = []
37653772

37663773
# NOTE [Reshape Algorithm]
37673774
# This algorithm works by attempting to greedily construct the desired dimensions in
@@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37943801
continue
37953802

37963803
# Skips dimensions that are already the correct length
3797-
if guard_size_oblivious(length == a_.shape[idx]):
3804+
if guard_or_false(length == a_.shape[idx]):
37983805
idx = idx + 1
37993806
continue
38003807

38013808
# Gathers enough original dimensions such that this new dimension can be created
38023809
# Note that this accumulation will terminate because we've verified a and the shape
38033810
# specify the same number of elements above
3811+
def maybe_throw_dde():
3812+
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
3813+
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
3814+
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
3815+
# figure out a valid reshape later in the loop.
3816+
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
3817+
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
3818+
# can torch._check() to avoid this.
3819+
for f in deferred:
3820+
f()
3821+
38043822
accum = a_.shape[idx]
38053823
end = idx
3806-
while guard_size_oblivious(accum % length != 0):
3824+
while guard_or_true(accum % length != 0):
3825+
deferred.append(lambda: bool(accum % length != 0))
3826+
if end == a_.ndim - 1:
3827+
maybe_throw_dde()
38073828
end = end + 1
38083829
accum = accum * a_.shape[end]
38093830
if end != idx:
@@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
38173838
if allow_copy:
38183839
return prims.reshape(a, shape)
38193840

3841+
maybe_throw_dde()
38203842
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
38213843
raise ValueError(msg)
38223844

38233845
a_ = flatten(a_, idx, end)
38243846

3825-
# Splits the (possibly flattened) dimension to create the desired dim length
3826-
if guard_size_oblivious(accum != length):
3847+
# Splits the (possibly flattened) dimension to create the desired dim length.
3848+
# guard_or_true is safe due to the tail unsqueeze routine.
3849+
if guard_or_true(accum != length):
38273850
a_ = prims.split_dim(a_, idx, length)
38283851

38293852
idx = idx + 1

0 commit comments

Comments
 (0)
0