8000 refactor scheduler countflops and runtime · pytorch/pytorch@da087e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit da087e3

Browse files
committed
refactor scheduler countflops and runtime
1 parent e9e1aac commit da087e3

File tree

4 files changed

+260
-60
lines changed

4 files changed

+260
-60
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import torch
4+
import torch.utils.flop_counter
5+
from torch._inductor.debug import DebugContext
6+
from torch._inductor.graph import GraphLowering
7+
from torch._inductor.virtualized import V
8+
from torch.fx.experimental.proxy_tensor import make_fx
9+
from torch.testing._internal.common_cuda import SM70OrLater
10+
from torch.testing._internal.common_device_type import (
11+
dtypes,
12+
instantiate_device_type_tests,
13+
skipCUDAIf,
14+
)
15+
from torch.testing._internal.common_utils import run_tests, TestCase
16+
17+
18+
def FlopCounterMode(*args, **kwargs):
19+
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
20+
21+
22+
def get_total_flops(mode):
23+
return sum(v for _, v in mode.flop_counts["Global"].items())
24+
25+
26+
def random_tensor(size, dtype, **kwargs):
27+
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
28+
return torch.randn(size, dtype=dtype, **kwargs)
29+
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
30+
return torch.randint(0, 100, size, dtype 57AE =dtype, **kwargs)
31+
else:
32+
raise ValueError("Unsupported data type")
33+
34+
35+
def cT(device, dtype):
36+
def T(*shape, requires_grad=False):
37+
return random_tensor(
38+
shape, requires_grad=requires_grad, device=device, dtype=dtype
39+
)
40+
41+
return T
42+
43+
44+
class TestScheduler(TestCase):
45+
@dtypes(torch.float, torch.double)
46+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
47+
def test_flop_counter_op(self, device, dtype):
48+
T = cT(device, dtype)
49+
50+
def composite(x, y, z):
51+
tmp = torch.mm(x + 10, y / 12)
52+
return torch.mm(tmp, z)
53+
54+
def composite_relu(x, y):
55+
tmp = torch.mm(x, y)
56+
return torch.relu(tmp)
57+
58+
test_cases = [
59+
(torch.mm, [T(4, 5), T(5, 6)], {}),
60+
(torch.add, [T(4, 5), T(4, 5)], {}),
61+
(composite, [T(5, 4), T(4, 3), T(3, 12)], {}),
62+
(composite_relu, [T(5, 4), T(4, 3)], {}),
63+
]
64+
for op, example_inputs, kwargs in test_cases:
65+
comp = torch.compile(op)
66+
with FlopCounterMode() as mode:
67+
comp(*example_inputs, **kwargs)
68+
gm = make_fx(op)(*example_inputs, **kwargs)
69+
reference_flops = get_total_flops(mode)
70+
71+
graph = GraphLowering(gm)
72+
73+
with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
74+
graph.run(*example_inputs, **kwargs)
75+
graph.init_wrapper_code()
76+
graph._update_scheduler()
77+
scheduler_flops = 0
78+
for node in graph.scheduler.nodes:
79+
flops = node.estimate_flops()
80+
scheduler_flops += flops if flops is not None else 0
81+
self.assertEqual(reference_flops, scheduler_flops, msg=f"op = {op}")
82+
83+
84+
instantiate_device_type_tests(TestScheduler, globals())
85+
86+
if __name__ == "__main__":
87+
run_tests()

test/inductor/test_utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from sympy import Symbol, sympify
44

55
import torch
6+
from torch._inductor.fx_utils import count_flops_fx, countable_fx
67
from torch._inductor.test_case import run_tests, TestCase
78
from torch._inductor.utils import sympy_str, sympy_subs
9+
from torch._inductor.virtualized import V
810

911

1012
class TestUtils(TestCase):
@@ -81,6 +83,111 @@ def test_sympy_str(self):
8183
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
8284
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
8385

86+
def test_flops_fx(self):
87+
def create_fx_node(
88+
aten: torch._ops.OpOverloadPacket, args, kwargs
89+
) -> tuple[torch.fx.Node, torch.fx.Node]:
90+
node1 = torch.fx.Node(
91+
graph=torch.fx.Graph(),
92+
name="",
93+
op="call_function",
94+
target=aten,
95+
args=args,
96+
kwargs=kwargs,
97+
)
98+
name: str = aten.overloads()[0]
99+
op_overload: torch._ops.OpOverload = getattr(aten, name)
100+
node2 = torch.fx.Node(
101+
graph=torch.fx.Graph(),
102+
name="",
103+
op="call_function",
104+
target=op_overload,
105+
args=args,
106+
kwargs=kwargs,
107+
)
108+
return node1, node2
109+
110+
with V.set_fake_mode(
111+
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
112+
):
113+
trues = [
114+
(
115+
torch.ops.aten.addmm,
116+
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
117+
{},
118+
),
119+
(
120+
torch.ops.aten.bmm,
121+
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
122+
{},
123+
),
124+
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
125+
(
126+
torch.ops.aten.convolution,
127+
(
128+
torch.Tensor(2, 3, 3),
129+
torch.Tensor(2, 2, 2),
130+
torch.Tensor(2),
131+
(1, 1),
132+
(0, 0),
133+
(1, 1),
134+
True,
135+
(0, 0),
136+
1,
137+
),
138+
{},
139+
),
140+
(
141+
torch.ops.aten._convolution,
142+
(
143+
torch.Tensor(2, 2, 2),
144+
torch.Tensor(2, 2, 2),
145+
torch.Tensor(2),
146+
(1,),
147+
(0,),
148+
(1,),
149+
True,
150+
(0,),
151+
1,
152+
False,
153+
True,
154+
False,
155+
),
156+
{},
157+
),
158+
]
159+
# we don't support pointwise ops
160+
falses = [
161+
(
162+
torch.ops.aten.add,
163+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
164+
{},
165+
),
166+
(
167+
torch.ops.aten.mul,
168+
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
169+
{},
170+
),
171+
]
172+
for t, args, kwargs in trues:
173+
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
174+
self.assertTrue(
175+
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
176+
)
177+
self.assertTrue(
178+
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
179+
)
180+
self.assertNotEqual(count_flops_fx(fx_node_1), None)
181+
self.assertNotEqual(count_flops_fx(fx_node_2), None)
182+
for f, args, kwargs in falses:
183+
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
184+
self.assertFalse(
185+
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
186+
)
187+
self.assertFalse(
188+
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
189+
)
190+
84191

85192
if __name__ == "__main__":
86193
run_tests()

torch/_inductor/fx_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.fx
1010
from torch._dispatch.python import enable_python_dispatcher
11+
from torch._subclasses.fake_tensor import FakeTensorMode
1112
from torch.fx.experimental.symbolic_shapes import (
1213
compute_unbacked_bindings,
1314
rebind_unbacked,
@@ -17,6 +18,7 @@
1718
from torch.utils import _pytree as pytree
1819
from torch.utils._ordered_set import OrderedSet
1920
from torch.utils._pytree import tree_map
21+
from torch.utils.flop_counter import flop_registry
2022

2123
from .virtualized import V
2224

@@ -250,3 +252,31 @@ def realizes_inputs(node: torch.fx.Node) -> bool:
250252

251253
# Otherwise, assume node isn't realized
252254
return False
255+
256+
257+
def count_flops_fx(node: torch.fx.Node) -> Optional[int]:
258+
if isinstance(node.target, str):
259+
return None
260+
with FakeTensorMode(allow_non_fake_inputs=True):
261+
success, args, kwargs = get_fake_args_kwargs(node)
262+
263+
if success:
264+
with torch.utils.flop_counter.FlopCounterMode(
265+
display=False
266+
) as flop_counter_mode:
267+
node.target(*args, **kwargs)
268+
269+
counted_flops = flop_counter_mode.get_total_flops()
270+
return counted_flops
271+
return None
272+
273+
274+
def countable_fx(node: torch.fx.Node) -> bool:
275+
assert isinstance(node, torch.fx.Node)
276+
if not hasattr(node, "target"):
277+
return False
278+
target = node.target
279+
if not hasattr(target, "overloadpacket"):
280+
return target in flop_registry
281+
packet = target.overloadpacket
282+
return packet in flop_registry

torch/_inductor/scheduler.py

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch._dynamo.utils import counters, dynamo_timed
2929
from torch._inductor.codecache import LambdaFuture, PyCodeCache
3030
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
31-
from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
31+
from torch.fx.experimental.symbolic_shapes import free_symbols
3232
from torch.utils._ordered_set import OrderedSet
3333
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
3434
from torch.utils._triton import has_triton
@@ -39,8 +39,8 @@
3939
from .comm_analysis import estimate_nccl_collective_runtime
4040
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
4141
from .exc import GPUTooOldForTriton, TritonMissing
42+
from .fx_utils import count_flops_fx, countable_fx
4243
from .ir import (
43-
ComputedBuffer,
4444
get_device_type,
4545
GraphPartitionSignature,
4646
MultiOutput,
@@ -783,6 +783,21 @@ def get_buf_bytes(
783783

784784
return buf_byte_accesses
785785

786+
@cache_on_self
787+
def estimate_flops(self) -> int | None:
788+
if self.node is None:
789+
return None
790+
fx_node = self.node.get_origin_node()
791+
if fx_node is None:
792+
return None
793+
if not countable_fx(fx_node):
794+
return None
795+
796+
flops = count_flops_fx(fx_node)
797+
798+
resolved_flops = V.graph.sizevars.size_hints((flops,), fallback=0)[0]
799+
return resolved_flops
800+
786801
@cache_on_self
787802
def get_estimated_runtime(self) -> float:
788803
"""
@@ -823,57 +838,29 @@ def get_estimated_runtime(self) -> float:
823838
except Exception:
824839
return 0
825840

826-
if isinstance(self, ExternKernelSchedulerNode):
827-
assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}"
828-
op = kernel_name_to_op.get(
829-
getattr(self.node, "python_kernel_name", ""), None
841+
if isinstance(self, FusedSchedulerNode):
842+
flops_est: int | None = sum(
843+
filter(
844+
None,
845+
(node.estimate_flops() for node in self.get_nodes()),
846+
)
830847
)
848+
else:
849+
flops_est = self.estimate_flops()
831850

832-
# if there is a resolved op, dry-run using fake mode and record flop count
833-
if op is not None:
834-
from torch._subclasses.fake_tensor import FakeTensorMode
835-
from torch.utils.flop_counter import FlopCounterMode
836-
837-
if any(
838-
len(free_unbacked_symbols(n.get_numel())) > 0
839-
for n in self.node.inputs
840-
):
841-
# Tensor has unbacked symints, we don't know how to estimate
842-
# runtime for that today
843-
return 0
844-
845-
with (
846-
FakeTensorMode() as fake_mode,
847-
FlopCounterMode(display=False) as flop_counter_mode,
848-
V.set_current_node(self.node.fx_node),
849-
V.set_fake_mode(fake_mode),
850-
):
851-
from .ir import ir_node_to_tensor
852-
853-
fake_inputs = [
854-
ir_node_to_tensor(input, guard_shape=False)
855-
for input in self.node.inputs
856-
]
857-
cls = self.node.__class__
858-
cls.process_kernel(op, *fake_inputs, **self.node.kwargs)
859-
860-
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
861-
factor = 1.0
862-
counted_flops = flop_counter_mode.get_total_flops()
863-
counted_bytes = self.get_read_write_buffers_sizes()
864-
compute_time = (factor * counted_flops / gpu_flops) * 1e9
865-
transfer_time = counted_bytes / gpu_memory_bandwidth
866-
867-
# Return estimated runtime in nanoseconds
868-
return max(compute_time, transfer_time)
869-
870-
elif isinstance(self, FusedSchedulerNode) or isinstance(
871-
self.node, ComputedBuffer
872-
):
873-
# Return estimated runtime in nanoseconds (bytes / gbps)
851+
if flops_est == 0 or flops_est is None:
852+
# no flops estimate, so fall back to memory estimate
874853
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
875854

876-
return 0
855+
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
856+
factor = 1.0
857+
counted_bytes = self.get_read_write_buffers_sizes()
858+
counted_bytes = 0 if counted_bytes is None else counted_bytes
859+
compute_time = (factor * flops_est / gpu_flops) * 1e9
860+
transfer_time = counted_bytes / gpu_memory_bandwidth
861+
862+
# Return estimated runtime in nanoseconds
863+
return max(compute_time, transfer_time)
877864

878865
def get_template_node(self) -> Optional[ir.TemplateBuffer]:
879866
return None
@@ -987,17 +974,6 @@ def should_prune(dep: Dep) -> bool:
987974
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
988975

989976

990-
# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
991-
kernel_name_to_op = {
992-
"extern_kernels.convolution": torch.ops.aten.convolution,
993-
"extern_kernels.mm": torch.ops.aten.mm,
994-
"extern_kernels.bmm": torch.ops.aten.bmm,
995-
"extern_kernels.addmm": torch.ops.aten.addmm,
996-
"extern_kernels._scaled_mm": torch.ops.aten._scaled_mm,
997-
"extern_kernels._scaled_grouped_mm": torch.ops.aten._scaled_grouped_mm,
998-
}
999-
1000-
1001977
class ExternKernelSchedulerNode(BaseSchedulerNode):
1002978
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
1003979
super().__init__(scheduler)

0 commit comments

Comments
 (0)
0