8000 Lint and add patch · pytorch/pytorch@28d289e · GitHub
[go: up one dir, main page]

Skip to content

Commit 28d289e

Browse files
committed
Lint and add patch
1 parent 0280a45 commit 28d289e

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

torch/_export/utils.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from collections.abc import Iterable
1212
from contextlib import contextmanager
1313
from inspect import ismethod, Parameter
14-
from sortedcontainers import SortedDict, SortedList
1514
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
1615

16+
from sortedcontainers import SortedDict, SortedList
17+
1718
import torch
1819
from torch._guards import detect_fake_mode
1920
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@@ -265,6 +266,8 @@ def _rename_without_collisions(
265266
"""
266267
Renames nodes to avoid name collisions, with suffixing.
267268
name_map: map from original name to new name
269+
find_available: map prefix to available suffix
270+
used_names: cache of used names
268271
orig_name: mapping key
269272
name: candidate name (potentially suffixed, e.g. mul_2)
270273
is_placeholder: if the node is a placeholder, avoid detecting suffix
@@ -283,7 +286,7 @@ def _rename_without_collisions(
283286
key = name
284287
else:
285288
key = prefix
286-
289+
287290
new_name = name
288291
if new_name in used_names:
289292
new_name = f"{key}_{find_available[key][0] + 1}"
@@ -294,8 +297,9 @@ def _rename_without_collisions(
294297
find_available[prefix].add(int(n))
295298

296299
while len(find_available[prefix]) >= 2 and (
297-
find_available[prefix][0]+1 == find_available[prefix][1] or
298-
find_available[prefix][0] == find_available[prefix][1]):
300+
find_available[prefix][0] + 1 == find_available[prefix][1]
301+
or find_available[prefix][0] == find_available[prefix][1]
302+
):
299303
find_available[prefix].pop(0)
300304

301305
name_map[orig_name] = new_name
@@ -896,6 +900,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
896900
Different HOO subgraph types have different input schemas, so we first enumerate them
897901
and gather the top-level named placeholder nodes.
898902
"""
903+
899904
def build_cache(name, find_available, used_names):
900905
used_names.add(name)
901906
match = re.match(r"(.*)_(\d+)", name)
@@ -906,10 +911,13 @@ def build_cache(name, find_available, used_names):
906911

907912
if match and prefix not in find_available:
908913
find_available[prefix] = SortedList([0])
909-
914+
910915
if match:
911916
find_available[prefix].add(int(n))
912-
while len(find_available[prefix]) >= 2 and find_available[prefix][0]+1 == find_available[prefix][1]:
917+
while (
918+
len(find_available[prefix]) >= 2
919+
and find_available[prefix][0] + 1 == find_available[prefix][1]
920+
):
913921
find_available[prefix].pop(0)
914922

915923
# gather all HOO subgraphs and their top-level named placeholder nodes
@@ -943,7 +951,9 @@ def build_cache(name, find_available, used_names):
943951
node.name = node.target = hoo_phs[i].name
944952
build_cache(node.name, find_available, used_names)
945953
else: # non-placeholder, check for collisions
946-
node.name = _rename_without_collisions(name_map, find_available, used_names, node.name, node.name)
954+
node.name = _rename_without_collisions(
955+
name_map, find_available, used_names, node.name, node.name
956+
)
947957

948958
# recurse and recompile
949959
_name_hoo_subgraph_placeholders(subgraph)
@@ -1062,8 +1072,9 @@ def _extract_pytree_key(x):
10621072
for node in gm.graph.nodes:
10631073
if node.op == "placeholder":
10641074
continue
1065-
_rename_without_collisions(name_map, find_available,
1066-
used_names, node.name, node.name)
1075+
_rename_without_collisions(
1076+
name_map, find_available, used_names, node.name, node.name
1077+
)
10671078

10681079
# assign new node names
10691080
for node in gm.graph.nodes:

torch/export/exported_program.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import dataclasses
66
import functools
77
import operator
8+
import re
89
import types
910
import warnings
1011
from collections import defaultdict, namedtuple
1112
from collections.abc import Iterator
1213
from contextlib import contextmanager
1314
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
1415

16+
from sortedcontainers import SortedDict, SortedList
17+
1518
from torch._guards import tracing, TracingContext
1619
from torch._higher_order_ops.utils import autograd_not_implemented
1720
from torch._library.fake_class_registry import FakeScriptObject
@@ -617,12 +620,38 @@ def update_arg(old_ar A3DB g, new_ph):
617620
for old_ph, new_ph in zip(old_placeholders, new_placeholders):
618621
new_ph.name = new_ph.target = old_ph.name
619622

623+
def build_cache(name, find_available, used_names):
624+
used_names.add(name)
625+
match = re.match(r"(.*)_(\d+)", name)
626+
if match:
627+
prefix, n = match.group(1), match.group(2)
628+
if name not in find_available:
629+
find_available[name] = SortedList([0])
630+
631+
if match and prefix not in find_available:
632+
find_available[prefix] = SortedList([0])
633+
634+
if match:
635+
find_available[prefix].add(int(n))
636+
while (
637+
len(find_available[prefix]) >= 2
638+
and find_available[prefix][0] + 1 == find_available[prefix][1]
639+
):
640+
find_available[prefix].pop(0)
641+
620642
# handle name collisions with newly decomposed graph nodes
621-
name_map = {ph.name: ph.name for ph in new_placeholders}
643+
name_map = {}
644+
find_available = SortedDict()
645+
used_names = set()
646+
for ph in new_placeholders:
647+
name_map[ph.name] = ph.name
648+
build_cache(ph.name, find_available, used_names)
622649
for node in gm.graph.nodes:
623650
if node.op == "placeholder":
624651
continue
625-
node.name = _rename_without_collisions(name_map, node.name, node.name)
652+
node.name = _rename_without_collisions(
653+
name_map, find_available, used_names, node.name, node.name
654+
)
626655

627656
# propagate names to higher order op subgraphs
628657
_name_hoo_subgraph_placeholders(gm)

0 commit comments

Comments
 (0)
0