8000 Dynamo x autograd.Function supports setup_context (#124802) · pytorch/pytorch@ce503c1 · GitHub
[go: up one dir, main page]

Skip to content

Commit ce503c1

Browse files
yanboliangpytorchmergebot
authored andcommitted
Dynamo x autograd.Function supports setup_context (#124802)
Fixes part of #118397 Pull Request resolved: #124802 Approved by: https://github.com/zou3519
1 parent a866bff commit ce503c1

File tree

7 files changed

+79
-20
lines changed

7 files changed

+79
-20
lines changed

test/dynamo/test_autograd_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,11 @@ def test_autograd_function_has_graph_break(self):
253253

254254
def test_linear_setup_context(self):
255255
model = ModuleLinear()
256-
opt_model = torch._dynamo.optimize("eager")(model)
256+
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
257257
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
258258
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
259-
optim_result = opt_model(input, weight)
260259
eager_result = model(input, weight)
260+
optim_result = opt_model(input, weight)
261261
self.assertEqual(optim_result, eager_result)
262262

263263
def test_materialize_grad(self):

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3205,6 +3205,7 @@ def _module_dir(m: types.ModuleType):
32053205
"torch._dynamo.comptime",
32063206
"torch._dynamo.polyfill",
32073207
"torch._functorch.vmap",
3208+
"torch._functorch.autograd_function",
32083209
"torch._library.custom_ops",
32093210
"torch._functorch.eager_transforms",
32103211
"torch._inductor.test_operators",

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,13 +1628,12 @@ def bwd(ctx, grad, x):
16281628
fwd_src = AttrSource(self.parent_source, member="forward")
16291629
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
16301630
if isinstance(self.fwd_graph, types.FunctionType):
1631-
fwd_fn = UserFunctionVariable(self.fwd_graph, source=fwd_src)
1631+
fwd_fn = UserFunctionVariable(self.fwd_graph)
16321632
fwd_args = [ctx, *args]
16331633
elif isinstance(self.fwd_graph, types.MethodType):
16341634
fwd_fn = UserMethodVariable(
16351635
self.fwd_graph.__func__,
16361636
UserDefinedClassVariable(self.fwd_graph.__class__),
1637-
source=fwd_src,
16381637
)
16391638
fwd_args = [fwd_fn.obj, ctx, *args]
16401639
else:

torch/_dynamo/variables/misc.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,19 @@ def visit(node):
357357
and torch.is_grad_enabled()
358358
and config.capture_autograd_function
359359
):
360-
# Note - this is the same check used in autograd/function.py, except inverted.
361-
# If we want to support functorch transforms here, we will need to enable this.
362-
if (
363-
self.fn_cls.setup_context
364-
!= torch.autograd.function._SingleLevelFunction.setup_context
365-
):
366-
unimplemented(
367-
"NYI - autograd.Function with custom setup_context method"
360+
from torch._functorch.autograd_function import (
361+
autograd_function_forward_rewritten,
362+
)
363+
from torch.autograd.function import _is_setup_context_defined
364+
365+
forward_fn = self.fn_cls.forward
366+
367+
is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
368+
if is_setup_ctx_defined:
369+
# If setup_context is defined, we generate a new forward function which includes
370+
# the original forward and setup_context function, and trace the new forward function.
371+
forward_fn = autograd_function_forward_rewritten(
372+
self.fn_cls.forward, self.fn_cls.setup_context
368373
)
369374

370375
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
@@ -383,12 +388,25 @@ def visit(node):
383388
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
384389
)
385390

386-
return AutogradFunctionApplyVariable(
387-
self.fn_cls.forward,
391+
val = AutogradFunctionApplyVariable(
392+
forward_fn,
388393
self.fn_cls.backward,
389394
source,
390395
source=AttrSource(source, member="apply"),
391396
).call_function(tx, args, kwargs)
397+
# Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
398+
# the forward function, as we don't want to generate guards for new_forward.__closure__
399+
# if forward is rewritten by autograd_function_forward_rewritten.
400+
# But we still need to generate correct guards for the original forward and setup_context
401+
# functions, so we have to add guards manually.
402+
if self.source:
403+
fwd_src = AttrSource(self.source, "forward")
404+
install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
405+
if is_setup_ctx_defined:
406+
setup_ctx_src = AttrSource(self.source, "setup_context")
407+
install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
408+
409+
return val
392410

393411
if self.source:
394412
source = AttrSource(self.source, "forward")
@@ -443,7 +461,32 @@ def call_method(
443461
return self.call_apply(tx, args, kwargs)
444462

445463
else:
446-
unimplemented(f"Unsupported method: {name}")
464+
from .. import trace_rules
465+
466+
source = AttrSource(self.source, name) if self.source is not None else None
467+
try:
468+
obj = inspect.getattr_static(self.fn_cls, name)
469+
except AttributeError:
470+
obj = None
471+
472+
if isinstance(obj, staticmethod):
473+
func = obj.__get__(self.fn_cls)
474+
if source is not None:
475+
return (
476+
trace_rules.lookup(func)
477+
.create_with_source(func, source=source)
478+
.call_function(tx, args, kwargs)
479+
)
480+
else:
481+
return trace_rules.lookup(func)(func).call_function(
482+
tx, args, kwargs
483+
)
484+
elif isinstance(obj, classmethod):
485+
return variables.UserMethodVariable(
486+
obj.__func__, self, source=source
487+
).call_function(tx, args, kwargs)
488+
else:
489+
unimplemented(f"Unsupported method: {name}")
447490

448491

449492
@dataclasses.dataclass

torch/_dynamo/variables/user_defined.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections
44
import contextlib
5+
import enum
56
import functools
67
import importlib
78
import inspect
@@ -107,7 +108,7 @@ def can_constant_fold_through(self):
107108

108109
def var_getattr(self, tx, name: str) -> "VariableTracker":
109110
from .. import trace_rules
110-
from . import ConstantVariable
111+
from . import ConstantVariable, EnumVariable
111112
from .builder import VariableBuilder
112113

113114
if name == "__name__":
@@ -144,14 +145,16 @@ def var_getattr(self, tx, name: str) -> "VariableTracker":
144145
if self.value is collections.OrderedDict and name == "fromkeys":
145146
return super().var_getattr(tx, name)
146147

147-
if name in getattr(self.value, "__dict__", {}) or (
148+
if ConstantVariable.is_literal(obj):
149+
return ConstantVariable.create(obj)
150+
elif isinstance(obj, enum.Enum):
151+
return EnumVariable(obj)
152+
elif name in getattr(self.value, "__dict__", {}) or (
148153
self.value.__module__.startswith("torch.")
149154
or self.value.__module__ == "torch"
150155
):
151156
if source:
152157
return VariableBuilder(tx, source)(obj)
153-
elif ConstantVariable.is_literal(obj):
154-
return ConstantVariable.create(obj)
155158

156159
return super().var_getattr(tx, name)
157160

torch/_functorch/autograd_function.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,15 @@ def reductify_leaf(
682682
return grad_input
683683

684684

685+
def autograd_function_forward_rewritten(original_forward, original_setup_context):
686+
def new_forward(ctx, *args, **kwargs):
687+
output = original_forward(*args, **kwargs)
688+
original_setup_context(ctx, args, output)
689+
return output
690+
691+
return new_forward
692+
693+
685694
class AutogradFunctionApply(HigherOrderOperator):
686695
def __init__(self):
687696
super().__init__("autograd_function_apply")

torch/autograd/function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def bind_default_args(func, *args, **kwargs):
561561

562562
return bound_args.args
563563

564-
is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
564+
is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
565565
if is_setup_ctx_defined:
566566
args = bind_default_args(cls.forward, *args, **kwargs)
567567

@@ -585,6 +585,10 @@ def _compiled_autograd_key(ctx):
585585
return (ctx._autograd_function_id,)
586586

587587

588+
def _is_setup_context_defined(fn):
589+
return fn != _SingleLevelFunction.setup_context
590+
591+
588592
def once_differentiable(fn):
589593
@functools.wraps(fn)
590594
def wrapper(ctx, *args):

0 commit comments

Comments
 (0)
0