10000 Fix bug in AOTI lowering · pytorch/pytorch@002dc27 · GitHub
[go: up one dir, main page]

Skip to content

Commit 002dc27

Browse files
Fix bug in AOTI lowering
ghstack-source-id: 1c81ea4 Pull Request resolved: #148364
1 parent d260d4f commit 002dc27

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,18 @@ def forward(self, q, k, v):
11611161
path = torch._inductor.aot_compile(ep.module(), inputs)
11621162
aot_model = torch._export.aot_load(path, device=self.device)
11631163
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
1164+
1165+
def test_aoti_constant_tensor(self):
1166+
class Foo(torch.nn.Module):
1167+
def __init__(self):
1168+
super().__init__()
1169+
self.a = torch.ones(4, 4)
1170+
self.b = torch.ones(4, 4)
1171+
def forward(self, x):
1172+
return torch.ops.aten.linear.default(x, self.a, self.b)
1173+
1174+
ep = torch.export.export(Foo(), (torch.ones(4, 4),), strict=False).run_decompositions({})
1175+
_ = torch._inductor.aoti_compile_and_package(ep)
11641176

11651177
def test_large_grid(self):
11661178
if self.device != GPU_TYPE:

torch/_inductor/compile_fx.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,6 +2184,22 @@ def bw_compiler(
21842184
trace_joint=False,
21852185
decompositions=decompositions,
21862186
)
2187+
2188+
from torch._export.utils import _detect_fake_mode_from_gm
2189+
fake_mode = _detect_fake_mode_from_gm(gm)
2190+
# aot_export_module doesn't account for constant tensor attributes
2191+
# so we end up having tensors that don't have fake vals attached.
2192+
# This can happen when upstream export is non-strict where we
2193+
# preserve the original module params/buffers. Once AOTI switches
2194+
# to ep.run_decompositions() flow to lower to post-autograd opset
2195+
# this will go away.
2196+
for node in gm.graph.nodes:
2197+
if node.op == "get_attr" and "val" not in node.meta:
2198+
target = getattr(gm, node.target)
2199+
if isinstance(target, torch.Tensor):
2200+
node.meta["val"] = fake_mode.from_tensor(target)
2201+
2202+
21872203
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
21882204
if "dynamo_flat_name_to_original_fqn" in model_.meta:
21892205
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[

0 commit comments

Comments
 (0)
0