8000 [Trace PyDispatcher] Capture Vmapped autograd function as graph · pytorch/pytorch@00d7da9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 00d7da9

Browse files
committed
[Trace PyDispatcher] Capture Vmapped autograd function as graph
ghstack-source-id: 52ea779 Pull Request resolved: #146288
1 parent 06559cf commit 00d7da9

File tree

5 files changed

+270
-3
lines changed

5 files changed

+270
-3
lines changed

test/dynamo/test_python_dispatcher.py

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
import torch
55
import torch._dynamo.test_case
6-
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm
6+
from torch._dynamo.testing import (
7+
CompileCounter,
8+
CompileCounterWithBackend,
9+
EagerAndRecordGraphs,
10+
normalize_gm,
11+
)
712
from torch.testing._internal.common_cuda import TEST_CUDA
813

914

@@ -130,6 +135,184 @@ def fn(x, y):
130135
# No recompile
131136
self.assertEqual(counter.frame_count, 1)
132137

138+
def test_vmapped_autograd_function(self):
139+
eager = EagerAndRecordGraphs()
140+
141+
class Foo(torch.autograd.Function):
142+
generate_vmap_rule = True
143+
144+
@staticmethod
145+
def forward(x):
146+
return x * 2
147+
148+
@staticmethod
149+
def setup_context(ctx, inputs, output):
150+
pass
151+
152+
@staticmethod
153+
def backward(ctx, grad):
154+
return grad * 2
155+
156+
@torch.compile(backend=eager, fullgraph=True)
157+
def fn(x):
158+
return torch.vmap(Foo.apply)(x)
159+
160+
x = torch.randn(2, 3, requires_grad=True)
161+
self.assertEqual(fn(x), torch.vmap(Foo.apply)(x))
162+
163+
graph = eager.graphs[0]
164+
actual = normalize_gm(graph.print_readable(False))
165+
self.assertExpectedInline(
166+
actual,
167+
"""\
168+
class GraphModule(torch.nn.Module):
169+
def forward(self, L_x_: "f32[2, 3]"):
170+
l_x_ = L_x_
171+
172+
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
173+
174+
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
175+
176+
a: "f32[3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
177+
178+
_are_functorch_transforms_active = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active = None
179+
180+
_are_functorch_transforms_active_1 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_1 = None
181+
182+
child: "f32[3]" = torch._C._functorch.unwrap_if_dead(a); a = None
183+
184+
_unwrap_batched = torch._C._functorch._unwrap_batched(child, 1); child = None
185+
getitem: "f32[2, 3]" = _unwrap_batched[0]; _unwrap_batched = None
186+
187+
pop_dynamic_layer_stack = torch._C._functorch.pop_dynamic_layer_stack()
188+
189+
_are_functorch_transforms_active_2 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_2 = None
190+
191+
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
192+
fwd_body_0 = self.fwd_body_0
193+
bwd_body_0 = self.bwd_body_0
194+
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, getitem, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = getitem = None
195+
outputs: "f32[2, 3]" = autograd_function_apply[0]; autograd_function_apply = None
196+
197+
push_dynamic_layer_stack = torch._C._functorch.push_dynamic_layer_stack(pop_dynamic_layer_stack); pop_dynamic_layer_stack = push_dynamic_layer_stack = None
198+
199+
result: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None
200+
201+
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(result, 1, 2, 0); result = None
202+
203+
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
204+
return (_remove_batch_dim,)
205+
206+
class fwd_body_0(torch.nn.Module):
207+
def forward(self, function_ctx : torch.autograd.function.Function, getitem: "f32[2, 3]"):
208+
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
209+
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
210+
211+
_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1)
212+
213+
batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None
214+
215+
_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None
216+
outputs: "f32[2, 3]" = _unwrap_batched[0]
217+
getitem_2 = _unwrap_batched[1]; _unwrap_batched = getitem_2 = None
218+
219+
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
220+
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None
221+
222+
inp: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1); getitem = inp = None
223+
_add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); _add_batch_dim_2 = None
224+
225+
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
226+
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
227+
return ((outputs, 0), [])
228+
229+
class bwd_body_0(torch.nn.Module):
230+
def forward(self, function_ctx : torch.autograd.function.Function, outputs: "f32[2, 3]", const_unused : int):
231+
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
232+
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
233+
234+
_add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None
235+
236+
batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None
237+
238+
_unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None
239+
grad_ins: "f32[2, 3]" = _unwrap_batched[0]
240+
getitem_1 = _unwrap_batched[1]; _unwrap_batched = getitem_1 = None
241+
242+
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
243+
244+
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
245+
246+
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None
247+
248+
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(grad_ins, 0, 1); grad_ins = None
249+
250+
batched_outputs_1: "f32[3]" = _add_batch_dim_1.sum_to_size((3,)); _add_batch_dim_1 = None
251+
252+
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 2, 0); batched_outputs_1 = None
253+
254+
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
255+
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
256+
return (_remove_batch_dim,)
257+
""", # NOQA: B950
258+
)
259+
260+
def test_vmapped_autograd_function_fwd_and_bwd(self):
261+
cnt = CompileCounterWithBackend("aot_eager")
262+
263+
class LinearFunction(torch.autograd.Function):
264+
generate_vmap_rule = True
265+
266+
@staticmethod
267+
def forward(input, weight, bias):
268+
output = input.mm(weight.t())
269+
if bias is not None:
270+
output += bias.unsqueeze(0).expand_as(output)
271+
return output
272+
273+
@staticmethod
274+
def setup_context(ctx, inputs, output):
275+
input, weight, bias = inputs
276+
ctx.save_for_backward(input, weight, bias)
277+
278+
@staticmethod
279+
def backward(ctx, grad_output):
280+
input, weight, bias = ctx.saved_tensors
281+
grad_input = grad_weight = grad_bias = None
282+
if ctx.needs_input_grad[0]:
283+
grad_input = grad_output.mm(weight)
284+
if ctx.needs_input_grad[1]:
285+
grad_weight = grad_output.t().mm(input)
286+
if bias is not None and ctx.needs_input_grad[2]:
287+
grad_bias = grad_output.sum(0)
288+
289+
return grad_input, grad_weight, grad_bias
290+
291+
def fn(input, weight, bias=None):
292+
return torch.vmap(LinearFunction.apply)(input, weight, bias)
293+
294+
input1 = torch.randn(4, 2, 2, dtype=torch.double, requires_grad=True)
295+
input2 = input1.clone().detach().requires_grad_(True)
296+
weight1 = torch.randn(4, 3, 2, dtype=torch.double, requires_grad=True)
297+
weight2 = weight1.clone().detach().requires_grad_(True)
298+
bias1 = torch.randn(4, 3, dtype=torch.double, requires_grad=True)
299+
bias2 = bias1.clone().detach().requires_grad_(True)
300+
301+
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(fn)
302+
303+
output1 = fn(input1, weight1, bias1)
304+
output1.sum().backward()
305+
306+
output2 = compiled_fn(input2, weight2, bias2)
307+
output2.sum().backward()
308+
309+
self.assertEqual(output1, output2)
310+
self.assertEqual(input1.grad, input2.grad)
311+
self.assertEqual(weight1.grad, weight2.grad)
312+
self.assertEqual(bias1.grad, bias2.grad)
313+
self.assertEqual(cnt.frame_count, 1)
314+
self.assertEqual(cnt.op_count, 25)
315+
133316

134317
if __name__ == "__main__":
135318
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,9 @@ def build_key_value(i, k, v):
787787
and value == getattr(value.__self__, "apply", None)
788788
):
789789
# handle aliased autograd function `apply` calls
790-
self.install_guards(GuardBuilder.FUNCTION_MATCH)
790+
self.install_guards(GuardBuilder.TYPE_MATCH)
791+
func_source = AttrSource(self.source, "__func__")
792+
install_guard(func_source.make_guard(GuardBuilder.ID_MATCH))
791793
return GetAttrVariable(
792794
AutogradFunctionVariable(
793795
value.__self__, source=AttrSource(self.source, member="__self__")

torch/_dynamo/variables/functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,17 @@ def call_function(
341341
]:
342342
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
343343
return super().call_function(tx, args, kwargs)
344+
elif self.fn is torch._functorch.autograd_function.vmapify_autograd_function:
345+
assert isinstance(args[0], variables.AutogradFunctionVariable)
346+
new_autograd_fn = (
347+
torch._functorch.autograd_function.vmapify_autograd_function(
348+
args[0].fn_cls,
349+
args[1].as_python_constant(),
350+
args[2].as_python_constant(),
351+
args[3].as_python_constant(),
352+
)
353+
)
354+
return variables.AutogradFunctionVariable(new_autograd_fn)
344355
return super().call_function(tx, args, kwargs)
345356

346357

torch/_dynamo/variables/misc.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,59 @@ def __init__(self, fn_cls, **kwargs) -> None:
624624
super().__init__(**kwargs)
625625
self.fn_cls = fn_cls
626626

627+
def as_proxy(self):
628+
return self.fn_cls
629+
630+
def python_type(self):
631+
return type(self.fn_cls)
632+
633+
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
634+
from torch._functorch.autograd_function import (
635+
autograd_function_forward_rewritten,
636+
)
637+
638+
from .builder import SourcelessBuilder, VariableBuilder
639+
from .higher_order_ops import AutogradFunctionApplyVariable
640+
641+
# Special handling for the vmapped autograd function because:
642+
# 1. We cannot guard against the vmapped autograd function, as it is generated on the fly.
643+
# 2. Skipping this guard is acceptable since we already guard on `id(Generated)`.
644+
# 3. `AutogradFunctionApplyVariable` requires `parent_source` to be non-None,
645+
# though this constraint could be relaxed in the future.
646+
if (
647+
name == "apply"
648+
and self.fn_cls.__name__.startswith("Vmapped")
649+
and not torch._C._are_functorch_transforms_active()
650+
):
651+
forward_fn = autograd_function_forward_rewritten(
652+
self.fn_cls.forward, self.fn_cls.setup_context
653+
)
654+
655+
source = self.source
656+
if source is None:
657+
source = AttrSource(
658+
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
659+
)
660+
661+
val = AutogradFunctionApplyVariable(
662+
forward_fn,
663+
self.fn_cls.backward,
664+
source,
665+
source=AttrSource(source, member="apply"),
666+
)
667+
return val
668+
669+
# General case.
670+
try:
671+
attr_value = getattr(self.fn_cls, name)
672+
if self.source:
673+
attr_source = AttrSource(self.source, name)
674+
return VariableBuilder(tx, attr_source)(attr_value)
675+
else:
676+
return SourcelessBuilder.create(tx, attr_value)
677+
except AttributeError:
678+
unimplemented(f"getattr({self.fn_cls}, {name})")
679+
627680
def call_apply(self, tx: "InstructionTranslator", args, kwargs):
628681
requires_grad = False
629682

@@ -744,7 +797,17 @@ def call_method(
744797
from ..trace_rules import is_callable_allowed
745798
from .builder import wrap_fx_proxy
746799

747-
if name == "apply":
800+
# There are two cases to handle the apply method of an autograd function:
801+
# 1. If the autograd function is not vmapified:
802+
# - We can directly handle it by either treating it as allow_in_graph or
803+
# wrapping it as an AutogradFunctionApplyVariable HOP.
804+
# 2. If the autograd function is vmapified, there are two types to consider within the same process:
805+
# - The vmapped autograd function (name starts with "Vmapped"):
806+
# - We treat it as allow_in_graph or wrap it as an AutogradFunctionApplyVariable HOP.
807+
# - The original autograd function (be called when functorch transforms are active):
808+
# - Since we already wrap the vmapped autograd function as an AutogradFunctionApplyVariable HOP,
809+
# and the vmapped autograd function calls the original autograd function, we simply inline them.
810+
if name == "apply" and not torch._C._are_functorch_transforms_active():
748811
if is_callable_allowed(self.fn_cls):
749812
trampoline_autograd_apply = produce_trampoline_autograd_apply(
750813
self.fn_cls
@@ -763,6 +826,7 @@ def call_method(
763826
elif name == "backward":
764827
return self.call_backward(tx, args, kwargs)
765828
else:
829+
# Simply inline these methods.
766830
from .. import trace_rules
767831

768832
source = AttrSource(self.source, name) if self.source is not None else None
@@ -1000,6 +1064,12 @@ def as_python_constant(self):
10001064
except AttributeError:
10011065
raise NotImplementedError(f"{self} is not a constant") from None
10021066

1067+
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
1068+
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
1069+
return variables.ConstantVariable.create(
1070+
hasattr(self.obj.fn_cls.apply, name)
1071+
)
1072+
10031073
def const_getattr(self, tx: "InstructionTranslator", name):
10041074
if not isinstance(self.obj, variables.NNModuleVariable):
10051075
raise NotImplementedError

torch/_dynamo/variables/torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
constant_fold_functions = [
108108
torch._assert,
109109
torch._utils._get_device_index,
110+
torch._C._functorch.current_level,
110111
torch._C._get_cublas_allow_tf32,
111112
torch._C._is_any_autocast_enabled,
112113
torch.cuda.get_device_properties,

0 commit comments

Comments
 (0)
0