8000 Revert "[dynamic shapes] guard_or_false for _reshape_view_helper, uti… · pytorch/pytorch@97d97ae · GitHub
[go: up one dir, main page]

Skip to content

Commit 97d97ae

Browse files
Revert "[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)"
This reverts commit 1dd2033. Reverted #150127 on behalf of https://github.com/clee2000 due to maybe caused export test to fail? export/test_draft_export.py::TestDraftExport::test_masked_linear [GH job link](https://github.com/pytorch/pytorch/actions/runs/14538768138/job/40794985504) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/1dd2033c0a1de460ee2bad8d64c36a0344886071), bad TD ([comment](#150127 (comment)))
1 parent bd77c3e commit 97d97ae

File tree

3 files changed

+77
-79
lines changed
Filter options

3 files changed

+77
-79
lines changed

test/export/test_export.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4306,7 +4306,7 @@ class M_v0(torch.nn.Module):
43064306
def forward(self, t):
43074307
items = [t[i].item() for i in range(t.numel())]
43084308
r = torch.randn([items[0], items[1]])
4309-
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
4309+
# Could not guard on data-dependent expression Eq(u2, -1)
43104310
return r.view(items[0], items[2])
43114311

43124312
M = M_v0
@@ -4315,23 +4315,69 @@ def forward(self, t):
43154315
"The following call raised this error(.*\n)+"
43164316
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
43174317
"To fix the error, insert one of the following checks before this call.*:\n"
4318-
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
4319-
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
4320-
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
4321-
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
4318+
f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
4319+
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
4320+
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
43224321
):
43234322
export(N(), (t,), strict=strict)
43244323

43254324
class M_v1(torch.nn.Module):
43264325
def forward(self, t):
43274326
items = [t[i].item() for i in range(t.numel())]
43284327
r = torch.randn([items[0], items[1]])
4329-
# TODO(pianpwk): this isn't the suggested fixes.
4330-
# fix issue with % being interpreted as PythonMod instead of Mod
4331-
torch._check(items[1] == items[2])
4328+
# Could not guard on data-dependent expression Eq(u2, -1)
4329+
torch._check(items[2] != -1)
4330+
# Could not guard on data-dependent expression u2 >= 0
43324331
return r.view(items[0], items[2])
43334332

43344333
M = M_v1
4334+
with self.assertRaisesRegex(
4335+
error_type,
4336+
"The following call raised this error(.*\n)+"
4337+
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4338+
"To fix the error, insert one of the following checks before this call.*:\n"
4339+
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
4340+
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
4341+
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
4342+
):
4343+
export(N(), (t,), strict=strict)
4344+
4345+
class M_v2(torch.nn.Module):
4346+
def forward(self, t):
4347+
items = [t[i].item() for i in range(t.numel())]
4348+
r = torch.randn([items[0], items[1]])
4349+
# Could not guard on data-dependent expression Eq(u2, -1)
4350+
torch._check(items[2] != -1)
4351+
# Could not guard on data-dependent expression u2 >= 0
4352+
torch._check(items[2] >= 0)
4353+
# Could not guard on data-dependent expression Eq(u1, u2)
4354+
return r.view(items[0], items[2])
4355+
4356+
M = M_v2
4357+
with self.assertRaisesRegex(
4358+
error_type,
4359+
"The following call raised this error(.*\n)+"
4360+
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
4361+
"To fix the error, insert one of the following checks before this call.*:\n"
4362+
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
4363+
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
4364+
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
4365+
):
4366+
export(N(), (t,), strict=strict)
4367+
4368+
class M_v3(torch.nn.Module):
4369+
def forward(self, t):
4370+
items = [t[i].item() for i in range(t.numel())]
4371+
r = torch.randn([items[0], items[1]])
4372+
# Could not guard on data-dependent expression Eq(u2, -1)
4373+
torch._check(items[2] != -1)
4374+
# Could not guard on data-dependent expression u2 >= 0
4375+
torch._check(items[2] >= 0)
4376+
# Could not guard on data-dependent expression Eq(u1, u2)
4377+
torch._check(items[2] == r.shape[1])
4378+
return r.view(items[0], items[2])
4379+
4380+
M = M_v3
43354381
export(N(), (t,), strict=strict)
43364382

43374383
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@@ -4443,29 +4489,6 @@ def forward(self, x, offsets_t, fixes):
44434489
fixes=[], # nothing to fix!
44444490
)
44454491

4446-
def test_simple_unbacked_view(self):
4447-
class Foo(torch.nn.Module):
4448-
def forward(self, x):
4449-
u0 = x.item()
4450-
y = torch.empty(5, u0)
4451-
return y.view(u0, 5) # [5, u0] -> [u0, 5]
4452-
4453-
ep = export(Foo(), (torch.tensor([9]),))
4454-
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
4455-
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
4456-
4457-
class Foov2(torch.nn.Module):
4458-
def forward(self, xs):
4459-
xsl = xs.tolist()
4460-
a, b = xsl
4461-
x = torch.zeros(a)
4462-
return x.reshape(b)
4463-
4464-
xs = torch.tensor([4, 4])
4465-
ep = export(Foov2(), (xs,))
4466-
self.assertEqual(ep.module()(xs).size(0), 4)
4467-
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
4468-
44694492
def test_no_suggested_fixes_for_data_dependent_errors(self):
44704493
# suggested fixes for data-dependent errors only work in non-strict mode
44714494
strict = False
@@ -7388,19 +7411,22 @@ def forward(self, xs, y):
73887411
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
73897412
)
73907413

7391-
def test_no_check_is_size_error(self):
7414+
def test_check_is_size_error(self):
73927415
class Module(torch.nn.Module):
73937416
def forward(self, x):
73947417
a = x.item()
7418+
# We cannot automatically infer a is a size here because view
7419+
# accepts -1
73957420
return torch.randn(24).view(a, 4)
73967421

73977422
f = Module()
7398-
ep = export(f, (torch.tensor(6),))
7399-
ep.module()(torch.tensor(6))
7400-
with self.assertRaisesRegex(
7401-
RuntimeError, r"Runtime assertion failed for .* u.* 6"
7402-
):
7403-
ep.module()(torch.tensor(5))
7423+
if is_non_strict_test(self._testMethodName):
7424+
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
7425+
else:
7426+
error = torch._dynamo.exc.UserError
7427+
error_msg = r"Could not guard on data-dependent expression"
7428+
with self.assertRaisesRegex(error, error_msg):
7429+
_ = export(f, (torch.tensor(6),))
74047430

74057431
def test_is_non_negative_check_function(self):
74067432
import sympy as sp
@@ -13244,7 +13270,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1324413270
node.target == torch.ops.aten._assert_scalar.default
1324513271
for node in ep.graph.nodes
1324613272
].count(True)
13247-
self.assertEqual(num_asserts, 2)
13273+
self.assertEqual(num_asserts, 1)
1324813274
with self.assertRaises(RuntimeError):
1324913275
ep.module()(torch.randn(4, 2))
1325013276

torch/_prims_common/__init__.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -924,29 +924,24 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
924924
Infers the size of a dim with size -1, if it exists.
925925
Also checks that new shape is compatible with the number of elements.
926926
"""
927-
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
928-
929927
dim = None
930928
newsize = 1
931929
for i, d in enumerate(shape):
932-
if guard_or_false(d == -1):
930+
if d == -1:
933931
torch._check(dim is None, lambda: "only one dimension can be inferred")
934932
dim = i
935-
else:
936-
torch._check(
937-
d >= 0,
938-
lambda: (
939-
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
940-
"If this was meant to be inferred, please explicitly pass in -1."
941-
),
942-
)
933+
elif d >= 0:
943934
newsize *= d
935+
else:
936+
torch._check(False, lambda: f"invalid shape dimension {d}")
944937
if dim is None:
945938
torch._check(
946939
numel == newsize,
947940
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
948941
)
949942
else:
943+
from torch.fx.experimental.symbolic_shapes import definitely_true
944+
950945
torch._check(
951946
newsize != 0,
952947
lambda: (

torch/_refs/__init__.py

Lines changed: 6 additions & 29 deletions
3817
Original file line numberDiff line numberDiff line change
@@ -3717,8 +3717,7 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
37173717

37183718

37193719
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
3720-
from torch._dynamo.exc import UserError, UserErrorType
3721-
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
3720+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
37223721

37233722
# Creates a valid shape
37243723
shape = utils.extract_shape_from_varargs(shape, validate=False)
@@ -3727,7 +3726,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37273726
shape = utils.infer_size(shape, a.numel())
37283727

37293728
# Special-cases tensors with no elements
3730-
if guard_or_false(a.numel() == 0):
3729+
if guard_size_oblivious(a.numel() == 0):
37313730
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
37323731

37333732
# Special-cases reshaping zero dim tensors
@@ -3763,12 +3762,6 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
37633762
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
37643763

37653764
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
3766-
shape_numel = reduce(operator.mul, shape, 1)
3767-
torch._check(
3768-
a.numel() == shape_numel,
3769-
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
3770-
)
3771-
deferred: list[Callable[[], bool]] = []
37723765

37733766
# NOTE [Reshape Algorithm]
37743767
# This algorithm works by attempting to greedily construct the desired dimensions in
@@ -3801,30 +3794,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
38013794
continue
38023795

38033796
# Skips dimensions that are already the correct length
3804-
if guard_or_false(length == a_.shape[idx]):
3797+
if guard_size_oblivious(length == a_.shape[idx]):
38053798
idx = idx + 1
38063799
continue
38073800

38083801
# Gathers enough original dimensions such that this new dimension can be created
38093802
# Note that this accumulation will terminate because we've verified a and the shape
38103803
# specify the same number of elements above
3811-
def maybe_throw_dde():
3812-
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
3813-
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
3814-
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
3815-
# figure out a valid reshape later in the loop.
3816-
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
-
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
3818-
# can torch._check() to avoid this.
3819-
for f in deferred:
3820-
f()
3821-
38223804
accum = a_.shape[idx]
38233805
end = idx
3824-
while guard_or_true(accum % length != 0):
3825-
deferred.append(lambda: bool(accum % length != 0))
3826-
if end == a_.ndim - 1:
3827-
maybe_throw_dde()
3806+
while guard_size_oblivious(accum % length != 0):
38283807
end = end + 1
38293808
accum = accum * a_.shape[end]
38303809
if end != idx:
@@ -3838,15 +3817,13 @@ def maybe_throw_dde():
38383817
if allow_copy:
38393818
return prims.reshape(a, shape)
38403819

3841-
maybe_throw_dde()
38423820
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
38433821
raise ValueError(msg)
38443822

38453823
a_ = flatten(a_, idx, end)
38463824

3847-
# Splits the (possibly flattened) dimension to create the desired dim length.
3848-
# guard_or_true is safe due to the tail unsqueeze routine.
3849-
if guard_or_true(accum != length):
3825+
# Splits the (possibly flattened) dimension to create the desired dim length
3826+
if guard_size_oblivious(accum != length):
38503827
a_ = prims.split_dim(a_, idx, length)
38513828

38523829
idx = idx + 1

0 commit comments

Comments
 (0)
0