8000 [dynamo] Use the new `get_unique_name_wrt` helper when applicable (#1… · pytorch/pytorch@7e0ef2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 7e0ef2c

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Use the new get_unique_name_wrt helper when applicable (#146950)
This patch removes some duplicated name generation logic in Dynamo. Pull Request resolved: #146950 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367
1 parent f46f0e4 commit 7e0ef2c

File tree

2 files changed

+16
-29
lines changed

2 files changed

+16
-29
lines changed

torch/_dynamo/output_graph.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,15 +1554,7 @@ def dedup_pass(self):
15541554
return dict()
15551555

15561556
def install_subgraph(self, name, sub_gm):
1557-
next_name = None
1558-
i = 0
1559-
while not next_name:
1560-
candidate = f"{name}_{i}"
1561-
if candidate in self.nn_modules:
1562-
i += 1
1563-
else:
1564-
next_name = candidate
1565-
1557+
next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)
15661558
sub_gm.__name__ = next_name
15671559
sub_gm.torchdynamo_force_dynamic = False
15681560
# This graph module is not present in the user space, so it can't be
@@ -2289,14 +2281,7 @@ def create_graph_input(
22892281
TracingContext.extract_stack()
22902282
)
22912283

2292-
# unique
2293-
if name in self.input_name_to_proxy:
2294-
for i in itertools.count():
2295-
candidate_name = f"{name}_{i}"
2296-
if candidate_name not in self.input_name_to_proxy:
2297-
name = candidate_name
2298-
break
2299-
2284+
name = get_unique_name_wrt(name, self.input_name_to_proxy)
23002285
if self.input_name_to_proxy:
23012286
prev_name = next(reversed(self.input_name_to_proxy))
23022287
node = self.input_name_to_proxy[prev_name].node

torch/_dynamo/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,23 +2583,25 @@ def get_safe_global_name(tx, root, obj):
25832583
return f"{root}_{id(obj)}_c{tx.output.compile_id}"
25842584

25852585

2586-
def get_unique_name_wrt(prefix: str, *containers) -> str:
2586+
def is_in(item: Any, *containers) -> bool:
2587+
for container in containers:
2588+
if item in container:
2589+
return True
2590+
return False
2591+
2592+
2593+
def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str:
25872594
"""
25882595
Return a name that starts with `prefix` and is not in any of the
25892596
`containers` (e.g., map, set).
25902597
"""
2591-
name = prefix
2598+
if not requires_suffix and not is_in(prefix, *containers):
2599+
return prefix
2600+
25922601
for i in itertools.count():
2593-
found = False
2594-
for container in containers:
2595-
if name in container:
2596-
found = True
2597-
break
2598-
2599-
if not found:
2600-
return name
2601-
# else update and retry
2602-
name = f"{prefix}_{i}"
2602+
candidate = f"{prefix}_{i}"
2603+
if not is_in(candidate, *containers):
2604+
return candidate
26032605

26042606
raise AssertionError("unreachable")
26052607

0 commit comments

Comments
 (0)
0