8000 [Trace PyDispatcher] Capture Vmapped autograd function as graph by yanboliang · Pull Request #146288 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Trace PyDispatcher] Capture Vmapped autograd function as graph #146288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/yanboliang/62/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 184 additions & 1 deletion test/dynamo/test_python_dispatcher.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import torch
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
normalize_gm,
)
from torch.testing._internal.common_cuda import TEST_CUDA


Expand Down Expand Up @@ -130,6 +135,184 @@ def fn(x, y):
# No recompile
self.assertEqual(counter.frame_count, 1)

def test_vmapped_autograd_function(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing, we have a lot of autograd.Function we can test. A good comprehensive way to do this is to copy-paste the following and add torch.compile(backend="eager") testing to it

"test_vmap_exhaustive",
(but only for the autograd_function_db).

The vmap tests are able to generate inputs with various in_dims, so that better exercises the logic. They're also able to generate vmap(vmap( tests.

eager = EagerAndRecordGraphs()

class Foo(torch.autograd.Function):
generate_vmap_rule = True

@staticmethod
def forward(x):
return x * 2

@staticmethod
def setup_context(ctx, inputs, output):
pass

@staticmethod
def backward(ctx, grad):
return grad * 2

@torch.compile(backend=eager, fullgraph=True)
def fn(x):
return torch.vmap(Foo.apply)(x)
Comment on lines +156 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a more complicated (non-pointwise) test case with double vmap? Maybe use double vmap on the LinearFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!


x = torch.randn(2, 3, requires_grad=True)
self.assertEqual(fn(x), torch.vmap(Foo.apply)(x))

graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2, 3]"):
l_x_ = L_x_

lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None

_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None

a: "f32[3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None

_are_functor 8000 ch_transforms_active = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active = None

_are_functorch_transforms_active_1 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_1 = None

child: "f32[3]" = torch._C._functorch.unwrap_if_dead(a); a = None

_unwrap_batched = torch._C._functorch._unwrap_batched(child, 1); child = None
getitem: "f32[2, 3]" = _unwrap_batched[0]; _unwrap_batched = None

pop_dynamic_layer_stack = torch._C._functorch.pop_dynamic_layer_stack()

_are_functorch_transforms_active_2 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_2 = None

function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, getitem, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = getitem = None
outputs: "f32[2, 3]" = autograd_function_apply[0]; autograd_function_apply = None

push_dynamic_layer_stack = torch._C._functorch.push_dynamic_layer_stack(pop_dynamic_layer_stack); pop_dynamic_layer_stack = push_dynamic_layer_stack = None

result: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None

_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(result, 1, 2, 0); result = None

_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
return (_remove_batch_dim,)

class fwd_body_0(torch.nn.Module):
def forward(self, function_ctx : torch.autograd.function.Function, getitem: "f32[2, 3]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None

_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1)

batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None

_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None
outputs: "f32[2, 3]" = _unwrap_batched[0]
getitem_2 = _unwrap_batched[1]; _unwrap_batched = getitem_2 = None

_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None

inp: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1); getitem = inp = None
_add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); _add_batch_dim_2 = None

_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return ((outputs, 0), [])

class bwd_body_0(torch.nn.Module):
def forward(self, function_ctx : torch.autograd.function.Function, outputs: "f32[2, 3]", const_unused : int):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None

_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None

batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None

_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None
grad_ins: "f32[2, 3]" = _unwrap_batched[0]
getitem_1 = _unwrap_batched[1]; _unwrap_batched = getitem_1 = None

_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None

lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None

_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None

_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(grad_ins, 0, 1); grad_ins = None

batched_outputs_1: "f32[3]" = _add_batch_dim_1.sum_to_size((3,)); _add_batch_dim_1 = None

_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 2, 0); batched_outputs_1 = None

_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (_remove_batch_dim,)
""", # NOQA: B950
)

def test_vmapped_autograd_function_fwd_and_bwd(self):
cnt = CompileCounterWithBackend("aot_eager")

class LinearFunction(torch.autograd.Function):
generate_vmap_rule = True

@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output

@staticmethod
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)

return grad_input, grad_weight, grad_bias

def fn(input, weight, bias=None):
return torch.vmap(LinearFunction.apply)(input, weight, bias)

input1 = torch.randn(4, 2, 2, dtype=torch.double, requires_grad=True)
input2 = input1.clone().detach().requires_grad_(True)
weight1 = torch.randn(4, 3, 2, dtype=torch.double, requires_grad=True)
weight2 = weight1.clone().detach().requires_grad_(True)
bias1 = torch.randn(4, 3, dtype=torch.double, requires_grad=True)
bias2 = bias1.clone().detach().requires_grad_(True)

compiled_fn = torch.compile(backend=cnt, fullgraph=True)(fn)

output1 = fn(input1, weight1, bias1)
output1.sum().backward()

output2 = compiled_fn(input2, weight2, bias2)
output2.sum().backward()

self.assertEqual(output1, output2)
self.assertEqual(input1.grad, input2.grad)
self.assertEqual(weight1.grad, weight2.grad)
self.assertEqual(bias1.grad, bias2.grad)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 25)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,9 @@ def build_key_value(i, k, v):
and value == getattr(value.__self__, "apply", None)
):
# handle aliased autograd function `apply` calls
self.install_guards(GuardBuilder.FUNCTION_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
func_source = AttrSource(self.source, "__func__")
install_guard(func_source.make_guard(GuardBuilder.ID_MATCH))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is a bug in guarding the apply method of an autograd function. The original FUNCTION_MATCH triggers an ID check failure (bellow is the error stack), but it works correctly when changed to TYPE_MATCH on apply and ID_MATCH on apply.__func__.

Traceback (most recent call last):
  File "/data/users/ybliang/debug/debug2.py", line 34, in <module>
    print(fn(x))
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 862, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/home/ybliang/local/pytorch/torch/_dynamo/guards.py", line 2466, in __init__
    raise AssertionError(f"Guard check failed: {reasons}")
AssertionError: Guard check failed: 0/0: ___check_obj_id(G['Foo'].apply, 139997975477056) 

return GetAttrVariable(
AutogradFunctionVariable(
value.__self__, source=AttrSource(self.source, member="__self__")
Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ def call_function(
]:
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
return super().call_function(tx, args, kwargs)
elif self.fn is torch._functorch.autograd_function.vmapify_autograd_function:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, why didn't you do something like:

UserDefinedFunction(torch._functorch.autograd_function.vmapify_autograd_function).call_function(tx, args)

The current solution in this PR looks fine to me though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the default inlining approach is better, but I ran into a few unsupported Dynamo features while handling the following case:

Generated = type(
name,
(torch.autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
"jvp": staticmethod(jvp),
"setup_context": staticmethod(setup_context),
"generate_vmap_rule": True,
},

This includes issues like constructing NestedUserFunctionVariable without a source, among others. Since I’d like to keep this PR focused on tracing the vmapped autograd function rather than addressing broader issues, I decided to go with this approach for now.

That said, I’m happy to revisit this later and migrate it to the inlining approach as a follow-up. I’ll add a TODO comment here to track it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO sounds fine

assert isinstance(args[0], variables.AutogradFunctionVariable)
new_autograd_fn = (
torch._functorch.autograd_function.vmapify_autograd_function(
args[0].fn_cls,
args[1].as_python_constant(),
args[2].as_python_constant(),
args[3].as_python_constant(),
)
)
return variables.AutogradFunctionVariable(new_autograd_fn)
return super().call_function(tx, args, kwargs)


Expand Down
77 changes: 76 additions & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,64 @@ def __init__(self, fn_cls, **kwargs) -> None:
super().__init__(**kwargs)
self.fn_cls = fn_cls

def as_proxy(self):
return self.fn_cls
Comment on lines +627 to +628
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have an as_proxy? We shouldn't be putting autograd.Functions into the graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is typo, we don't use it actually, will remove it.


def python_type(self):
return type(self.fn_cls)

def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from torch._functorch.autograd_function import (
autograd_function_forward_rewritten,
)

from .builder import SourcelessBuilder, VariableBuilder
from .higher_order_ops import AutogradFunctionApplyVariable

# Special handling for the vmapped autograd function because:
# 1. We cannot guard against the vmapped autograd function, as it is generated on the fly.
# 2. Skipping this guard is acceptable since we already guard on `id(Generated)`.
# 3. `AutogradFunctionApplyVariable` requires `parent_source` to be non-None,
# though this constraint could be relaxed in the future.
Comment on lines +644 to +645
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How difficult is it to relax this constraint? Otherwise the source we're generating here is incorrect, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think it’s very difficult, though I haven’t looked into it too deeply. I just want to keep this PR focused on its intended scope, but I’ll address it as a follow-up.

The source we generate here is correct—the issue is that we can’t generate a guard from it. The problem arises because it’s trying to generate guards on torch._functorch.autograd_function.vmapped_xxx.apply, which leads to errors during guard evaluation. This happens because the vmapped autograd function is created on the fly during compilation, and we don’t materialize it.

One possible solution could be adding a new guard specifically for objects generated dynamically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

if (
name == "apply"
and self.fn_cls.__name__.startswith("Vmapped")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have some more robust way of identifying a vmapped autograd function. Since these are generated in Dynamo now, could we set a flag when constructing the AutogradFunctionVariable?

Copy link
Contributor Author
@yanboliang yanboliang Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Setting a flag during construction is definitely a more robust approach. I’ll update it.

and not torch._C._are_functorch_transforms_active()
):
forward_fn = autograd_function_forward_rewritten(
self.fn_cls.forward, self.fn_cls.setup_context
)

source = self.source
if source is None:
source = AttrSource(
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)

val = AutogradFunctionApplyVariable(
forward_fn,
self.fn_cls.backward,
source,
source=AttrSource(source, member="apply"),
)
return val

# General case.
try:
attr_value = getattr(self.fn_cls, name)
source = self.source
if source is None:
source = AttrSource(
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)
if source:
attr_source = AttrSource(source, name)
return VariableBuilder(tx, attr_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)
except AttributeError:
unimplemented(f"getattr({self.fn_cls}, {name})")

def call_apply(self, tx: "InstructionTranslator", args, kwargs):
requires_grad = False

Expand Down Expand Up @@ -744,7 +802,17 @@ def call_method(
from ..trace_rules import is_callable_allowed
from .builder import wrap_fx_proxy

if name == "apply":
# There are two cases to handle the apply method of an autograd function:
# 1. If the autograd function is not vmapified:
# - We can directly handle it by either treating it as allow_in_graph or
# wrapping it as an AutogradFunctionApplyVariable HOP.
Comment on lines +806 to +808
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever allow_in_graph an autograd.Function unless the user has explicitly used allow_in_graph?

Copy link
Contributor Author
@yanboliang yanboliang Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here it refers users explicitly using allow_in_graph decorator. I'll update the comment to clarify this.

# 2. If the autograd function is vmapified, there are two types to consider within the same process:
# - The vmapped autograd function (name starts with "Vmapped"):
# - We treat it as allow_in_graph or wrap it as an AutogradFunctionApplyVariable HOP.
# - The original autograd function (be called when functorch transforms are active):
# - Since we already wrap the vmapped autograd function as an AutogradFunctionApplyVariable HOP,
# and the vmapped autograd function calls the original autograd function, we simply inline them.
if name == "apply" and not torch._C._are_functorch_transforms_active():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Er, what happens if functorch transforms are active but you have a regular autograd.Function? Something like:

def f(x, y):
    z = Foo.apply(y)
    return x * z

x = torch.randn(3)
y = torch.randn([])
vmap(f, (0, None))(x, y)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes into the else branch down blow, which is inlining the apply method. This is because we only want call_apply on the vmapped autograd function and capture fwd/bwd graphs there. And during tracing vmapped autograd function, it triggers call to the original regular autograd function's apply method, which we should just inline it.

if is_callable_allowed(self.fn_cls):
trampoline_autograd_apply = produce_trampoline_autograd_apply(
self.fn_cls
Expand All @@ -763,6 +831,7 @@ def call_method(
elif name == "backward":
return self.call_backward(tx, args, kwargs)
else:
# Simply inline these methods.
from .. import trace_rules

source = AttrSource(self.source, name) if self.source is not None else None
Expand Down Expand Up @@ -1000,6 +1069,12 @@ def as_python_constant(self):
except AttributeError:
raise NotImplementedError(f"{self} is not a constant") from None

def call_obj_hasattr(self, tx: "InstructionTranslator", name):
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
return variables.ConstantVariable.create(
hasattr(self.obj.fn_cls.apply, name)
)

def const_getattr(self, tx: "InstructionTranslator", name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
constant_fold_functions = [
torch._assert,
torch._utils._get_device_index,
torch._C._functorch.current_level,
torch._C._get_cublas_allow_tf32,
torch._C._is_any_autocast_enabled,
torch.cuda.get_device_properties,
Expand Down
Loading
0