8000 Don't hardcoded support for DTensor to_local/from_local/redistribute into dynamo · Issue #152829 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Don't hardcoded support for DTensor to_local/from_local/redistribute into dynamo #152829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bdhirsh opened this issue May 5, 2025 · 1 comment
Labels
module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bdhirsh
Copy link
Contributor
bdhirsh commented May 5, 2025

There has been a long-standing hack in dynamo around support for DTensor - there are a few primitive functions (listed above) that accept opaque python types (DTensorSpec/Placement/DeviceMesh) and therefore cannot go in the dynamo graph, that have hardcoded support in dynamo.

This is bad for several reasons:

(1) it is brittle (these functions aren't supported in all cases - recent internal example where .to_local() on a model causes extra graph breaks / recompiles)

(2) it is an invariant violation (dynamo shouldn't really need to know anything about DTensor)

(3) it prevents @jamesjwu 's AOTDispatcher warm cache from kicking in (the hacks we use to handle these functions in dynamo are not easily pickleable by FX and we therefore cache miss on them). This will be even more critical if we want any sort of pre-compilation to work with distributed.

Now that we have a flat_apply HOP that can support non-tensor/symint primitives (thanks @StrongerXi and @zou3519), it should be possible to have dynamo support these functions more generically:

(1) these functions all desugar into a custom autograd.Function, which we support in dynamo

(2) the autograd.Function here accepts custom python types, which we can handle through the flat_apply HOP.

One difference that needs some figuring out, though, is that this flat_apply should "disappear" as part of AOTDispatcher tracing, since the DTensor subclass will desugar these arguments. We need to make sure this works properly.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@bdhirsh bdhirsh added oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 labels May 5, 2025
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels May 5, 2025
@StrongerXi
Copy link
Contributor

Discussed with @bdhirsh offline, I think right now the thing blocking Dynamo from tracing into from_local and to_local is that they eventually hit some autograd.Function.apply, and Dynamo requires its arguments to be proxy-able:

p_args = (
fwd_node,
bwd_node,
*([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())),
)

nonstrict_trace's impl added some logic that makes pytree-registered classes "proxy-able", we can probably just reuse that for autograd.Function.apply tracing:

tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)

After that, we might be able to trace through to_local entirely, but we'd hit a graph break in for from_local, at its DTensor construction, because Dynamo requires the args to be proxy-able:

tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)

For this, we can probably be able to reuse the same logic above, to enable tracing through tensor subclass constructor with more complex input types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0