8000 [dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph · pytorch/pytorch@b496d8b · GitHub
[go: up one dir, main page]

Skip to content

Commit b496d8b

Browse files
committed
[dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph
ghstack-source-id: a1ea341 Pull Request resolved: #147824
1 parent 86f3953 commit b496d8b

File tree

7 files changed

+38
-40
lines changed

7 files changed

+38
-40
lines changed

test/dynamo/test_decorators.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -425,41 +425,16 @@ def forward(self, a, *args):
425425
def _test_mark_static_address(self, guarded):
426426
# This test verifies that dynamo properly marks inputs as static
427427
# when using the mark_static_address API.
428-
# On 1st compile, we expect the input to be marked as static, with guarded
429-
# set depending on the `guarded` flag.
430-
# On 2nd compile, we expect the input to be unmarked
431-
# if inlining NN modules, we expect metadata to be present on the tensor, indicating
432-
# the static address type of the input
433-
# if not inlining NN modules, we expect the tensor to be present in the buffers attribute
434-
# of the graph.
428+
# For both inline_inbuilt_nn_modules True and False, we expect the
429+
# tensor to be present in the buffers attribute of the graph.
435430

436431
compiles_with_buffers = 0
437432
compiles = 0
438433

439434
def debug_compiler(gm, _):
440435
nonlocal compiles_with_buffers
441436
nonlocal compiles
442-
if torch._dynamo.config.inline_inbuilt_nn_modules:
443-
input_node = [
444-
n
445-
for n in gm.graph.nodes
446-
if n.op == "placeholder" and n.name == "l_x_"
447-
]
448-
self.assertEqual(len(input_node), 1)
449-
input_node = input_node[0]
450-
if compiles == 0:
451-
self.assertEqual(
452-
input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
453-
"guarded" if guarded else "unguarded",
454-
)
455-
elif compiles == 1:
456-
self.assertFalse(
457-
"_dynamo_static_input_type" in input_node.meta["tensor_dict"]
458-
)
459-
else:
460-
raise RuntimeError(f"Unexpected number of compiles: {compiles}")
461-
else:
462-
compiles_with_buffers += len(gm._buffers) > 0
437+
compiles_with_buffers += len(gm._buffers) > 0
463438
compiles += 1
464439
return gm
465440

@@ -472,7 +447,7 @@ def fn(x):
472447
torch._dynamo.mark_static_address(inp, guard=guarded)
473448

474449
fn(inp)
475-
if not torch._dynamo.config.inline_inbuilt_nn_modules:
450+
if guarded:
476451
self.assertEqual(compiles_with_buffers, 1)
477452

478453
inp2 = torch.ones(2)
@@ -482,7 +457,7 @@ def fn(x):
482457
# should not be incremented
483458
fn(inp2)
484459

485-
if not torch._dynamo.config.inline_inbuilt_nn_modules:
460+
if guarded:
486461
self.assertEqual(compiles_with_buffers, 1)
487462

488463
self.assertEqual(compiles, 2 if guarded else 1)

test/dynamo/test_subclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,9 +1848,9 @@ def inner_compile(
18481848
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
18491849
):
18501850
if dynamic:
1851-
self.assertEqual(static_input_idxs, [2, 3, 4])
1851+
self.assertEqual(static_input_idxs, [0, 1, 2, 3, 4])
18521852
else:
1853-
self.assertEqual(static_input_idxs, [1, 2])
1853+
self.assertEqual(static_input_idxs, [0, 1, 2])
18541854
return gm
18551855

18561856
compiler = functools.partial(compile_fx, inner_compile=inner_compile)

torch/_dynamo/output_graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,13 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
13611361

13621362
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
13631363
if not config.do_not_emit_runtime_asserts:
1364+
# There is a rare scenario where codegen_suffix adds a new entry
1365+
# to self.nn_modules while `root` knows only about the
1366+
# nn_modules at the time of its creation. This causes failures
1367+
# while creating the graph module because self.graph and root
1368+
# are out of sync. This only happens for `get_attr` nodes, so
1369+
# here we clean up the get_attr nodes that are unused.
1370+
self.remove_unused_get_attr_nodes()
13641371
insert_deferred_runtime_asserts(
13651372
fx.GraphModule(root, self.graph),
13661373
self.shape_env,
@@ -1568,6 +1575,11 @@ def example_inputs(self) -> list[torch.Tensor]:
15681575
result = [arg.example for arg in self.graphargs]
15691576
return result
15701577

1578+
def remove_unused_get_attr_nodes(self) -> None:
1579+
for node in reversed(list(self.graph.nodes)):
1580+
if node.op == "get_attr" and len(list(node.users)) == 0:
1581+
self.remove_node(node)
1582+
15711583
def remove_unused_graphargs(self) -> None:
15721584
# NB: It's always OK to drop GraphArg for symbols that ended up being
15731585
# specialized. You don't even have to make a guard for it, because

torch/_dynamo/variables/builder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,21 @@ def wrap_tensor(self, value: torch.Tensor):
16641664
value, self.name, source=source
16651665
)
16661666

1667+
if get_static_address_type(value) == "guarded":
1668+
# If it's a guarded tensor, we can install the parameter directly
1669+
# into the Fx graph instead of lifting it as an input. Lifting
1670+
# offers no benefit, such as regional compilation, since we still
1671+
# guard on the tensor's ID. Moreover, installing it in the Fx graph
1672+
# eliminates the pre-graph bytecode required to extract the tensor
1673+
# from locals/globals, reducing overhead. This can lead to
1674+
# significant cost savings, especially for optimizers handling many
1675+
# tensors.
1676+
self.install_guards(GuardBuilder.ID_MATCH)
1677+
self.assert_not_wrapped_by_this_graph(value)
1678+
return self.tx.output.register_attr_or_module(
1679+
value, self.name, source=source
1680+
)
1681+
16671682
if is_constant_source(source):
16681683
self.assert_not_wrapped_by_this_graph(value)
16691684
return self.tx.output.register_attr_or_module(

torch/_dynamo/variables/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
358358
# mark these tensors as static for cudagraphs
359359
mark_static_address(tensor_value)
360360
source = self.tensor_to_source[tensor_value]
361-
self.static_tensor_names.add(tx.output.module_key_name(source.name))
361+
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
362362
elif tensor_value in self.grad_to_source:
363363
source = self.grad_to_source[tensor_value]
364364
else:
@@ -367,7 +367,7 @@ def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
367367

368368
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
369369
source = GlobalWeakRefSource(global_name)
370-
self.static_tensor_names.add(tx.output.module_key_name(source.name))
370+
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
371371

372372
return VariableTracker.build(tx, tensor_value, source)
373373

torch/_functorch/aot_autograd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,8 @@ def _try_get_metadata_from_dynamo(
10311031
aot_autograd_arg_pos_to_source.append(source)
10321032

10331033
# Collect the dynamo graph inputs
1034+
# TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
1035+
# matched tensors back into the Fx graph, this might not be necessary.
10341036
static_input_indices = []
10351037
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
10361038
assert hasattr(node, "_dynamo_source")

torch/_inductor/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,12 +2085,6 @@ def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int)
20852085
# AOT won't lift any parameters if we're inlining NN Modules
20862086
# however desugaring subclasses will still add arguments
20872087
# resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
2088-
if (
2089-
torch._dynamo.config.inline_inbuilt_nn_modules
2090-
and not torch._dynamo.utils.is_parameter_freezing()
2091-
):
2092-
return 0
2093-
20942088
return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
20952089

20962090

0 commit comments

Comments
 (0)
0