-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Trace PyDispatcher] Capture Vmapped autograd function as graph #146288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/yanboliang/62/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,12 @@ | |
|
||
import torch | ||
import torch._dynamo.test_case | ||
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm | ||
from torch._dynamo.testing import ( | ||
CompileCounter, | ||
CompileCounterWithBackend, | ||
EagerAndRecordGraphs, | ||
normalize_gm, | ||
) | ||
from torch.testing._internal.common_cuda import TEST_CUDA | ||
|
||
|
||
|
@@ -130,6 +135,184 @@ def fn(x, y): | |
# No recompile | ||
self.assertEqual(counter.frame_count, 1) | ||
|
||
def test_vmapped_autograd_function(self): | ||
eager = EagerAndRecordGraphs() | ||
|
||
class Foo(torch.autograd.Function): | ||
generate_vmap_rule = True | ||
|
||
@staticmethod | ||
def forward(x): | ||
return x * 2 | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, output): | ||
pass | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
return grad * 2 | ||
|
||
@torch.compile(backend=eager, fullgraph=True) | ||
def fn(x): | ||
return torch.vmap(Foo.apply)(x) | ||
Comment on lines
+156
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you do a more complicated (non-pointwise) test case with double vmap? Maybe use double vmap on the LinearFunction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do! |
||
|
||
x = torch.randn(2, 3, requires_grad=True) | ||
self.assertEqual(fn(x), torch.vmap(Foo.apply)(x)) | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
graph = eager.graphs[0] | ||
actual = normalize_gm(graph.print_readable(False)) | ||
self.assertExpectedInline( | ||
actual, | ||
"""\ | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_x_: "f32[2, 3]"): | ||
l_x_ = L_x_ | ||
|
||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None | ||
|
||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None | ||
|
||
a: "f32[3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None | ||
|
||
_are_functor 8000 ch_transforms_active = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active = None | ||
|
||
_are_functorch_transforms_active_1 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_1 = None | ||
|
||
child: "f32[3]" = torch._C._functorch.unwrap_if_dead(a); a = None | ||
|
||
_unwrap_batched = torch._C._functorch._unwrap_batched(child, 1); child = None | ||
getitem: "f32[2, 3]" = _unwrap_batched[0]; _unwrap_batched = None | ||
|
||
pop_dynamic_layer_stack = torch._C._functorch.pop_dynamic_layer_stack() | ||
|
||
_are_functorch_transforms_active_2 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_2 = None | ||
|
||
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None | ||
fwd_body_0 = self.fwd_body_0 | ||
bwd_body_0 = self.bwd_body_0 | ||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, getitem, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = getitem = None | ||
outputs: "f32[2, 3]" = autograd_function_apply[0]; autograd_function_apply = None | ||
|
||
push_dynamic_layer_stack = torch._C._functorch.push_dynamic_layer_stack(pop_dynamic_layer_stack); pop_dynamic_layer_stack = push_dynamic_layer_stack = None | ||
|
||
result: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None | ||
|
||
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(result, 1, 2, 0); result = None | ||
|
||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None | ||
return (_remove_batch_dim,) | ||
|
||
class fwd_body_0(torch.nn.Module): | ||
def forward(self, function_ctx : torch.autograd.function.Function, getitem: "f32[2, 3]"): | ||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None | ||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None | ||
|
||
_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1) | ||
|
||
batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None | ||
|
||
_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None | ||
outputs: "f32[2, 3]" = _unwrap_batched[0] | ||
getitem_2 = _unwrap_batched[1]; _unwrap_batched = getitem_2 = None | ||
|
||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None | ||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None | ||
|
||
inp: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1); getitem = inp = None | ||
_add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); _add_batch_dim_2 = None | ||
|
||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None | ||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None | ||
return ((outputs, 0), []) | ||
|
||
class bwd_body_0(torch.nn.Module): | ||
def forward(self, function_ctx : torch.autograd.function.Function, outputs: "f32[2, 3]", const_unused : int): | ||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None | ||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None | ||
|
||
_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None | ||
|
||
batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None | ||
|
||
_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None | ||
grad_ins: "f32[2, 3]" = _unwrap_batched[0] | ||
getitem_1 = _unwrap_batched[1]; _unwrap_batched = getitem_1 = None | ||
|
||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None | ||
|
||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None | ||
|
||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None | ||
|
||
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(grad_ins, 0, 1); grad_ins = None | ||
|
||
batched_outputs_1: "f32[3]" = _add_batch_dim_1.sum_to_size((3,)); _add_batch_dim_1 = None | ||
|
||
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 2, 0); batched_outputs_1 = None | ||
|
||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None | ||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None | ||
return (_remove_batch_dim,) | ||
""", # NOQA: B950 | ||
) | ||
|
||
def test_vmapped_autograd_function_fwd_and_bwd(self): | ||
cnt = CompileCounterWithBackend("aot_eager") | ||
|
||
class LinearFunction(torch.autograd.Function): | ||
generate_vmap_rule = True | ||
|
||
@staticmethod | ||
def forward(input, weight, bias): | ||
output = input.mm(weight.t()) | ||
if bias is not None: | ||
output += bias.unsqueeze(0).expand_as(output) | ||
return output | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, output): | ||
input, weight, bias = inputs | ||
ctx.save_for_backward(input, weight, bias) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input, weight, bias = ctx.saved_tensors | ||
grad_input = grad_weight = grad_bias = None | ||
if ctx.needs_input_grad[0]: | ||
grad_input = grad_output.mm(weight) | ||
if ctx.needs_input_grad[1]: | ||
grad_weight = grad_output.t().mm(input) | ||
if bias is not None and ctx.needs_input_grad[2]: | ||
grad_bias = grad_output.sum(0) | ||
|
||
return grad_input, grad_weight, grad_bias | ||
|
||
def fn(input, weight, bias=None): | ||
return torch.vmap(LinearFunction.apply)(input, weight, bias) | ||
|
||
input1 = torch.randn(4, 2, 2, dtype=torch.double, requires_grad=True) | ||
input2 = input1.clone().detach().requires_grad_(True) | ||
weight1 = torch.randn(4, 3, 2, dtype=torch.double, requires_grad=True) | ||
weight2 = weight1.clone().detach().requires_grad_(True) | ||
bias1 = torch.randn(4, 3, dtype=torch.double, requires_grad=True) | ||
bias2 = bias1.clone().detach().requires_grad_(True) | ||
|
||
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(fn) | ||
|
||
output1 = fn(input1, weight1, bias1) | ||
output1.sum().backward() | ||
|
||
output2 = compiled_fn(input2, weight2, bias2) | ||
output2.sum().backward() | ||
|
||
self.assertEqual(output1, output2) | ||
self.assertEqual(input1.grad, input2.grad) | ||
self.assertEqual(weight1.grad, weight2.grad) | ||
self.assertEqual(bias1.grad, bias2.grad) | ||
self.assertEqual(cnt.frame_count, 1) | ||
self.assertEqual(cnt.op_count, 25) | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -787,7 +787,9 @@ def build_key_value(i, k, v): | |
and value == getattr(value.__self__, "apply", None) | ||
): | ||
# handle aliased autograd function `apply` calls | ||
self.install_guards(GuardBuilder.FUNCTION_MATCH) | ||
self.install_guards(GuardBuilder.TYPE_MATCH) | ||
func_source = AttrSource(self.source, "__func__") | ||
install_guard(func_source.make_guard(GuardBuilder.ID_MATCH)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is a bug in guarding the
|
||
return GetAttrVariable( | ||
AutogradFunctionVariable( | ||
value.__self__, source=AttrSource(self.source, member="__self__") | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -341,6 +341,17 @@ def call_function( | |||||||||||||||||||||
]: | ||||||||||||||||||||||
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): | ||||||||||||||||||||||
return super().call_function(tx, args, kwargs) | ||||||||||||||||||||||
elif self.fn is torch._functorch.autograd_function.vmapify_autograd_function: | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just wondering, why didn't you do something like: UserDefinedFunction(torch._functorch.autograd_function.vmapify_autograd_function).call_function(tx, args) The current solution in this PR looks fine to me though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the default inlining approach is better, but I ran into a few unsupported Dynamo features while handling the following case: pytorch/torch/_functorch/autograd_function.py Lines 502 to 511 in e68f508
This includes issues like constructing That said, I’m happy to revisit this later and migrate it to the inlining approach as a follow-up. I’ll add a TODO comment here to track it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO sounds fine |
||||||||||||||||||||||
assert isinstance(args[0], variables.AutogradFunctionVariable) | ||||||||||||||||||||||
new_autograd_fn = ( | ||||||||||||||||||||||
torch._functorch.autograd_function.vmapify_autograd_function( | ||||||||||||||||||||||
args[0].fn_cls, | ||||||||||||||||||||||
args[1].as_python_constant(), | ||||||||||||||||||||||
args[2].as_python_constant(), | ||||||||||||||||||||||
args[3].as_python_constant(), | ||||||||||||||||||||||
) | ||||||||||||||||||||||
) | ||||||||||||||||||||||
return variables.AutogradFunctionVariable(new_autograd_fn) | ||||||||||||||||||||||
return super().call_function(tx, args, kwargs) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -624,6 +624,64 @@ def __init__(self, fn_cls, **kwargs) -> None: | |
super().__init__(**kwargs) | ||
self.fn_cls = fn_cls | ||
|
||
def as_proxy(self): | ||
return self.fn_cls | ||
Comment on lines
+627
to
+628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this have an as_proxy? We shouldn't be putting autograd.Functions into the graph There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is typo, we don't use it actually, will remove it. |
||
|
||
def python_type(self): | ||
return type(self.fn_cls) | ||
|
||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": | ||
from torch._functorch.autograd_function import ( | ||
autograd_function_forward_rewritten, | ||
) | ||
|
||
from .builder import SourcelessBuilder, VariableBuilder | ||
from .higher_order_ops import AutogradFunctionApplyVariable | ||
|
||
# Special handling for the vmapped autograd function because: | ||
# 1. We cannot guard against the vmapped autograd function, as it is generated on the fly. | ||
# 2. Skipping this guard is acceptable since we already guard on `id(Generated)`. | ||
# 3. `AutogradFunctionApplyVariable` requires `parent_source` to be non-None, | ||
# though this constraint could be relaxed in the future. | ||
Comment on lines
+644
to
+645
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How difficult is it to relax this constraint? Otherwise the source we're generating here is incorrect, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think it’s very difficult, though I haven’t looked into it too deeply. I just want to keep this PR focused on its intended scope, but I’ll address it as a follow-up. The source we generate here is correct—the issue is that we can’t generate a guard from it. The problem arises because it’s trying to generate guards on One possible solution could be adding a new guard specifically for objects generated dynamically. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
if ( | ||
name == "apply" | ||
and self.fn_cls.__name__.startswith("Vmapped") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should have some more robust way of identifying a vmapped autograd function. Since these are generated in Dynamo now, could we set a flag when constructing the AutogradFunctionVariable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! Setting a flag during construction is definitely a more robust approach. I’ll update it. |
||
and not torch._C._are_functorch_transforms_active() | ||
): | ||
forward_fn = autograd_function_forward_rewritten( | ||
self.fn_cls.forward, self.fn_cls.setup_context | ||
) | ||
|
||
source = self.source | ||
if source is None: | ||
source = AttrSource( | ||
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ | ||
) | ||
|
||
val = AutogradFunctionApplyVariable( | ||
forward_fn, | ||
self.fn_cls.backward, | ||
source, | ||
source=AttrSource(source, member="apply"), | ||
) | ||
return val | ||
|
||
# General case. | ||
try: | ||
attr_value = getattr(self.fn_cls, name) | ||
source = self.source | ||
if source is None: | ||
source = AttrSource( | ||
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ | ||
) | ||
if source: | ||
attr_source = AttrSource(source, name) | ||
return VariableBuilder(tx, attr_source)(attr_value) | ||
else: | ||
return SourcelessBuilder.create(tx, attr_value) | ||
except AttributeError: | ||
unimplemented(f"getattr({self.fn_cls}, {name})") | ||
|
||
def call_apply(self, tx: "InstructionTranslator", args, kwargs): | ||
requires_grad = False | ||
|
||
|
@@ -744,7 +802,17 @@ def call_method( | |
from ..trace_rules import is_callable_allowed | ||
from .builder import wrap_fx_proxy | ||
|
||
if name == "apply": | ||
# There are two cases to handle the apply method of an autograd function: | ||
# 1. If the autograd function is not vmapified: | ||
# - We can directly handle it by either treating it as allow_in_graph or | ||
# wrapping it as an AutogradFunctionApplyVariable HOP. | ||
Comment on lines
+806
to
+808
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we ever allow_in_graph an autograd.Function unless the user has explicitly used allow_in_graph? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, here it refers users explicitly using allow_in_graph decorator. I'll update the comment to clarify this. |
||
# 2. If the autograd function is vmapified, there are two types to consider within the same process: | ||
# - The vmapped autograd function (name starts with "Vmapped"): | ||
# - We treat it as allow_in_graph or wrap it as an AutogradFunctionApplyVariable HOP. | ||
# - The original autograd function (be called when functorch transforms are active): | ||
# - Since we already wrap the vmapped autograd function as an AutogradFunctionApplyVariable HOP, | ||
# and the vmapped autograd function calls the original autograd function, we simply inline them. | ||
if name == "apply" and not torch._C._are_functorch_transforms_active(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Er, what happens if functorch transforms are active but you have a regular autograd.Function? Something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It goes into the |
||
if is_callable_allowed(self.fn_cls): | ||
trampoline_autograd_apply = produce_trampoline_autograd_apply( | ||
self.fn_cls | ||
|
@@ -763,6 +831,7 @@ def call_method( | |
elif name == "backward": | ||
return self.call_backward(tx, args, kwargs) | ||
else: | ||
# Simply inline these methods. | ||
from .. import trace_rules | ||
|
||
source = AttrSource(self.source, name) if self.source is not None else None | ||
|
@@ -1000,6 +1069,12 @@ def as_python_constant(self): | |
except AttributeError: | ||
raise NotImplementedError(f"{self} is not a constant") from None | ||
|
||
def call_obj_hasattr(self, tx: "InstructionTranslator", name): | ||
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply": | ||
return variables.ConstantVariable.create( | ||
hasattr(self.obj.fn_cls.apply, name) | ||
) | ||
|
||
def const_getattr(self, tx: "InstructionTranslator", name): | ||
if not isinstance(self.obj, variables.NNModuleVariable): | ||
raise NotImplementedError | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing, we have a lot of autograd.Function we can test. A good comprehensive way to do this is to copy-paste the following and add torch.compile(backend="eager") testing to it
pytorch/test/functorch/test_vmap.py
Line 4341 in 01554c7
The vmap tests are able to generate inputs with various in_dims, so that better exercises the logic. They're also able to generate vmap(vmap( tests.