10000 [Graph Partition] reorder for minimal number of partitions (#151968) · pytorch/pytorch@797768c · GitHub
[go: up one dir, main page]

Skip to content

Commit 797768c

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[Graph Partition] reorder for minimal number of partitions (#151968)
This pr adds an optimal reordering for minimizing #partitions. ## Optimal reordering for minimizing #partitions A bfs could minimize #partitions (ignore peak memory for now): 1. For each node, compute node_to_indegree: dict[node, int]. 2. Maintain 2 queues: cudagraphable_nodes, and non_cudagraphable_nodes. Iterate through all nodes and add nodes to one of these 2 queues if node_to_indegree[node] == 0. 3. While non_cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0. 4. While cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0. 5. Repeat step 3 & 4 until all nodes have been scheduled. We call this strategy `reorder_for_minimizing_partition`. **Q: Why is this optimal?** Suppose this is not optimal, we have a counter example with 2 non_cudagraphable regions: ``` [non_cudagrable1, cudagraphable2, non_cudagraphable3] ``` where we can reorder to only 1 non_cudagraphable region: ``` [non_cudagrable1, non_cudagraphable3, cudagraphable2] ``` This reorder means non_cudagraphable3 does not depend on cudagraphable2. So after we scheduled non_cudagraphable1, both non_cudagraphable3 and cudagraphable2 have in_degree as 0. If this is true, Step 3 should have already scheduled non_cudagraphable3 before cudagraphable2 such that the counter example cannot exist. This shows we cannot find such a counter example and the bfs is optimal on minimizing #partitions. ## Minimize peak memory `reorder_for_peak_memory` currently uses topological_sort_dfs, topological_sort_lpmf, and topological_sort_bfs, where the later 2 are bfs. ILP brings small benefits and it can hardly scale to more than 100 nodes, according to @xuanzhang816. So ILP is not used for peak memory reorder in the inductor. Heuristics strategy: - Conduct reorder_for_peak_memory as the default order - Conduct reorder_for_minimal_partitions and get results as list[tuple[partition, bool]], where partition: list[BaseSchedulerNode] and bool for cudagraphable. - If the reorder increases peak memory too much, we use the default order. Pull Request resolved: #151968 Approved by: https://github.com/eellison
1 parent a77a447 commit 797768c

File tree

3 files changed

+222
-21
lines changed

3 files changed

+222
-21
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,8 +2686,8 @@ def forward(self, x):
26862686
loss.backward()
26872687
optimizer.step()
26882688

2689-
# 2 graph partitions lead to 2 fwd cudagraphs and 2 bwd cudagraphs
2690-
self.assertEqual(self.get_manager().new_graph_id().id, 4)
2689+
# 2 graph partitions lead to 2 fwd cudagraphs and 1 bwd cudagraphs
2690+
self.assertEqual(self.get_manager().new_graph_id().id, 3)
26912691

26922692
@torch._inductor.config.patch("graph_partition", True)
26932693
def test_graph_partition_cpu_only(self):
@@ -3088,6 +3088,89 @@ def run(shape_x, shape_y):
30883088

30893089
self.assertEqual(self.get_manager().new_graph_id().id, 3)
30903090

3091+
@torch._inductor.config.patch("graph_partition", True)
3092+
def test_graph_partition_reorder_cpu_and_gpu(self):
3093+
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
3094+
x_cuda0 = x_cuda + 1
3095+
x_cuda1 = x_cuda0 @ weight_cuda
3096+
x_cuda2 = 2 * (x_cuda1 + x_cuda)
3097+
3098+
y_cpu0 = y_cpu + 1
3099+
y_cpu1 = y_cpu0 @ weight_cpu
3100+
3101+
z_cuda0 = z_cuda + 1
3102+
z_cuda1 = z_cuda0 @ weight_cuda
3103+
z_cuda2 = 2 * (z_cuda1 + z_cuda)
3104+
3105+
return x_cuda2, y_cpu1, z_cuda2
3106+
3107+
x_cuda = torch.randn(3, 3, device="cuda")
3108+
y_cpu = torch.randn(3, 3, device="cpu")
3109+
z_cuda = torch.randn(3, 3, device="cuda")
3110+
weight_cuda = torch.randn(3, 3, device="cuda")
3111+
weight_cpu = torch.randn(3, 3, device="cpu")
3112+
3113+
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
3114+
3115+
compiled_f = torch.compile(f, mode="reduce-overhead")
3116+
for _ in range(3):
3117+
compiled_out = compiled_f(
3118+
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
3119+
)
3120+
self.assertEqual(eager_out, compiled_out)
3121+
3122+
# reorder merges ops on cuda into 1 graph partition
3123+
self.assertEqual(self.get_manager().new_graph_id().id, 1)
3124+
3125+
@torch._inductor.config.patch("graph_partition", True)
3126+
def test_graph_partition_reorder_cpu_and_gpu_interleave(self):
3127+
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
3128+
# partition 1 on cuda, no dependency
3129+
x_cuda0 = x_cuda + 1
3130+
x_cuda1 = x_cuda0 @ weight_cuda
3131+
x_cuda2 = 2 * (x_cuda1 + x_cuda)
3132+
3133+
# partition 2 on cpu w/ dependency on partition 1
3134+
y_cpu0 = y_cpu + 1
3135+
x_cuda2_cpu = x_cuda2.cpu() # adds dependency on gpu computations
3136+
y_cpu1 = y_cpu0 @ weight_cpu + x_cuda2_cpu
3137+
3138+
# partition 3 on cuda w/o dependency
3139+
z_cuda0 = z_cuda + 1
3140+
z_cuda1 = z_cuda0 @ weight_cuda
3141+
z_cuda2 = 2 * (z_cuda1 + z_cuda)
3142+
3143+
# partition 4 on cpu w/o dependency
3144+
y_cpu2 = y_cpu + 5
3145+
y_cpu3 = y_cpu2 @ weight_cpu
3146+
3147+
# partition 5 on cuda w/o dependency
3148+
u_cuda0 = z_cuda + 3
3149+
u_cuda1 = u_cuda0 @ weight_cuda
3150+
u_cuda2 = 2 * (u_cuda0 + u_cuda1)
3151+
3152+
return x_cuda2, y_cpu1, z_cuda2, y_cpu3, u_cuda2
3153+
3154+
x_cuda = torch.randn(3, 3, device="cuda")
3155+
y_cpu = torch.randn(3, 3, device="cpu")
3156+
z_cuda = torch.randn(3, 3, device="cuda")
3157+
weight_cuda = torch.randn(3, 3, device="cuda")
3158+
weight_cpu = torch.randn(3, 3, device="cpu")
3159+
3160+
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
3161+
3162+
compiled_f = torch.compile(f, mode="reduce-overhead")
3163+
for _ in range(3):
3164+
compiled_out = compiled_f(
3165+
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
3166+
)
3167+
self.assertEqual(eager_out, compiled_out)
3168+
3169+
# the optimal order is
3170+
# [[partition 4 on cpu], [partition 1,3,5 on cuda], [partition 2 on cpu]]
3171+
# since partition2 depends on partition1. So we have 1 cudagraph in total.
3172+
self.assertEqual(self.get_manager().new_graph_id().id, 1)
3173+
30913174
@config.patch(implicit_fallbacks=True)
30923175
@torch._inductor.config.patch("graph_partition", True)
30933176
def test_graph_partition_reorder_custom_op_with_no_dependency(self):

torch/_inductor/memory.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
torch_log = logging.getLogger(__name__)
2323

2424

25+
@dataclasses.dataclass
26+
class PeakMemoryResult:
27+
order: list[BaseSchedulerNode]
28+
peak_memory: int
29+
method: str
30+
31+
2532
@dataclasses.dataclass
2633
class MemoryPlanningInfoForBuffer:
2734
size_alloc: int = 0
@@ -578,6 +585,35 @@ def visit(n: BaseSchedulerNode) -> None:
578585
return result
579586

580587

588+
def prepare_planning_info(
589+
nodes: list[BaseSchedulerNode],
590+
name_to_buf: dict[str, SchedulerBuffer],
591+
name_to_fused_node: dict[str, BaseSchedulerNode],
592+
graph_inputs: OrderedSet[str],
593+
graph_outputs: OrderedSet[str],
594+
) -> tuple[int, dict[str, FreeableInputBuffer]]:
595+
"""
596+
Prepare planning info. As nodes are scheduled one at a time, these help
597+
keep track of when a buffer can be freed, and when a node can be scheduled
598+
599+
Returns:
600+
int: peak memory estimation
601+
dict[str, FreeableInputBuffer]: name to freeable input buffer
602+
"""
603+
name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs)
604+
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
605+
assign_memory_planning_info_for_scheduler_nodes(
606+
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
607+
)
608+
609+
# the default
610+
estimated_peak_memory, _ = estimate_peak_memory(
611+
nodes, name_to_freeable_input_buf, graph_outputs
612+
)
613+
614+
return estimated_peak_memory, name_to_freeable_input_buf
615+
616+
581617
def reorder_for_peak_memory(
582618
nodes: list[BaseSchedulerNode],
583619
name_to_buf: dict[str, SchedulerBuffer],
@@ -597,29 +633,16 @@ def reorder_for_peak_memory(
597633

598634
torch_log.info("Reordering for peak memory -- %d nodes", len(nodes))
599635

600-
@dataclasses.dataclass
601-
class PeakMemoryResult:
602-
order: list[BaseSchedulerNode]
603-
peak_memory: int
604-
method: str
605-
606-
# preparation -- as nodes are scheduled one at a time, these help
607-
# keep track of when a buffer can be freed, and when a node can be scheduled
608-
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
609-
nodes, graph_inputs
610-
)
611-
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
612-
assign_memory_planning_info_for_scheduler_nodes(
613-
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
636+
estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
637+
nodes,
638+
name_to_buf,
639+
name_to_fused_node,
640+
graph_inputs,
641+
graph_outputs,
614642
)
615643

616644
# keep track of the peak memory estimates of different methods
617645
peak_memory_diff_methods: list[PeakMemoryResult] = []
618-
619-
# the default
620-
estimated_peak_memory, _ = estimate_peak_memory(
621-
nodes, name_to_freeable_input_buf, graph_outputs
622-
)
623646
peak_memory_diff_methods.append(
624647
PeakMemoryResult(nodes, estimated_peak_memory, "baseline")
625648
)

torch/_inductor/scheduler.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,7 @@ def _init(self, nodes: list[ir.Operation]) -> None:
21042104
self.process_grouped_nodes()
21052105

21062106
if torch._inductor.config.graph_partition:
2107+
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
21072108
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
21082109

21092110
self.compute_last_usage()
@@ -4282,6 +4283,100 @@ def get_graph_partition_signature(
42824283

42834284
return signatures[::-1]
42844285

4286+
def reorder_for_minimizing_partition(
4287+
self,
4288+
nodes: list[BaseSchedulerNode],
4289+
) -> list[BaseSchedulerNode]:
4290+
"""
4291+
Reorder nodes to minimize the number of partitions via a bfs
4292+
topological sort. This is the optimal reodering such that the
4293+
number of partitions cannot be reduced further. This may be
4294+
sub-optimal for other metrics such as peak memory. This does not
4295+
change relative orders of two cudagraphable nodes, nor the
4296+
relative order of two non_cudagraphable nodes.
4297+
"""
4298+
node_to_indegree: dict[BaseSchedulerNode, int] = dict()
4299+
cudagraphable_nodes: collections.deque[BaseSchedulerNode] = collections.deque()
4300+
non_cudagraphable_nodes: collections.deque[BaseSchedulerNode] = (
4301+
collections.deque()
4302+
)
4303+
4304+
def insert_pending_nodes(node: BaseSchedulerNode) -> None:
4305+
if self.should_partition(node):
4306+
non_cudagraphable_nodes.append(node)
4307+
else:
4308+
cudagraphable_nodes.append(node)
4309+
4310+
def update_indegree(node: BaseSchedulerNode) -> None:
4311+
for succ_node in node.mpi_node.succ_nodes:
4312+
assert node_to_indegree[succ_node] > 0
4313+
node_to_indegree[succ_node] -= 1
4314+
if node_to_indegree[succ_node] == 0:
4315+
insert_pending_nodes(succ_node)
4316+
4317+
for node in nodes:
4318+
node_to_indegree[node] = len(node.mpi_node.pred_nodes)
4319+
if node_to_indegree[node] == 0:
4320+
insert_pending_nodes(node)
4321+
4322+
schedule: list[BaseSchedulerNode] = []
4323+
num_iters: int = 0
4324+
while num_iters < len(nodes) and (
4325+
non_cudagraphable_nodes or cudagraphable_nodes
4326+
):
4327+
while non_cudagraphable_nodes:
4328+
node = non_cudagraphable_nodes.popleft()
4329+
schedule.append(node)
4330+
update_indegree(node)
43 10000 31+
4332+
while cudagraphable_nodes:
4333+
node = cudagraphable_nodes.popleft()
4334+
schedule.append(node)
4335+
update_indegree(node)
4336+
4337+
num_iters += 1
4338+
4339+
if num_iters > len(nodes):
4340+
raise RuntimeError(
4341+
"""
4342+
Failed to schedule, while loop ran too long when
4343+
reordering for minimizing the num of partitions
4344+
"""
4345+
)
4346+
4347+
return schedule
4348+
4349+
def maybe_reorder_for_minimizing_partition(
4350+
self,
4351+
nodes: list[BaseSchedulerNode],
4352+
) -> list[BaseSchedulerNode]:
4353+
"""
4354+
Reorder nodes to minimize the number of partitions if this only slightly
4355+
increase peak memory.
4356+
"""
4357+
from .memory import estimate_peak_memory, prepare_planning_info
4358+
4359+
graph_outputs = OrderedSet(V.graph.get_output_names())
4360+
4361+
default_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
4362+
nodes,
4363+
self.name_to_buf,
4364+
self.name_to_fused_node,
4365+
OrderedSet(V.graph.graph_inputs.keys()),
4366+
graph_outputs,
4367+
)
4368+
4369+
reordered_nodes = self.reorder_for_minimizing_partition(nodes)
4370+
reorder_peak_memory, _ = estimate_peak_memory(
4371+
reordered_nodes, name_to_freeable_input_buf, graph_outputs
4372+
)
4373+
4374+
# 1.1 here means 10% extra peak memory budget which is quite arbitrary
4375+
if reorder_peak_m 5E91 emory < default_peak_memory * 1.1:
4376+
return reordered_nodes
4377+
4378+
return nodes
4379+
42854380
def reorder_for_partition_with_simple_dependency(
42864381
self, nodes: list[BaseSchedulerNode]
42874382
) -> list[BaseSchedulerNode]:

0 commit comments

Comments
 (0)
0