1111from collections .abc import Iterable
1212from contextlib import contextmanager
1313from inspect import ismethod , Parameter
14- from sortedcontainers import SortedDict , SortedList
1514from typing import Any , Callable , Optional , TYPE_CHECKING , Union
1615
16+ from sortedcontainers import SortedDict , SortedList
17+
1718import torch
1819from torch ._guards import detect_fake_mode
1920from torch ._subclasses .fake_tensor import FakeTensor , FakeTensorMode
@@ -265,6 +266,8 @@ def _rename_without_collisions(
265266 """
266267 Renames nodes to avoid name collisions, with suffixing.
267268 name_map: map from original name to new name
269+ find_available: map prefix to available suffix
270+ used_names: cache of used names
268271 orig_name: mapping key
269272 name: candidate name (potentially suffixed, e.g. mul_2)
270273 is_placeholder: if the node is a placeholder, avoid detecting suffix
@@ -283,7 +286,7 @@ def _rename_without_collisions(
283286 key = name
284287 else :
285288 key = prefix
286-
289+
287290 new_name = name
288291 if new_name in used_names :
289292 new_name = f"{ key } _{ find_available [key ][0 ] + 1 } "
@@ -294,8 +297,9 @@ def _rename_without_collisions(
294297 find_available [prefix ].add (int (n ))
295298
296299 while len (find_available [prefix ]) >= 2 and (
297- find_available [prefix ][0 ]+ 1 == find_available [prefix ][1 ] or
298- find_available [prefix ][0 ] == find_available [prefix ][1 ]):
300+ find_available [prefix ][0 ] + 1 == find_available [prefix ][1 ]
301+ or find_available [prefix ][0 ] == find_available [prefix ][1 ]
302+ ):
299303 find_available [prefix ].pop (0 )
300304
301305 name_map [orig_name ] = new_name
@@ -896,6 +900,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
896900 Different HOO subgraph types have different input schemas, so we first enumerate them
897901 and gather the top-level named placeholder nodes.
898902 """
903+
899904 def build_cache (name , find_available , used_names ):
900905 used_names .add (name )
901906 match = re .match (r"(.*)_(\d+)" , name )
@@ -906,10 +911,13 @@ def build_cache(name, find_available, used_names):
906911
907912 if match and prefix not in find_available :
908913 find_available [prefix ] = SortedList ([0 ])
909-
914+
910915 if match :
911916 find_available [prefix ].add (int (n ))
912- while len (find_available [prefix ]) >= 2 and find_available [prefix ][0 ]+ 1 == find_available [prefix ][1 ]:
917+ while (
918+ len (find_available [prefix ]) >= 2
919+ and find_available [prefix ][0 ] + 1 == find_available [prefix ][1 ]
920+ ):
913921 find_available [prefix ].pop (0 )
914922
915923 # gather all HOO subgraphs and their top-level named placeholder nodes
@@ -943,7 +951,9 @@ def build_cache(name, find_available, used_names):
943951 node .name = node .target = hoo_phs [i ].name
944952 build_cache (node .name , find_available , used_names )
945953 else : # non-placeholder, check for collisions
946- node .name = _rename_without_collisions (name_map , find_available , used_names , node .name , node .name )
954+ node .name = _rename_without_collisions (
955+ name_map , find_available , used_names , node .name , node .name
956+ )
947957
948958 # recurse and recompile
949959 _name_hoo_subgraph_placeholders (subgraph )
@@ -1062,8 +1072,9 @@ def _extract_pytree_key(x):
10621072 for node in gm .graph .nodes :
10631073 if node .op == "placeholder" :
10641074 continue
1065- _rename_without_collisions (name_map , find_available ,
1066- used_names , node .name , node .name )
1075+ _rename_without_collisions (
1076+ name_map , find_available , used_names , node .name , node .name
1077+ )
10671078
10681079 # assign new node names
10691080 for node in gm .graph .nodes :
0 commit comments