-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Trace PyDispatcher] Capture Vmapped autograd function as graph #146288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/yanboliang/62/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146288
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit 96cc5cc with merge base fa48757 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | ||
self.install_guards(GuardBuilder.TYPE_MATCH) | ||
func_source = AttrSource(self.source, "__func__") | ||
install_guard(func_source.make_guard(GuardBuilder.ID_MATCH)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is a bug in guarding the apply
method of an autograd function. The original FUNCTION_MATCH
triggers an ID check failure (bellow is the error stack), but it works correctly when changed to TYPE_MATCH
on apply
and ID_MATCH
on apply.__func__
.
Traceback (most recent call last):
File "/data/users/ybliang/debug/debug2.py", line 34, in <module>
print(fn(x))
File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
return fn(*args, **kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
return self._torchdynamo_orig_callable(
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
return _compile(
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 862, in _compile_inner
check_fn = CheckFunctionManager(
File "/home/ybliang/local/pytorch/torch/_dynamo/guards.py", line 2466, in __init__
raise AssertionError(f"Guard check failed: {reasons}")
AssertionError: Guard check failed: 0/0: ___check_obj_id(G['Foo'].apply, 139997975477056)
@@ -341,6 +341,17 @@ def call_function( | |||
]: | |||
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): | |||
return super().call_function(tx, args, kwargs) | |||
elif self.fn is torch._functorch.autograd_function.vmapify_autograd_function: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just wondering, why didn't you do something like:
UserDefinedFunction(torch._functorch.autograd_function.vmapify_autograd_function).call_function(tx, args)
The current solution in this PR looks fine to me though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the default inlining approach is better, but I ran into a few unsupported Dynamo features while handling the following case:
pytorch/torch/_functorch/autograd_function.py
Lines 502 to 511 in e68f508
Generated = type( | |
name, | |
(torch.autograd.Function,), | |
{ | |
"forward": staticmethod(forward), | |
"backward": staticmethod(backward), | |
"jvp": staticmethod(jvp), | |
"setup_context": staticmethod(setup_context), | |
"generate_vmap_rule": True, | |
}, |
This includes issues like constructing NestedUserFunctionVariable
without a source, among others. Since I’d like to keep this PR focused on tracing the vmapped autograd function rather than addressing broader issues, I decided to go with this approach for now.
That said, I’m happy to revisit this later and migrate it to the inlining approach as a follow-up. I’ll add a TODO comment here to track it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO sounds fine
def as_proxy(self): | ||
return self.fn_cls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this have an as_proxy? We shouldn't be putting autograd.Functions into the graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is typo, we don't use it actually, will remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added suggestions for more testing. Code seems reasonable to me
# though this constraint could be relaxed in the future. | ||
if ( | ||
name == "apply" | ||
and self.fn_cls.__name__.startswith("Vmapped") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have some more robust way of identifying a vmapped autograd function. Since these are generated in Dynamo now, could we set a flag when constructing the AutogradFunctionVariable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Setting a flag during construction is definitely a more robust approach. I’ll update it.
# 1. If the autograd function is not vmapified: | ||
# - We can directly handle it by either treating it as allow_in_graph or | ||
# wrapping it as an AutogradFunctionApplyVariable HOP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we ever allow_in_graph an autograd.Function unless the user has explicitly used allow_in_graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, here it refers users explicitly using allow_in_graph decorator. I'll update the comment to clarify this.
@torch.compile(backend=eager, fullgraph=True) | ||
def fn(x): | ||
return torch.vmap(Foo.apply)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do a more complicated (non-pointwise) test case with double vmap? Maybe use double vmap on the LinearFunction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do!
# - The original autograd function (be called when functorch transforms are active): | ||
# - Since we already wrap the vmapped autograd function as an AutogradFunctionApplyVariable HOP, | ||
# and the vmapped autograd function calls the original autograd function, we simply inline them. | ||
if name == "apply" and not torch._C._are_functorch_transforms_active(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Er, what happens if functorch transforms are active but you have a regular autograd.Function? Something like:
def f(x, y):
z = Foo.apply(y)
return x * z
x = torch.randn(3)
y = torch.randn([])
vmap(f, (0, None))(x, y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It goes into the else
branch down blow, which is inlining the apply
method. This is because we only want call_apply
on the vmapped autograd function and capture fwd/bwd graphs there. And during tracing vmapped autograd function, it triggers call to the original regular autograd function's apply method, which we should just inline it.
@@ -130,6 +135,184 @@ def fn(x, y): | |||
# No recompile | |||
self.assertEqual(counter.frame_count, 1) | |||
|
|||
def test_vmapped_autograd_function(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing, we have a lot of autograd.Function we can test. A good comprehensive way to do this is to copy-paste the following and add torch.compile(backend="eager") testing to it
pytorch/test/functorch/test_vmap.py
Line 4341 in 01554c7
"test_vmap_exhaustive", |
The vmap tests are able to generate inputs with various in_dims, so that better exercises the logic. They're also able to generate vmap(vmap( tests.
# 3. `AutogradFunctionApplyVariable` requires `parent_source` to be non-None, | ||
# though this constraint could be relaxed in the future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How difficult is it to relax this constraint? Otherwise the source we're generating here is incorrect, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think it’s very difficult, though I haven’t looked into it too deeply. I just want to keep this PR focused on its intended scope, but I’ll address it as a follow-up.
The source we generate here is correct—the issue is that we can’t generate a guard from it. The problem arises because it’s trying to generate guards on torch._functorch.autograd_function.vmapped_xxx.apply
, which leads to errors during guard evaluation. This happens because the vmapped autograd function is created on the fly during compilation, and we don’t materialize it.
One possible solution could be adding a new guard specifically for objects generated dynamically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool I think I am just looking for the double vmap tests
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames