8000 Update on "[cudagraphs] Fix issue in collecting static_input_idxs" · pytorch/pytorch@8ccbcc7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ccbcc7

Browse files
committed
Update on "[cudagraphs] Fix issue in collecting static_input_idxs"
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
2 parents a592e8c + 16a8989 commit 8ccbcc7

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

test/inductor/test_inductor_freezing.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: inductor"]
22
import contextlib
3+
import copy
34
import functools
45
import importlib
56
import itertools
@@ -376,19 +377,33 @@ def foo(mod, x):
376377
mod(x)
377378

378379
def test_static_indices_cudagraph(self):
379-
mod = torch.nn.Linear(2, 2).to(self.device)
380+
if self.device != "cuda":
381+
return
382+
383+
mod1 = torch.nn.Sequential(
384+
torch.nn.Linear(2, 2).to(self.device), torch.nn.Linear(2, 2).to(self.device)
385+
)
386+
mod2 = copy.deepcopy(mod1)
380387

381-
def fn(x):
382-
return mod(x) + x
388+
def fn(x, y, mod):
389+
x.add_(1)
390+
getattr(mod, "0").bias.add_(2)
391+
getattr(mod, "1").weight.add_(3)
392+
return mod(x) + y
383393

384-
x = torch.randn(2, 2, device=self.device)
394+
x1 = torch.randn(2, 2, device=self.device)
395+
y1 = torch.randn(2, 2, device=self.device)
396+
x2 = x1.clone()
397+
y2 = y1.clone()
385398

386399
opt_fn = torch.compile(fn, mode="reduce-overhead")
387400

388401
with torch.no_grad():
389-
ref = fn(x)
390-
res = opt_fn(x)
402+
ref = fn(x1, y1, mod1)
403+
res = opt_fn(x2, y2, mod2)
391404
self.assertEqual(ref, res)
405+
self.assertEqual(x1, x2)
406+
self.assertEqual(y1, y2)
392407

393408
def test_rng_op(self):
394409
@torch.compile()

torch/_functorch/aot_autograd.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch._dynamo.utils import (
2020
CompileEventLogger,
2121
dynamo_timed,
22-
is_parameter_freezing,
2322
preserve_rng_state,
2423
set_feature_use,
2524
)
@@ -1038,10 +1037,7 @@ def _try_get_metadata_from_dynamo(
10381037
seen_sources.add(source)
10391038
aot_autograd_arg_pos_to_source.append(source)
10401039

1041-
# For freezing, the params are not lifted in the inductor Fx graph, so
1042-
# don't mark the params as static.
1043-
if not is_parameter_freezing():
1044-
static_input_indices.append(i)
1040+
static_input_indices.append(i)
10451041

10461042
# Collect the dynamo graph inputs
10471043
# TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
@@ -1057,12 +1053,8 @@ def _try_get_metadata_from_dynamo(
10571053
# input[i] in dynamo is now:
10581054
# input[i + len(extra_params)] in AOT,
10591055
# where extra_params are the params/buffers that dynamo baked into the
1060-
# OutputGraph. The special case is freezing, where the params are not
1061-
# lifted.
1062-
if is_parameter_freezing():
1063-
actual_pos = pos
1064-
else:
1065-
actual_pos = pos + len(param_keys)
1056+
# OutputGraph
1057+
actual_pos = pos + len(param_keys)
10661058

10671059
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
10681060
"_dynamo_static_input_type", None

torch/_inductor/freezing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,21 @@ def replace_params_with_constants(
5252
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
5353
]
5454

55+
static_indices_new = []
56+
static_indices_offset = 0
5557
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
5658
if i in mutated_inps or i in aliased_input_args:
5759
preserved_arg_indices.append(i)
58-
continue
59-
replace_node_with_constant(gm, node, real_input)
60+
if i in fw_metadata.static_input_indices:
61+
new_static_index = i - static_indices_offset
62+
static_indices_new.append(new_static_index)
63+
else:
64+
replace_node_with_constant(gm, node, real_input)
65+
static_indices_offset += 1
6066
# add on non param inputs
6167
preserved_arg_indices.extend(range(len(flat_params), len(params)))
6268
# is this necessary ?
69+
fw_metadata.static_input_indices = static_indices_new
6370
gm.recompile()
6471
return preserved_arg_indices
6572

0 commit comments

Comments
 (0)