8000 Scheduler Flops refactor (#152708) · pytorch/pytorch@da0b89b · GitHub
[go: up one dir, main page]

Skip to content

Commit da0b89b

Browse files
exclamafortepytorchmergebot
authored andcommitted
Scheduler Flops refactor (#152708)
This refactors `estimate_flops` and `get_estimated_runtime` on scheduler nodes: 1. New function on BaseSchedulerNode: `estimate_flops`. Works with all types of ir nodes now, not just `ExternalKernels`. 1. Extends `get_estimated_runtime` to work with non-`ExternalKernels`. Prelude to: #149697 Testing: New unit tests cover functionality. Pull Request resolved: #152708 Approved by: https://github.com/xmfan, https://github.com/eellison
1 parent 073b025 commit da0b89b

File tree

4 files changed

+318
-61
lines changed

4 files changed

+318
-61
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import torch
4+
import torch.utils.flop_counter
5+
from torch._dynamo.utils import counters
6+
from torch._inductor.ir import FixedLayout
7+
from torch._inductor.utils import fresh_inductor_cache
8+
from torch.testing._internal.common_cuda import SM70OrLater
9+
from torch.testing._internal.common_device_type import (
10+
dtypes,
11+
instantiate_device_type_tests,
12+
skipCUDAIf,
13+
)
14+
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
15+
16+
17+
def FlopCounterMode(*args, **kwargs):
18+
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
19+
20+
21+
def get_total_flops(mode):
22+
return sum(v for _, v in mode.flop_counts["Global"].items())
23+
24+
25+
def random_tensor(size, dtype, **kwargs):
26+
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
27+
return torch.randn(size, dtype=dtype, **kwargs)
28+
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
29+
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
30+
else:
31+
raise ValueError("Unsupported data type")
32+
33+
34+
def cT(device, dtype):
35+
def T(*shape, requires_grad=False):
36+
return random_tensor(
37+
shape, requires_grad=requires_grad, device=device, dtype=dtype
38+
)
39+
40+
return T
41+
42+
43+
class TestScheduler(TestCase):
44+
@dtypes(torch.float, torch.float16)
45+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
46+
@parametrize(
47+
"options",
48+
[
49+
{
50+
"max_autotune": True,
51+
"max_autotune_gemm_backends": "TRITON",
52+
"force_disable_caches": True,
53+
},
54+
{
55+
"max_autotune": True,
56+
"max_autotune_gemm_backends": "TRITON,ATEN",
57+
"force_disable_caches": True,
58+
},
59+
],
60+
)
61+
def test_flop_counter_op(self, device, dtype, options):
62+
if device == "cpu":
63+
return
64+
if (
65+
options["max_autotune_gemm_backends"] == "TRITON"
66+
and torch.cuda.is_available()
67+
and not torch._inductor.utils.use_triton_template(
68+
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
69+
)
70+
):
71+
return
72+
T = cT(device, dtype)
73+
74+
def composite(x, y, z):
75+
tmp = torch.mm(x + 10, y / 12)
76+
return torch.mm(tmp, z)
77+
78+
def composite_relu(x, y):
79+
tmp = torch.mm(x, y)
80+
return torch.relu(tmp)
81+
82+
test_cases = [
83+
(torch.mm, [T(4, 5), T(5, 6)], {}),
84+
(torch.add, [T(4, 5), T(4, 5)], {}),
85+
(composite, [T(5, 4), T(4, 3), T(3, 12)], {}),
86+
(composite_relu, [T(5, 4), T(4, 3)], {}),
87+
]
88+
for op, example_inputs, kwargs in test_cases:
89+
comp = torch.compile(op, options=options)
90+
# next two lines are required, otherwise the flops will be cached from pervious runs of this function.
91+
torch._dynamo.reset()
92+
with fresh_inductor_cache():
93+
# actually run to set the counters
94+
comp(*example_inputs, **kwargs)
95+
with FlopCounterMode() as mode:
96+
comp(*example_inputs, **kwargs)
97+
reference_flops = get_total_flops(mode)
98+
99+
self.assertEqual(
100+
reference_flops,
101+
counters["inductor"]["flop_count"],
102+
msg=f"op = {op} reference flops = {reference_flops} != counters {counters['inductor']['flop_count']}",
103+
)
104+
if op != torch.add:
105+
self.assertNotEqual(reference_flops, 0, msg=f"op = {op} is 0 flops")
106+
counters["inductor"]["flop_count"] = 0
107+
108+
109+
instantiate_device_type_tests(TestScheduler, globals())
110+
111+
if __name__ == "__main__":
112+
run_tests()

test/inductor/test_utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from sympy import Symbol, sympify
44

55
import torch
6+
from torch._inductor.fx_utils import count_flops_fx, countable_fx
67
from torch._inductor.test_case import run_tests, TestCase
78
from torch._inductor.utils import sympy_str, sympy_subs
9+
from torch._inductor.virtualized import V
810

911

1012
class TestUtils(TestCase):
@@ -81,6 +83,111 @@ def test_sympy_str(self):
8183
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
8284
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
8385

86+
def test_flops_fx(self):
87+
def create_fx_node(
88+
aten: torch._ops.OpOverloadPacket, args, kwargs
89+
) -> tuple[torch.fx.Node, torch.fx.Node]:
90+
node1 = torch.fx.Node(
91+
graph=torch.fx.Graph(),
92+
name="",
93+
op="call_function",
94+
target=aten,
95+
args=args,
96+
kwargs=kwargs,
97+
)
98+
name: str = aten.overloads()[0]
99+
op_overload: torch._ops.OpOverload = getattr(aten, name)
100+
node2 = torch.fx.Node(
101+
graph=torch.fx.Graph(),
102+
name="",
103+
op="call_function",
104+
target=op_overload,
105+
args=args,
106+
kwargs=kwargs,
107+
)
108+
return node1, node2
109+
110+
with V.set_fake_mode(
111+
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
112+
):
113+
trues = [
114+
(
115+
torch.ops.aten.addmm,
116+
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
117+
{},
118+
),
119+
(
120+
torch.ops.aten.bmm,
121+
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
122+
{},
123+
),
124+
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
125+
(
126+
torch.ops.aten.convolution,
127+
(
128+
torch.Tensor(2, 3, 3),
129+
torch.Tensor(2, 2, 2),
130+
torch.Tensor(2),
131+
(1, 1),
132+
(0, 0),
133+
(1, 1),
134+
True,
135+
(0, 0),
136+
1,
137+
),
138+
{},
139+
),
140+
(
141+
torch.ops.aten._convolution,
142+
(
143+
torch.Tensor(2, 2, 2),
144+
torch.Tensor(2, 2, 2),
145+
torch.Tensor(2),
146+
(1,),
147+
(0,),
148+
(1,),
149+
True,
150+
(0,),
151+
1,
152+
False,
153+
True,
154+
False,
155+
),
156+
{},
157+
),
158+
]
159+
# we don't support pointwise ops
160+
falses = [
161+
(
162+
torch.ops.aten.add,
163+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
164+
{},
165+
),
166+
(
167+
torch.ops.aten.mul,
168+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
169+
{},
170+
),
171+
]
172+
for t, args, kwargs in trues:
173+
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
174+
self.assertTrue(
175+
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
176+
)
177+
self.assertTrue(
178+
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
179+
)
180+
self.assertNotEqual(count_flops_fx(fx_node_1), None)
181+
self.assertNotEqual(count_flops_fx(fx_node_2), None)
182+
for f, args, kwargs in falses:
183+
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
184+
self.assertFalse(
185+
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
186+
)
187+
self.assertFalse(
188+
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
189+
)
190+
84191

85192
if __name__ == "__main__":
86193
run_tests()

torch/_inductor/fx_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.fx
1010
from torch._dispatch.python import enable_python_dispatcher
11+
from torch._subclasses.fake_tensor import FakeTensorMode
1112
from torch.fx.experimental.symbolic_shapes import (
1213
compute_unbacked_bindings,
1314
rebind_unbacked,
@@ -17,6 +18,7 @@
1718
from torch.utils import _pytree as pytree
1819
from torch.utils._ordered_set import OrderedSet
1920
from torch.utils._pytree import tree_map
21+
from torch.utils.flop_counter import flop_registry
2022

2123
from .virtualized import V
2224

@@ -250,3 +252,34 @@ def realizes_inputs(node: torch.fx.Node) -> bool:
250252

251253
# Otherwise, assume node isn't realized
252254
return False
255+
256+
257+
def count_flops_fx(node: torch.fx.Node) -> Optional[int]:
258+
if isinstance(node.target, str):
259+
return None
260+
with FakeTensorMode(allow_non_fake_inputs=True):
261+
success, args, kwargs = get_fake_args_kwargs(node)
262+
263+
if success:
264+
with torch.utils.flop_counter.FlopCounterMode(
265+
display=False
266+
) as flop_counter_mode:
267+
node.target(*args, **kwargs)
268+
269+
counted_flops = flop_counter_mode.get_total_flops()
270+
return counted_flops
271+
return None
272+
273+
274+
def countable_fx(node: torch.fx.Node) -> bool:
275+
"""
276+
Whether or not we can count the flops of an FX node.
277+
"""
278+
assert isinstance(node, torch.fx.Node)
279+
if not hasattr(node, "target"):
280+
return False
281+
target = node.target
282+
if not hasattr(target, "overloadpacket"):
283+
return target in flop_registry
284+
packet = target.overloadpacket
285+
return packet in flop_registry

0 commit comments

Comments
 (0)
0