8000 [inductor][dynamo] Include operator name in size/stride/alignment ass… · pytorch/pytorch@725bbb6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 725bbb6

Browse files
karthickaipytorchmergebot
authored andcommitted
[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930 This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages. The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg. In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging. Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py). - Verified both successful and failing assertion cases include the operator name. - Verified that generated Triton code contains the op name inside the asserts. Pull Request resolved: #152353 Approved by: https://github.com/jansel
1 parent f5e0806 commit 725bbb6

File tree

7 files changed

+170
-20
lines changed

7 files changed

+170
-20
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44130000000,0.025
1414

1515

1616

test/distributed/test_functional_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,13 @@ def run_with_backward():
715715

716716
_, codes = run_and_get_code(run_with_backward)
717717
for code in codes:
718+
assert_keywords = ["assert_size_stride", "assert_alignment"]
719+
filtered_lines = [
720+
line
721+
for line in code.splitlines()
722+
if not any(assert_key in line for assert_key in assert_keywords)
723+
]
724+
code = "\n".join(filtered_lines)
718725
FileCheck().check_count(
719726
"_c10d_functional.all_to_all_single.default", 1, exactly=True
720727
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ def _test_code_common(
231231
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
232232
*clone_inputs,
233233
)
234+
assert_keywords = ["assert_size_stride", "assert_alignment"]
235+
filtered_lines = [
236+
line
237+
for line in source_code.splitlines()
238+
if not any(assert_key in line for assert_key in assert_keywords)
239+
]
240+
source_code = "\n".join(filtered_lines)
241+
234242
for op in include_ops:
235243
self.assertIn(op, source_code)
236244
if num_include_ops is not None:

test/inductor/test_torchinductor.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch._dynamo.config as dynamo_config
3131
import torch._inductor.aoti_eager
3232
import torch.nn as nn
33+
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
3334
from torch._dispatch.python import enable_python_dispatcher
3435
from torch._dynamo.debug_utils import aot_graph_input_parser
3536
from torch._dynamo.device_interface import get_interface_for_device
@@ -1410,9 +1411,10 @@ def fn(a, b):
14101411
)
14111412
_, code = run_and_get_code(fn, x, y)
14121413
code = " ".join(code)
1413-
self.assertEqual(
1414-
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
1415-
)
1414+
if config.cpp_wrapper:
1415+
self.assertEqual(code.count("view_dtype"), 3)
1416+
else:
1417+
self.assertEqual(code.count("aten.view"), 9)
14161418

14171419
def test_add_complex5(self):
14181420
def fn(a, b, alpha):
@@ -11882,6 +11884,82 @@ def fn(x):
1188211884
check_lowp=False,
1188311885
)
1188411886

11887+
@requires_gpu()
11888+
@skip_if_not_triton
11889+
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11890+
@config.patch(implicit_fallbacks=True)
11891+
def test_generated_code_has_size_stride_assert(self):
11892+
def foo(x):
11893+
return 3 * x
11894+
11895+
def foo_meta(x):
11896+
return torch.empty_like(x)
11897+
11898+
define_custom_op_for_test("foo", foo, foo_meta)
11899+
11900+
def fn(x):
11901+
a = torch.nn.functional.relu(x)
11902+
b = torch.ops.test.foo(a)
11903+
return b
11904+
11905+
a = torch.randn((16, 32), device=self.device)
11906+
11907+
_, code = run_and_get_code(
11908+
torch.compile(fn),
11909+
a,
11910+
)
11911+
if not is_dynamic_shape_enabled():
11912+
FileCheck().check(
11913+
"assert_size_stride(buf2, (16, 32), (32, 1), 'torch.ops.test.foo.default')"
11914+
).run(code[0])
11915+
11916+
@requires_gpu()
11917+
@skip_if_not_triton
11918+
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11919+
@config.patch(implicit_fallbacks=True)
11920+
def test_generated_code_has_alignment_assert(self):
11921+
def foo(x):
11922+
return 3 * x
11923+
11924+
def foo_meta(x):
11925+
return torch.empty_like(x)
11926+
11927+
define_custom_op_for_test("foo", foo, foo_meta)
11928+
11929+
def fn(x):
11930+
a = torch.nn.functional.relu(x)
11931+
b = torch.ops.test.foo(a)
11932+
return b
11933+
11934+
a = torch.randn((16, 32), device=self.device)
11935+
11936+
_, code = run_and_get_code(
11937+
torch.compile(fn),
11938+
a,
11939+
)
11940+
if not is_dynamic_shape_enabled():
11941+
FileCheck().check(
11942+
"assert_alignment(buf2, 16, 'torch.ops.test.foo.default')"
11943+
).run(code[0])
11944+
11945+
def test_assert_size_stride_op_name_pass(self):
11946+
tensor = torch.empty((16, 32))
11947+
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
11948+
11949+
def test_assert_size_stride_op_name_fail(self):
11950+
tensor = torch.empty((16, 32))
11951+
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11952+
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
11953+
11954+
def test_assert_alignment_op_name_pass(self):
11955+
tensor = torch.empty((16, 32))
11956+
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
11957+
11958+
def test_assert_alignment_op_name_fail(self):
11959+
tensor = torch.empty((16, 32))
11960+
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11961+
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
11962+
1188511963
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
1188611964
@torch._inductor.config.patch(implicit_fallbacks=True)
1188711965
def test_custom_op_unbacked_symints(self):
@@ -13014,12 +13092,12 @@ def f(x):
1301413092
code = run_and_get_triton_code(f, x)
1301513093

1301613094
if is_dynamic_shape_enabled():
13017-
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
13018-
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
13095+
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
13096+
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
1301913097
).run(code)
1302013098
else:
13021-
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
13022-
"assert_size_stride(buf2, (16, 32), (32, 1))"
13099+
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
13100+
"assert_size_stride(buf2, (16, 32), (32, 1)"
1302313101
).run(code)
1302413102

1302513103
@requires_cuda

torch/_C/_dynamo/guards.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def assert_size_stride(
176176
item: torch.Tensor,
177177
size: torch.types._size,
178178
stride: torch.types._size,
179+
op_name: str | None = None,
180+
): ...
181+
def assert_alignment(
182+
item: torch.Tensor,
183+
alignment: int,
184+
op_name: str | None = None,
179185
): ...
180186
def check_obj_id(obj: object, expected: int) -> bool: ...
181187
def check_type_id(obj: object, expected: int) -> bool: ...

torch/_inductor/ir.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,26 +5772,42 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def]
57725772
]
57735773
return kwargs
57745774

5775+
def get_op_name(self) -> str:
5776+
if self.fx_node is not None:
5777+
target = self.fx_node.target
5778+
op_namespace = getattr(target, "__module__", "unknown_namespace")
5779+
op_namespace = op_namespace.replace("._ops.", ".ops.")
5780< 1241 code class="diff-text syntax-highlighted-line addition">+
op_namespace = op_namespace.rsplit(".", 1)[0]
5781+
op_name = f"{op_namespace}.{target}"
5782+
else:
5783+
op_name = "unknown_op"
5784+
return op_name
5785+
57755786
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57765787
if config.size_asserts and not V.graph.cpp_wrapper:
57775788
# comparing strides for 0 size tensor is tricky. Ignore them for now.
57785789
if sympy_product(self.get_size()) == 0:
57795790
return
57805791
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
57815792
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
5782-
5793+
op_name = self.get_op_name()
57835794
wrapper.writeline(
5784-
f"assert_size_stride({self.get_name()}, {size}, {stride})"
5795+
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
57855796
)
57865797

57875798
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57885799
if config.alignment_asserts and not V.graph.cpp_wrapper:
57895800
name = self.get_name()
57905801
aligned = name not in V.graph.unaligned_buffers
5802+
op_name = self.get_op_name()
57915803
if aligned:
5792-
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
5804+
wrapper.writeline(
5805+
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
5806+
)
57935807
else:
5794-
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
5808+
wrapper.writeline(
5809+
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
5810+
)
57955811

57965812
def get_group_stride(self): # type: ignore[no-untyped-def]
57975813
"""

torch/csrc/dynamo/guards.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -844,21 +844,38 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
844844
PyObject* item = nullptr;
845845
PyObject* size = nullptr;
846846
PyObject* stride = nullptr;
847-
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
847+
const char* op_name = nullptr;
848+
849+
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
848850
return nullptr;
849851
}
850852
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
851-
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
853+
std::stringstream msg;
854+
msg << "expected Tensor()";
855+
if (op_name) {
856+
msg << " for op: " << op_name;
857+
}
858+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
852859
return nullptr;
853860
}
854861
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
855-
PyErr_SetString(PyExc_TypeError, "expected tuple()");
862+
std::stringstream msg;
863+
msg << "expected tuple()";
864+
if (op_name) {
865+
msg << " for op: " << op_name;
866+
}
867+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
856868
return nullptr;
857869
}
858870
at::Tensor tensor = THPVariable_Unpack(item);
859871
int64_t ndim = tensor.ndimension();
860872
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
861-
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
873+
std::stringstream msg;
874+
msg << "wrong number of dimensions" << ndim;
875+
if (op_name) {
876+
msg << " for op: " << op_name;
877+
}
878+
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
862879
return nullptr;
863880
}
864881

@@ -887,6 +904,9 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
887904
}
888905

889906
if (num_errors) {
907+
if (op_name) {
908+
msg << "\nError in op: " << op_name;
909+
}
890910
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
891911
msg << "\nUse torch.library.opcheck to test your custom op.";
892912
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
@@ -904,15 +924,27 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
904924
*/
905925
PyObject* item = nullptr;
906926
unsigned long alignment = 0;
907-
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
927+
const char* op_name = nullptr;
928+
929+
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
908930
return nullptr;
909931
}
910932
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
911-
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
933+
std::stringstream msg;
934+
msg << "expected Tensor()";
935+
if (op_name) {
936+
msg << " for op: " << op_name;
937+
}
938+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
912939
return nullptr;
913940
}
914941
if (alignment == 0) {
915-
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
942+
std::stringstream msg;
943+
msg << "alignment cannot be 0";
944+
if (op_name) {
945+
msg << " in op: " << op_name;
946+
}
947+
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
916948
return nullptr;
917949
}
918950

@@ -922,7 +954,10 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
922954
size_t itemsize = tensor.itemsize();
923955
if (storage_offset * itemsize % alignment != 0) {
924956
std::stringstream msg;
925-
msg << "Expect the tensor to be " << alignment
957+
if (op_name) {
958+
msg << "\nError in op: " << op_name;
959+
}
960+
msg << "\nExpect the tensor to be " << alignment
926961
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
927962
<< " itemsize=" << itemsize;
928963
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());

0 commit comments

Comments
 (0)
0