8000 [Trace PyDispatcher] Capture Vmapped autograd function as graph by yanboliang · Pull Request #146288 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Trace PyDispatcher] Capture Vmapped autograd function as graph #146288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/yanboliang/62/base
Choose a base branch
from

Conversation

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Feb 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146288

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures

As of commit 96cc5cc with merge base fa48757 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yanboliang added a commit that referenced this pull request Feb 3, 2025
@yanboliang yanboliang added the topic: not user facing topic category label Feb 3, 2025
[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Feb 3, 2025
self.install_guards(GuardBuilder.FUNCTION_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
func_source = AttrSource(self.source, "__func__")
install_guard(func_source.make_guard(GuardBuilder.ID_MATCH))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is a bug in guarding the apply method of an autograd function. The original FUNCTION_MATCH triggers an ID check failure (bellow is the error stack), but it works correctly when changed to TYPE_MATCH on apply and ID_MATCH on apply.__func__.

Traceback (most recent call last):
  File "/data/users/ybliang/debug/debug2.py", line 34, in <module>
    print(fn(x))
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 862, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/home/ybliang/local/pytorch/torch/_dynamo/guards.py", line 2466, in __init__
    raise AssertionError(f"Guard check failed: {reasons}")
AssertionError: Guard check failed: 0/0: ___check_obj_id(G['Foo'].apply, 139997975477056) 

@yanboliang yanboliang requested a review from zou3519 February 3, 2025 06:25
[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Feb 3, 2025
@@ -341,6 +341,17 @@ def call_function(
]:
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
return super().call_function(tx, args, kwargs)
elif self.fn is torch._functorch.autograd_function.vmapify_autograd_function:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, why didn't you do something like:

UserDefinedFunction(torch._functorch.autograd_function.vmapify_autograd_function).call_function(tx, args)

The current solution in this PR looks fine to me though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the default inlining approach is better, but I ran into a few unsupported Dynamo features while handling the following case:

Generated = type(
name,
(torch.autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
"jvp": staticmethod(jvp),
"setup_context": staticmethod(setup_context),
"generate_vmap_rule": True,
},

This includes issues like constructing NestedUserFunctionVariable without a source, among others. Since I’d like to keep this PR focused on tracing the vmapped autograd function rather than addressing broader issues, I decided to go with this approach for now.

That said, I’m happy to revisit this later and migrate it to the inlining approach as a follow-up. I’ll add a TODO comment here to track it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO sounds fine

Comment on lines +627 to +628
def as_proxy(self):
return self.fn_cls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have an as_proxy? We shouldn't be putting autograd.Functions into the graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is typo, we don't use it actually, will remove it.

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added suggestions for more testing. Code seems reasonable to me

# though this constraint could be relaxed in the future.
if (
name == "apply"
and self.fn_cls.__name__.startswith("Vmapped")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have some more robust way of identifying a vmapped autograd function. Since these are generated in Dynamo now, could we set a flag when constructing the AutogradFunctionVariable?

Copy link
Contributor Author
@yanboliang yanboliang Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Setting a flag during construction is definitely a more robust approach. I’ll update it.

Comment on lines +806 to +808
# 1. If the autograd function is not vmapified:
# - We can directly handle it by either treating it as allow_in_graph or
# wrapping it as an AutogradFunctionApplyVariable HOP.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever allow_in_graph an autograd.Function unless the user has explicitly used allow_in_graph?

Copy link
Contributor Author
@yanboliang yanboliang Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here it refers users explicitly using allow_in_graph decorator. I'll update the comment to clarify this.

Comment on lines +156 to +158
@torch.compile(backend=eager, fullgraph=True)
def fn(x):
return torch.vmap(Foo.apply)(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a more complicated (non-pointwise) test case with double vmap? Maybe use double vmap on the LinearFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!

# - The original autograd function (be called when functorch transforms are active):
# - Since we already wrap the vmapped autograd function as an AutogradFunctionApplyVariable HOP,
# and the vmapped autograd function calls the original autograd function, we simply inline them.
if name == "apply" and not torch._C._are_functorch_transforms_active():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Er, what happens if functorch transforms are active but you have a regular autograd.Function? Something like:

def f(x, y):
    z = Foo.apply(y)
    return x * z

x = torch.randn(3)
y = torch.randn([])
vmap(f, (0, None))(x, y)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes into the else branch down blow, which is inlining the apply method. This is because we only want call_apply on the vmapped autograd function and capture fwd/bwd graphs there. And during tracing vmapped autograd function, it triggers call to the original regular autograd function's apply method, which we should just inline it.

@@ -130,6 +135,184 @@ def fn(x, y):
# No recompile
self.assertEqual(counter.frame_count, 1)

def test_vmapped_autograd_function(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing, we have a lot of autograd.Function we can test. A good comprehensive way to do this is to copy-paste the following and add torch.compile(backend="eager") testing to it

"test_vmap_exhaustive",
(but only for the autograd_function_db).

The vmap tests are able to generate inputs with various in_dims, so that better exercises the logic. They're also able to generate vmap(vmap( tests.

Comment on lines +644 to +645
# 3. `AutogradFunctionApplyVariable` requires `parent_source` to be non-None,
# though this constraint could be relaxed in the future.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How difficult is it to relax this constraint? Otherwise the source we're generating here is incorrect, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think it’s very difficult, though I haven’t looked into it too deeply. I just want to keep this PR focused on its intended scope, but I’ll address it as a follow-up.

The source we generate here is correct—the issue is that we can’t generate a guard from it. The problem arises because it’s trying to generate guards on torch._functorch.autograd_function.vmapped_xxx.apply, which leads to errors during guard evaluation. This happens because the vmapped autograd function is created on the fly during compilation, and we don’t materialize it.

One possible solution could be adding a new guard specifically for objects generated dynamically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool I think I am just looking for the double vmap tests

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0