8000 Inductor logging + analysis of torch.profile · pytorch/pytorch@57716b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 57716b8

Browse files
committed
Inductor logging + analysis of torch.profile
1 parent e9e1aac commit 57716b8

19 files changed

+1883
-55
lines changed

test/inductor/test_analysis.py

+704
Large diffs are not rendered by default.
+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import torch
4+
import torch.utils.flop_counter
5+
from torch._inductor.debug import DebugContext
6+
from torch._inductor.graph import GraphLowering
7+
from torch._inductor.virtualized import V
8+
from torch.fx.experimental.proxy_tensor import make_fx
9+
from torch.testing._internal.common_cuda import SM70OrLater
10+
from torch.testing._internal.common_device_type import (
11+
dtypes,
12+
instantiate_device_type_tests,
13+
skipCUDAIf,
14+
)
15+
from torch.testing._internal.common_utils import run_tests, TestCase
16+
17+
18+
def FlopCounterMode(*args, **kwargs):
19+
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
20+
21+
22+
def get_total_flops(mode):
23+
return sum(v for _, v in mode.flop_counts["Global"].items())
24+
25+
26+
def random_tensor(size, dtype, **kwargs):
27+
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
28+
return torch.randn(size, dtype=dtype, **kwargs)
29+
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
30+
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
31+
else:
32+
raise ValueError("Unsupported data type")
33+
34+
35+
def cT(device, dtype):
36+
def T(*shape, requires_grad=False):
37+
return random_tensor(
38+
shape, requires_grad=requires_grad, device=device, dtype=dtype
39+
)
40+
41+
return T
42+
43+
44+
class TestScheduler(TestCase):
45+
@dtypes(torch.float, torch.double)
46+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
47+
def test_flop_counter_op(self, device, dtype):
48+
T = cT(device, dtype)
49+
50+
def composite(x, y, z):
51+
tmp = torch.mm(x + 10, y / 12)
52+
return torch.mm(tmp, z)
53+
54+
def composite_relu(x, y):
55+
tmp = torch.mm(x, y)
56+
return torch.relu(tmp)
57+
58+
test_cases = [
59+
(torch.mm, [T(4, 5), T(5, 6)], {}),
60+
(torch.add, [T(4, 5), T(4, 5)], {}),
61+
(composite, [T(5, 4), T(4, 3), T(3, 12)], {}),
62+
(composite_relu, [T(5, 4), T(4, 3)], {}),
63+
]
64+
for op, example_inputs, kwargs in test_cases:
65+
comp = torch.compile(op)
66+
with FlopCounterMode() as mode:
67+
comp(*example_inputs, **kwargs)
68+
gm = make_fx(op)(*example_inputs, **kwargs)
69+
reference_flops = get_total_flops(mode)
70+
71+
graph = GraphLowering(gm)
72+
73+
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
74+
graph.run(*example_inputs, **kwargs)
75+
graph.init_wrapper_code()
76+
graph._update_scheduler()
77+
scheduler_flops = 0
78+
for node in graph.scheduler.nodes:
79+
flops = node.estimate_flops()
80+
scheduler_flops += flops if flops is not None else 0
81+
self.assertEqual(reference_flops, scheduler_flops, msg=f"op = {op}")
82+
83+
84+
instantiate_device_type_tests(TestScheduler, globals())
85+
86+
if __name__ == "__main__":
87+
run_tests()

test/inductor/test_utils.py

+107
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from sympy import Symbol, sympify
44

55
import torch
6+
from torch._inductor.fx_utils import count_flops_fx, countable_fx
67
from torch._inductor.test_case import run_tests, TestCase
78
from torch._inductor.utils import sympy_str, sympy_subs
9+
from torch._inductor.virtualized import V
810

911

1012
class TestUtils(TestCase):
@@ -81,6 +83,111 @@ def test_sympy_str(self):
8183
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
8284
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
8385

86+
def test_flops_fx(self):
87+
def create_fx_node(
88+
aten: torch._ops.OpOverloadPacket, args, kwargs
89+
) -> tuple[torch.fx.Node, torch.fx.Node]:
90+
node1 = torch.fx.Node(
91+
graph=torch.fx.Graph(),
92+
name="",
93+
op="call_function",
94+
target=aten,
95+
args=args,
96+
kwargs=kwargs,
97+
)
98+
name: str = aten.overloads()[0]
99+
op_overload: torch._ops.OpOverload = getattr(aten, name)
100+
node2 = torch.fx.Node(
101+
graph=torch.fx.Graph(),
102+
name="",
103+
op="call_function",
104+
target=op_overload,
105+
args=args,
106+
kwargs=kwargs,
107+
)
108+
return node1, node2
109+
110+
with V.set_fake_mode(
111+
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
112+
):
113+
trues = [
114+
(
115+
torch.ops.aten.addmm,
116+
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
117+
{},
118+
),
119+
(
120+
torch.ops.aten.bmm,
121+
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
122+
{},
123+
),
124+
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
125+
(
126+
torch.ops.aten.convolution,
127+
(
128+
torch.Tensor(2, 3, 3),
129+
torch.Tensor(2, 2, 2),
130+
torch.Tensor(2),
131+
(1, 1),
132+
(0, 0),
133+
(1, 1),
134+
True,
135+
(0, 0),
136+
1,
137+
),
138+
{},
139+
),
140+
(
141+
torch.ops.aten._convolution,
142+
(
143+
torch.Tensor(2, 2, 2),
144+
torch.Tensor(2, 2, 2),
145+
torch.Tensor(2),
146+
(1,),
147+
(0,),
148+
(1,),
149+
True,
150+
(0,),
151+
1,
152+
False,
153+
True,
154+
False,
155+
),
156+
{},
157+
),
158+
]
159+
# we don't support pointwise ops
160+
falses = [
161+
(
162+
torch.ops.aten.add,
163+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
164+
{},
165+
),
166+
(
167+
torch.ops.aten.mul,
168+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
169+
{},
170+
),
171+
]
172+
for t, args, kwargs in trues:
173+
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
174+
self.assertTrue(
175+
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
176+
)
177+
self.assertTrue(
178+
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
179+
)
180+
self.assertNotEqual(count_flops_fx(fx_node_1), None)
181+
self.assertNotEqual(count_flops_fx(fx_node_2), None)
182+
for f, args, kwargs in falses:
183+
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
184+
self.assertFalse(
185+
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
186+
)
187+
self.assertFalse(
188+
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
189+
)
190+
84191

85192
if __name__ == "__main__":
86193
run_tests()

test/profiler/test_profiler.py

+59
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch.optim
2828
import torch.utils.data
2929
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
30+
from torch._inductor.ir import FixedLayout
3031
from torch.autograd.profiler import KinetoStepTracker, profile as _profile
3132
from torch.autograd.profiler_legacy import profile as _profile_legacy
3233
from torch.profiler import (
@@ -2998,6 +2999,64 @@ def validate_json(prof):
29982999
assert "Overload Name" in key_averages.table()
29993000
validate_json(prof)
30003001

3002+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
3003+
# this tests to see if we can only use a Triton backend for max autotune
3004+
@unittest.skipIf(
3005+
torch.cuda.is_available()
3006+
and not torch._inductor.utils.use_triton_template(
3007+
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
3008+
),
3009+
"Solo triton backend not possible",
3010+
)
3011+
def test_profiler_debug_autotuner(self):
3012+
"""
3013+
This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner.
3014+
"""
3015+
in1 = torch.randn((400, 600), device="cuda", dtype=torch.float16)
3016+
in2 = torch.randn((600, 800), device="cuda", dtype=torch.float16)
3017+
3018+
def mm():
3019+
return torch.mm(in1, in2)
3020+
3021+
pb_mm = torch.compile(
3022+
mm,
3023+
options={
3024+
"benchmark_kernel": True,
3025+
"max_autotune": True,
3026+
"max_autotune_gemm_backends": "TRITON",
3027+
"profile_bandwidth": True,
3028+
},
3029+
)
3030+
comp_mm = torch.compile(
3031+
mm,
3032+
options={
3033+
"benchmark_kernel": True,
3034+
"max_autotune": True,
3035+
"max_autotune_gemm_backends": "TRITON",
3036+
},
3037+
)
3038+
3039+
with profile() as prof1:
3040+
pb_mm()
3041+
with profile() as prof2:
3042+
comp_mm()
3043+
3044+
def names(prof):
3045+
return {
3046+
ev.name
3047+
for ev in prof.events()
3048+
if "mm" in ev.name or "triton" in ev.name
3049+
}
3050+
3051+
trace1 = "/tmp/trace1_pb.json"
3052+
trace2 = "/tmp/trace2_nopb.json"
3053+
prof1.export_chrome_trace(trace1)
3054+
prof2.export_chrome_trace(trace2)
3055+
3056+
n1 = names(prof1)
3057+
n2 = names(prof2)
3058+
self.assertEqual(n1, n2)
3059+
30013060

30023061
if __name__ == "__main__":
30033062
run_tests()

test/test_flop_counter.py

+1
Original file line numberDiff line numberDiff line change
@@ -854,5 +854,6 @@ def test_scaled_mm(self):
854854

855855
self.assertExpectedInline(get_total_flops(mode), """860160""")
856856

857+
857858
if __name__ == "__main__":
858859
run_tests()

torch/_inductor/analysis/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# `torch._inductor.analysis`
2+
Contains scripts for inductor performance analysis.

torch/_inductor/analysis/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)
0