Description
There has been a long-standing hack in dynamo around support for DTensor - there are a few primitive functions (listed above) that accept opaque python types (DTensorSpec/Placement/DeviceMesh
) and therefore cannot go in the dynamo graph, that have hardcoded support in dynamo.
This is bad for several reasons:
(1) it is brittle (these functions aren't supported in all cases - recent internal example where .to_local()
on a model causes extra graph breaks / recompiles)
(2) it is an invariant violation (dynamo shouldn't really need to know anything about DTensor)
(3) it prevents @jamesjwu 's AOTDispatcher warm cache from kicking in (the hacks we use to handle these functions in dynamo are not easily pickleable by FX and we therefore cache miss on them). This will be even more critical if we want any sort of pre-compilation to work with distributed.
Now that we have a flat_apply
HOP that can support non-tensor/symint primitives (thanks @StrongerXi and @zou3519), it should be possible to have dynamo support these functions more generically:
(1) these functions all desugar into a custom autograd.Function
, which we support in dynamo
(2) the autograd.Function here accepts custom python types, which we can handle through the flat_apply
HOP.
One difference that needs some figuring out, though, is that this flat_apply should "disappear" as part of AOTDispatcher tracing, since the DTensor subclass will desugar these arguments. We need to make sure this works properly.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames