8000 [inductor] Add unbacked symints binding in ShapeProp (#144605) · pytorch/pytorch@e15f913 · GitHub
[go: up one dir, main page]

Skip to content

Commit e15f913

Browse files
yushangdipytorchmergebot
authored andcommitted
[inductor] Add unbacked symints binding in ShapeProp (#144605)
Summary: ShapeProp doesn't know how to propagate unbacked. Patch it up to propagate unbacked symints like PropagateUnbackedSymInts. Test Plan: ``` buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_shape_prop_unbacked_sym ``` Differential Revision: D68050073 Pull Request resolved: #144605 Approved by: https://github.com/guowentian, https://github.com/pianpwk
1 parent 3c55669 commit e15f913

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

test/test_fx.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,25 @@ def forward(self, x):
17931793
if node.op in {'placeholder'}:
17941794
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
17951795

1796+
def test_shape_prop_unbacked_sym(self):
1797+
from torch._dynamo.utils import detect_fake_mode
1798+
1799+
class M(torch.nn.Module):
1800+
def forward(self, x: torch.Tensor):
1801+
return torch.nonzero(x)
1802+
1803+
inp = (torch.tensor([1, 0, 1, 0]),)
1804+
gm = torch.export.export(M(), inp).module()
1805+
fake_inputs = [
1806+
node.meta.get("val")
1807+
for node in gm.graph.nodes
1808+
if node.op == "placeholder"
1809+
]
1810+
inp = fake_inputs
1811+
fake_mode = detect_fake_mode(inp)
1812+
shape_prop.ShapeProp(gm=gm, fake_mode=fake_mode).propagate(*inp)
1813+
self.assertEqual(len(fake_mode.shape_env.pending_fresh_unbacked_symbols), 0)
1814+
17961815
def test_nn_module_stack(self):
17971816
class SubModule(torch.nn.Module):
17981817
def __init__(self) -> None:

torch/fx/passes/shape_prop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def __init__(self, gm, fake_mode=None):
154154
self.real_module = self.module
155155

156156
def run_node(self, n: Node) -> Any:
157+
from torch.fx.experimental.symbolic_shapes import (
158+
compute_unbacked_bindings,
159+
rebind_unbacked,
160+
)
161+
157162
try:
158163
if self.fake_module is not None:
159164
# Hacky swap. Alternatively, we could do this with overriding
@@ -163,6 +168,7 @@ def run_node(self, n: Node) -> Any:
163168
if self.fake_mode is not None:
164169
with self.fake_mode, enable_python_dispatcher():
165170
result = super().run_node(n)
171+
rebind_unbacked(self.fake_mode.shape_env, n, result)
166172
else:
167173
result = super().run_node(n)
168174
finally:
@@ -187,6 +193,12 @@ def extract_tensor_meta(obj):
187193
if found_tensor:
188194
n.meta["tensor_meta"] = meta
189195

196+
if self.fake_mode:
197+
if (shape_env := self.fake_mode.shape_env) and (
198+
symbol_to_path := compute_unbacked_bindings(shape_env, result)
199+
):
200+
n.meta["unbacked_bindings"] = symbol_to_path
201+
190202
n.meta["type"] = type(result)
191203
return result
192204

0 commit comments

Comments
 (0)
0