8000 [cudagraphs] Fix issue in collecting static_input_idxs by anijain2305 · Pull Request #152287 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[cudagraphs] Fix issue in collecting static_input_idxs #152287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
4 changes: 2 additions & 2 deletions test/dynamo/test_subclasses.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -2135,9 +2135,9 @@ def inner_compile(
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
):
if dynamic:
self.assertEqual(static_input_idxs, [0, 1, 2, 3, 4])
self.assertEqual(static_input_idxs, [2, 3, 4])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks strictly more correct now than previously. For sanity, this is the signature of the AOT graph in this test:

def forward(self, arg0_1: "Sym(s25)", arg1_1: "f32[s25][1]cpu", arg2_1: "f32[s25][1]cpu", arg3_1: "f32[s25][1]cpu", arg4_1: "Sym(s25)", arg5_1: "f32[s25][1]cpu"):

Where indices [2,3] correspond to the two static tensor inputs that mapped to the static TwoTensor subclass.

One thing that is wrong in this test though is that:

(1) in the dynamic shapes variant of this test, we have extra SymInt graph args that correspond to the symbolic sizes of the subclass

(2) we are marking those inputs as static indices as well, which is happening here: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/subclass_utils.py#L308

This seems wrong. It might turn out not cause too many problems, if inductor has logic to properly filter out SymInts from the "static input indices" list later (given that integers have no memory address and get burned into cudagraphs anyway). But we should probably fix it either way. cc @mlazos

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, I can take a look at the issue.

else:
self.assertEqual(static_input_idxs, [0, 1, 2])
self.assertEqual(static_input_idxs, [1, 2])
return gm

compiler = functools.partial(compile_fx, inner_compile=inner_compile)
Expand Down
34 changes: 34 additions & 0 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,40 @@ def fn(x, y):
self.run_static_input_param_test(fn, 4)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)

@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_no_rerecord_with_mark_static_address(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)

mod = Mod().cuda()

def fn_eager(x, marked_static_y):
return torch.cos(x) + mod(marked_static_y)

with torch.device("cuda"):
fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")

# y is marked static
y = torch.randn(2, 2)
torch._dynamo.mark_static_address(y)

# Chanhing pointer of x should not lead to re-records
for _ in range(5):
x = torch.randn(2, 2, requires_grad=True)
res = fn_compiled(x, y)
res.sum().backward()
x.grad = None
mod.linear.weight.grad = None
mod.linear.bias.grad = None
# One forward and one backward
self.assertEqual(self.get_manager().new_graph_id().id, 2)

def test_tensor_constant_mutation(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
Expand Down
17 changes: 12 additions & 5 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,18 +1028,19 @@ def _try_get_metadata_from_dynamo(
seen_sources = set()

aot_autograd_arg_pos_to_source = []
static_input_indices = []
# Collect the new inputs lifted by aotdispatch
for name in param_keys:
for i, name in enumerate(param_keys):
assert name in param_name_to_source, f"{name} not found."
source = param_name_to_source[name]
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
static_input_indices.append(i)

# Collect the dynamo graph inputs
# TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
# matched tensors back into the Fx graph, this might not be necessary.
static_input_indices = []
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
assert hasattr(node, "_dynamo_source")
source = node._dynamo_source
Expand All @@ -1048,16 +1049,22 @@ def _try_get_metadata_from_dynamo(
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)

# input[i] in dynamo is now:
# input[i + len(extra_params)] in AOT,
# where extra_params are the params/buffers that dynamo baked into
# the OutputGraph
actual_pos = pos + len(param_keys)

if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_inputs_log.debug(
"Adding static input pos %s for source %s", pos, source_name
"Adding static input pos %s for source %s", actual_pos, source_name
)
static_input_indices.append(pos)
static_input_indices.append(actual_pos)
else:
static_inputs_log.debug(
"Non-static input pos %s for source %s", pos, source_name
"Non-static input pos %s for source %s", actual_pos, source_name
)

assert full_args_num == len(aot_autograd_arg_pos_to_source)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_static_input_idxs(num_fixed: int) -> list[int]:
if not context or not context.fw_metadata:
return fixed

return fixed + context.fw_metadata.static_input_indices
return context.fw_metadata.static_input_indices


def record_original_output_strides(gm: GraphModule) -> None:
Expand Down
Loading
0