10000 [DRAFT][Reshape] Guard-free reshape for contiguous tensors to avoid data dependent errors. by laithsakka · Pull Request #148742 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DRAFT][Reshape] Guard-free reshape for contiguous tensors to avoid data dependent errors. #148742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
40 changes: 40 additions & 0 deletions test/test_dynamic_shapes.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -2827,6 +2827,46 @@ def test_guards_float_div(self):
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))

def test_unbacked_reshape(self):
# reshape u0 -> (u1, u2, u2)
@torch.compile(fullgraph=True)
def func(x, y, z):
t = torch.reshape(x, (y.size()[0], z.size()[0], z.size()[0]))
return t

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
y = torch.rand(2)
z = torch.rand(2)

torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(z, 0)

self.assertEqual(
func(x, y, z), torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
)
self.assertEqual(
func(torch.ones(50), torch.ones(2), torch.ones(5)), torch.ones(2, 5, 5)
)

# reshape (u0, u1) -> (u3, u3+u4)
@torch.compile(fullgraph=True)
def func2(x, y, z):
t = torch.reshape(x, (y.size()[0], z.size()[0] + y.size()[0]))
return t

x = torch.ones(4, 4)
y = torch.rand(2)
z = torch.rand(6)

torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)

torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(z, 0)

self.assertEqual(func2(x, y, z), torch.ones(2, 8))

def test_remove_symbols_without_guarding(self):
from torch._functorch.partitioners import _remove_symbols_without_guarding

Expand Down
8 changes: 5 additions & 3 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,10 +930,12 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
if d == -1:
torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i
elif d >= 0:
newsize *= d
else:
torch._check(False, lambda: f"invalid shape dimension {d}")
torch._check(
d >= 0, f"shape '{list(shape)}' is invalid for input of size {numel}"
)
newsize *= d

if dim is None:
torch._check(
numel == newsize,
Expand Down
23 changes: 14 additions & 9 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3696,7 +3696,11 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:


def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
sym_eq,
)

# Creates a valid shape
shape = utils.extract_shape_from_varargs(shape, validate=False)
Expand Down Expand Up @@ -3731,14 +3735,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
return _a

if a.is_contiguous():
# Special-cases for nd_to_1d
if len(shape) == 1 and a.ndim > 1:
return torch.as_strided(a, [a.numel()], [1])
# Special-cases for 1d_to_2d
if len(shape) == 2 and a.ndim == 1:
dim0 = shape[0]
dim1 = shape[1]
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
if len(shape) >= 1 and a.ndim >= 1:
if statically_known_true(sym_eq(shape, a.shape)):
return prims.view_of(a)

strides = [1]
for x in reversed(shape[1:]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be shape[:-1]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha maybe i am new ish to python no fancy slicing in c++.
I will try it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean why not
[1:][::-1]
I want to remove first element the reverse

strides.append(strides[-1] * x)
strides.reverse()
return torch.as_strided(a, shape, strides)

# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

Expand Down
Loading
0