8000 [Inductor] Add NaN assert to returned values from generated code (#15… · pytorch/pytorch@aec3ef1 · GitHub
[go: up one dir, main page]

Skip to content

Commit aec3ef1

Browse files
PaulZhang12pytorchmergebot
authored andcommitted
[Inductor] Add NaN assert to returned values from generated code (#154455)
Summary: It is possible to have `reinterpret_tensor` in the output of inductor codegen, e.g. `reinterpret_tensor(buf366, (1024, ), (1, ), 0)` in the return tuple. This adds assertions to all return values from inductor codegen to prevent nans from slipping through and being hard to trace. Test Plan: NaN asserts properly generated in example gemm script: vars = (buf1, primals_2, buf2, primals_1, ) for var in vars: if isinstance(var, torch.Tensor): assert not var.isnan().any().item() assert not var.isinf().any().item() Pull Request resolved: #154455 Approved by: https://github.com/eellison
1 parent dc82e91 commit aec3ef1

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15039,6 +15039,9 @@ def f(x):
1503915039
self.assertIn("aoti_torch_check_inf_and_nan", code)
1504015040
else:
1504115041
self.assertIn("# make sure graph inputs are not nan/inf", code)
15042+
self.assertRegex(code, r"return_vars = (.*)")
15043+
self.assertIn("for var in return_vars:", code)
15044+
self.assertIn("if isinstance(var, torch.Tensor):", code)
1504215045
self.assertRegex(code, r"assert not .*\.isnan\(\)\.any\(\).item\(\)")
1504315046
self.assertRegex(code, r"assert not .*\.isinf\(\)\.any\(\).item\(\)")
1504415047

torch/_inductor/codegen/wrapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,18 @@ def codegen_device_guard_exit(self) -> None:
12601260

12611261
def generate_return(self, output_refs: list[str]) -> None:
12621262
if output_refs:
1263+
if config.nan_asserts:
1264+
self.wrapper_call.writeline(
1265+
"return_vars = (" + ", ".join(output_refs) + ", )"
1266+
)
1267+
self.wrapper_call.writeline("for var in return_vars:")
1268+
self.wrapper_call.do_indent()
1269+
self.wrapper_call.writeline("if isinstance(var, torch.Tensor):")
1270+
self.wrapper_call.do_indent()
1271+
self.wrapper_call.writeline("assert not var.isnan().any().item()")
1272+
self.wrapper_call.writeline("assert not var.isinf().any().item()")
1273+
self.wrapper_call.do_unindent(2)
1274+
12631275
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
12641276
else:
12651277
self.wrapper_call.writeline("return ()")

0 commit comments

Comments
 (0)
0