8000 [dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims by pianpwk · Pull Request #150127 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims #150127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7510a77
init
pianpwk Mar 27, 2025
4630374
assume >= 0
pianpwk Mar 27, 2025
e076ca4
test
pianpwk Mar 27, 2025
ed8ed35
assume numel = prod(shape)
pianpwk Mar 27, 2025
fafaa47
test
pianpwk Mar 27, 2025
fe734e8
workaround
pianpwk Mar 27, 2025
11a25a7
simple case
pianpwk Mar 27, 2025
8a48fad
weird test
pianpwk Mar 27, 2025
ac40d4c
lint
pianpwk Mar 27, 2025
207d3fb
switch to guard_or_true
pianpwk Apr 2, 2025
1790f35
Update __init__.py
pianpwk Apr 2, 2025
1dd4c2b
Update __init__.py
8000 pianpwk Apr 2, 2025
9d35a27
lint
pianpwk Apr 2, 2025
cab5f4a
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 3, 2025
f5b10c4
stash
pianpwk Apr 11, 2025
2eb56d0
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 11, 2025
9604307
reduce to only changing fast path
pianpwk Apr 11, 2025
00031a1
Update fx.experimental.rst
pianpwk Apr 12, 2025
983174e
Update fx.experimental.rst
pianpwk Apr 12, 2025
7c8d113
lint
pianpwk Apr 14, 2025
20749a3
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 14, 2025
5885633
add loop back
pianpwk Apr 14, 2025
18496a9
Update __init__.py
pianpwk Apr 16, 2025
af74f34
Update test_export.py
pianpwk Apr 16, 2025
69c1abe
try
pianpwk Apr 16, 2025
2ac1303
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 17, 2025
ffe4273
comments
pianpwk Apr 17, 2025
fa11fa1
Update __init__.py
pianpwk Apr 17, 2025
85fd59e
Update __init__.py
pianpwk Apr 17, 2025
6d322e6
Update __init__.py
pianpwk Apr 18, 2025
07c6ac9
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 18, 2025
5fa8ce9
test masked linear
pianpwk Apr 19, 2025
c85b5ce
lint
pianpwk Apr 21, 2025
b0db707
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 22, 2025
2bca726
Merge branch 'main' of https://github.com/pytorch/pytorch into pianpw…
pianpwk Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/export/test_draft_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def forward(self, x):
self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
)
self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)")
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")

def test_dedup_data_dependent_failure(self):
class M(torch.nn.Module):
Expand Down Expand Up @@ -480,6 +480,7 @@ def forward(self, x, mask, weight, bias):
return torch.nn.functional.linear(masked, weight, bias)

x = torch.zeros(10)
x[0] += 1
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
draft_ep = draft_export(M(), inp)
ep = export(M(), inp)
Expand Down
104 changes: 39 additions & 65 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4301,7 +4301,7 @@ class M_v0(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
return r.view(items[0], items[2])

M = M_v0
Expand All @@ -4310,69 +4310,23 @@ def forward(self, t):
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check( 8000 items[2] == (-1))')}.*\n"
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
):
export(N(), (t,), strict=strict)

class M_v1(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
# TODO(pianpwk): this isn't the suggested fixes.
# fix issue with % being interpreted as PythonMod instead of Mod
torch._check(items[1] == items[2])
return r.view(items[0], items[2])

M = M_v1
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
):
export(N(), (t,), strict=strict)

class M_v2(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
return r.view(items[0], items[2])

M = M_v2
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
):
export(N(), (t,), strict=strict)

class M_v3(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
torch._check(items[2] == r.shape[1])
return r.view(items[0], items[2])

M = M_v3
export(N(), (t,), strict=strict)

def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
Expand Down Expand Up @@ -4484,6 +4438,29 @@ def forward(self, x, offsets_t, fixes):
fixes=[], # nothing to fix!
)

def test_simple_unbacked_view(self):
class Foo(torch.nn.Module):
def forward(self, x):
u0 = x.item()
y = torch.empty(5, u0)
return y.view(u0, 5) # [5, u0] -> [u0, 5]

ep = export(Foo(), (torch.tensor([9]),))
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)

class Foov2(torch.nn.Module):
def forward(self, xs):
xsl = xs.tolist()
a, b = xsl
x = torch.zeros(a)
return x.reshape(b)

xs = torch.tensor([4, 4])
ep = export(Foov2(), (xs,))
self.assertEqual(ep.module()(xs).size(0), 4)
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)

def test_no_suggested_fixes_for_data_dependent_errors(self):
# suggested fixes for data-dependent errors only work in non-strict mode
strict = False
Expand Down Expand Up @@ -7549,22 +7526,19 @@ def forward(self, xs, y):
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
)

def test_check_is_size_error(self):
def test_no_check_is_size_error(self):
class Module(torch.nn.Module):
def forward(self, x):
a = x.item()
# We cannot automatically infer a is a size here because view
# accepts -1
return torch.randn(24).view(a, 4)

f = Module()
if is_non_strict_test(self._testMethodName):
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
else:
error = torch._dynamo.exc.UserError
error_msg = r"Could not guard on data-dependent expression"
with self.assertRaisesRegex(error, error_msg):
_ = export(f, (torch.tensor(6),))
ep = export(f, (torch.tensor(6),))
ep.module()(torch.tensor(6))
with self.assertRaisesRegex(
RuntimeError, r"Runtime assertion failed for .* u.* 6"
):
ep.module()(torch.tensor(5))
Copy link
Contributor Author
@pianpwk pianpwk Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exports fine now without check_is_size, we just assume it's >= 0 and it later specializes to 6.


def test_is_non_negative_check_function(self):
import sympy as sp
Expand Down Expand Up @@ -13446,7 +13420,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
node.target == torch.ops.aten._assert_scalar.default
for node in ep.graph.nodes
].count(True)
self.assertEqual(num_asserts, 1)
self.assertEqual(num_asserts, 2)
with self.assertRaises(RuntimeError):
ep.module()(torch.randn(4, 2))

Expand Down
17 changes: 11 additions & 6 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements.
"""
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false

dim = None
newsize = 1
for i, d in enumerate(shape):
if d == -1:
if guard_or_false(d == -1):
torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i
elif d >= 0:
newsize *= d
else:
torch._check(False, lambda: f"invalid shape dimension {d}")
torch._check(
d >= 0,
lambda: (
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
"If this was meant to be inferred, please explicitly pass in -1."
),
)
newsize *= d
if dim is None:
torch._check(
numel == newsize,
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
)
else:
from torch.fx.experimental.symbolic_shapes import definitely_true

torch._check(
newsize != 0,
lambda: (
Expand Down
35 changes: 29 additions & 6 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:


def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
from torch._dynamo.exc import UserError, UserErrorType
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true

# Creates a valid shape
shape = utils.extract_shape_from_varargs(shape, validate=False)
Expand All @@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
shape = utils.infer_size(shape, a.numel())

# Special-cases tensors with no elements
if guard_size_oblivious(a.numel() == 0):
if guard_or_false(a.numel() == 0):
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))

# Special-cases reshaping zero dim tensors
Expand Down Expand Up @@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
return torch.as_strided(a, [dim0, dim1], [dim1, 1])

# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
shape_numel = reduce(operator.mul, shape, 1)
torch._check(
a.numel() == shape_numel,
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
)
deferred: list[Callable[[], bool]] = []

# NOTE [Reshape Algorithm]
# This algorithm works by attempting to greedily construct the desired dimensions in
Expand Down Expand Up @@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
continue

# Skips dimensions that are already the correct length
if guard_size_oblivious(length == a_.shape[idx]):
if guard_or_false(length == a_.shape[idx]):
idx = idx + 1
continue

# Gathers enough original dimensions such that this new dimension can be created
# Note that this accumulation will terminate because we've verified a and the shape
# specify the same number of elements above
def maybe_throw_dde():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe_throw_bypassed_dde

# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
# figure out a valid reshape later in the loop.
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
# can torch._check() to avoid this.
for f in deferred:
f()

accum = a_.shape[idx]
end = idx
while guard_size_oblivious(accum % length != 0):
while guard_or_true(accum % length != 0):
deferred.append(lambda: bool(accum % length != 0))
if end == a_.ndim - 1:
maybe_throw_dde()
end = end + 1
accum = accum * a_.shape[end]
if end != idx:
Expand All @@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
if allow_copy:
return prims.reshape(a, shape)

maybe_throw_dde()
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
raise ValueError(msg)

a_ = flatten(a_, idx, end)

# Splits the (possibly flattened) dimension to create the desired dim length
if guard_size_oblivious(accum != length):
# Splits the (possibly flattened) dimension to create the desired dim length.
# guard_or_true is safe due to the tail unsqueeze routine.
if guard_or_true(accum != length):
a_ = prims.split_dim(a_, idx, length)

idx = idx + 1
Expand Down
Loading
0