8000 init by pianpwk · Pull Request #153682 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

init #153682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

init #153682

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
pianpwk committed May 16, 2025
commit 4ad3e4f2cdde223d73d7d586b62c08ccad8b4fa5
20 changes: 20 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11494,6 +11494,26 @@ def forward(self, x, y):
][0]
self.assertEqual(op_node.target._name, "aten::add.Tensor")

@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_unbacked_slice_forward(self):
class Foo(torch.nn.Module):
def forward(self, xs):
u0, u1, u2 = xs.tolist()
x = torch.empty(u0)
return x[u1:u2]

mod = Foo()
gm = export(mod, (torch.tensor([9, 1, 8]),)).module()
def check(self, sizes):
inp = torch.tensor(sizes)
self.assertEqual(mod(inp).shape, gm(inp).shape)

check(self, [9, -8, -1])
check(self, [3, 5, 3])
check(self, [10, 0, -2])
check(self, [10, -1000, 1000])

@testing.expectedFailureRetraceability
def test_layer_sharing(self):
N, C, H, W = 1, 2, 2, 3
Expand Down
26 changes: 25 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,34 @@ def fn(a):
__all__.append("sym_sqrt")


def check_same_symtype(t, f):
if type(t) == type(f):
return True
elif (
isinstance(t, (SymInt, builtins.int))
and isinstance(f, (SymInt, builtins.int))
):
return True
elif (
isinstance(t, (SymFloat, builtins.float))
and isinstance(f, (SymFloat, builtins.float))
):
return True
elif (
isinstance(t, (SymBool, builtins.bool))
and isinstance(f, (SymBool, builtins.bool))
):
return True
return False


def sym_ite(b, t, f):
if overrides.has_torch_function((b, t, f)):
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
assert (
isinstance(b, (SymBool, builtins.bool))
and check_same_symtype(t, f)
)
if isinstance(b, SymBool):
return b.__sym_ite__(t, f)
return t if b else f
Expand Down
48 changes: 29 additions & 19 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,7 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)
from torch.fx.experimental.symbolic_shapes import guard_or_none, statically_known_true, sym_or

ndim = self.dim()
if ndim == 0:
Expand All @@ -728,23 +725,36 @@ def slice_forward(
start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 - 1

if guard_size_oblivious(start_val < 0):
start_val += sizes[dim]

if guard_size_oblivious(end_val < 0):
end_val += sizes[dim]
def generalize_index(idx, end=False):
a = guard_or_none(idx >= -sizes[dim])
if a is None:
lt_clause = torch.sym_ite(idx >= -sizes[dim], sizes[dim] + idx, 0)
elif a:
lt_clause = sizes[dim] + idx
else:
lt_clause = 0

if end and statically_known_true(idx == sys.maxsize):
gt_clause = sizes[dim]
elif (b := guard_or_none(idx <= sizes[dim])) is None:
gt_clause = torch.sym_ite(idx <= sizes[dim], idx, sizes[dim])
elif b:
gt_clause = idx
else:
gt_clause = sizes[dim]

if guard_size_oblivious(start_val < 0):
start_val = 0
elif guard_size_oblivious(start_val > sizes[dim]):
start_val = sizes[dim]
c = guard_or_none(idx >= 0)
if c is None:
out = torch.sym_ite(idx >= 0, gt_clause, lt_clause)
elif c:
out = gt_clause
else:
out = lt_clause
return out

if guard_size_oblivious(end_val < start_val):
end_val = start_val
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
end_val = sizes[dim]
start_val = generalize_index(start_val)
end_val = generalize_index(end_val, end=True)
torch._check(end_val >= start_val)

storage_offset = self.storage_offset() + start_val * strides[dim]
len = end_val - start_val
Expand Down
5 changes: 5 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,10 @@
return _guard_or(a, True)


def guard_or_none(a: BoolLikeType) -> bool:
return _guard_or(a, None)

Check failure on line 1288 in torch/fx/experimental/symbolic_shapes.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [arg-type]

Argument 2 to "_guard_or" has incompatible type "None"; expected "bool"


def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
assert isinstance(x, SymBool)
expr = x.node.expr
Expand Down Expand Up @@ -1869,6 +1873,7 @@
torch.utils._sympy.functions.ToFloat,
torch.utils._sympy.functions.TruncToInt,
torch.utils._sympy.functions.CeilToInt,
sympy.functions.elementary.piecewise.ExprCondPair,
)


Expand Down
12 changes: 9 additions & 3 deletions torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def _extract_tensor_metadata(
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode

shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
Expand All @@ -52,9 +54,13 @@ def _extract_tensor_metadata(
torch.channels_last_3d,
}
for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break
try:
is_contig = result.is_contiguous(memory_format=query_format)
if is_contig:
memory_format = query_format
break
except GuardOnDataDependentSymNode:
continue

is_quantized = result.is_quantized
qparams: dict[str, Any] = {}
Expand Down
Loading
0