[RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors #129627
Labels
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
[WIP][RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors
Description
Current DTensor sharding follows a pattern which we call "Contiguous Sharding", because each shard itself contains a contiguous piece of the global view (i.e. the original tensor before sharding). Let's say, we shard a tensor over its dimension 0 over a 1-d
DeviceMesh
of size 2. The tensor will be split row-wisely into 2 contiguous shards and every rank in mesh holds one shard as its local tensor:Strided Sharding, the new sharding pattern we want to introduce to DTensor, differs from Contiguous Sharding in the following aspects:
k * N
) shards. We call thisk
split_factor
. Consider the above example, the tensor will be split into 4 shards ifk=2
.k
shards, assigning them to one rank, and concatenating them into one tensor as the new local tensor. The selection process can be done with a stride and this is the reason why we call it "Strided Sharding". Consider the above example withsplit_factor
k=2
andstride
s=2
as well, the tensor will be first split into 4 shards:Then rank 0 concat shard A & C and rank 1 concat shard B & D, as their local tensor:
Motivation
Now, why should we introduce this irregular sharding pattern? Let's see a FSDP + TP example:
mesh = init_device_mesh("cuda", (2, 2), mesh_dim_names=("dp", "tp")) # mesh: [[0, 1], [2, 3]]
mesh
, on its first dimension (dim 0).mesh["tp"]
thenmesh["dp"]
. However, the current DTensor placement type cannot represent this sharding pattern. Right now, the placements of the above DTensor would be[Shard(0), Shard(0)]
but the corresponding sharding result would be:while
actually gives the following sharding:
You'll see that this sharding pattern is exactly Strided Sharding.
Proposal: Placement Type
To introduce Strided Sharding to DTensor, I propose to add a new placement type
_StridedShard
. The reason we make it a private class is, for the current moment the only use case of Strided Sharding comes from our Composability APIs (e.g. FSDP + TP). This placement type will temporarily only be used in DTensor internal to correctly handle the resharding logic.The definition of the new placement type is:
Proposal: Shard (
distribute_tensor()
)[WIP] the implementation is identical to the "split, assign, and concat" steps in
Description
sectionProposal: To Replicate (
full_tensor()
)[WIP] it is a reverse operation as sharding, i.e. "allgather, split, assign, and concat". The assign step follows the index calculation formula
shard_index_new = X % split_factor * stride + X // split_factor
whereX
is the shard index after the split step.Compatibility
This Strided Sharding concept is compatible with Contiguous Sharding which can be treated as a special case of Strided Sharding (
split_factor=1
). Therefore we should consider forging this_StridedShard
placement type intoShard
.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @chauhang @d4l3k @lessw2020 @gnadathur @wz337
The text was updated successfully, but these errors were encountered: