Don't hardcoded support for DTensor to_local/from_local/redistribute into dynamo #152829
Labels
module: dynamo
oncall: distributed
Add this issue/PR to distributed oncall triage queue
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
There has been a long-standing hack in dynamo around support for DTensor - there are a few primitive functions (listed above) that accept opaque python types (
DTensorSpec/Placement/DeviceMesh
) and therefore cannot go in the dynamo graph, that have hardcoded support in dynamo.This is bad for several reasons:
(1) it is brittle (these functions aren't supported in all cases - recent internal example where
.to_local()
on a model causes extra graph breaks / recompiles)(2) it is an invariant violation (dynamo shouldn't really need to know anything about DTensor)
(3) it prevents @jamesjwu 's AOTDispatcher warm cache from kicking in (the hacks we use to handle these functions in dynamo are not easily pickleable by FX and we therefore cache miss on them). This will be even more critical if we want any sort of pre-compilation to work with distributed.
Now that we have a
flat_apply
HOP that can support non-tensor/symint primitives (thanks @StrongerXi and @zou3519), it should be possible to have dynamo support these functions more generically:(1) these functions all desugar into a custom
autograd.Function
, which we support in dynamo(2) the autograd.Function here accepts custom python types, which we can handle through the
flat_apply
HOP.One difference that needs some figuring out, though, is that this flat_apply should "disappear" as part of AOTDispatcher tracing, since the DTensor subclass will desugar these arguments. We need to make sure this works properly.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
The text was updated successfully, but these errors were encountered: