8000 [RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors · Issue #129627 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors #129627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wz337 opened this issue Jun 27, 2024 · 1 comment
Open
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wz337
Copy link
Contributor
wz337 commented Jun 27, 2024

[WIP][RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors

Description

Current DTensor sharding follows a pattern which we call "Contiguous Sharding", because each shard itself contains a contiguous piece of the global view (i.e. the original tensor before sharding). Let's say, we shard a tensor over its dimension 0 over a 1-d DeviceMesh of size 2. The tensor will be split row-wisely into 2 contiguous shards and every rank in mesh holds one shard as its local tensor:

Contiguous Sharding
A (placed on rank 0)
B (placed on rank 1)

Strided Sharding, the new sharding pattern we want to introduce to DTensor, differs from Contiguous Sharding in the following aspects:

  1. Unlike contiguous sharding splitting the tensor into NUM_RANKS shards, Contiguous Sharding allows splitting the tensor into a multiple of NUM_RANKS (i.e. k * N) shards. We call this k split_factor. Consider the above example, the tensor will be split into 4 shards if k=2.
  2. We still respect the DTensor design principle that each rank will only hold one local tensor, by selecting k shards, assigning them to one rank, and concatenating them into one tensor as the new local tensor. The selection process can be done with a stride and this is the reason why we call it "Strided Sharding". Consider the above example with split_factor k=2 and stride s=2 as well, the tensor will be first split into 4 shards:
Strided Sharding (Tensor Split & Strided Select)
A (placed on rank 0)
B (placed on rank 1)
C (placed on rank 0)
D (placed on rank 1)

Then rank 0 concat shard A & C and rank 1 concat shard B & D, as their local tensor:

Strided Sharding (Tensor Concat)
A & C (placed on rank 0)
B & D (placed on rank 1)
  1. Unlike Contiguous Sharding, the result local tensor of Strided Sharding is no longer contiguous in global view: A and C are not contiguous neighbors (and so are B & D). This is the fundamental difference between Strided Sharding and Contiguous Sharding.

Motivation

Now, why should we introduce this irregular sharding pattern? Let's see a FSDP + TP example:

  • assume we have 4 ranks and organize them into a 2-d mesh. mesh = init_device_mesh("cuda", (2, 2), mesh_dim_names=("dp", "tp")) # mesh: [[0, 1], [2, 3]]
  • a 2-d tensor is sharded over mesh, on its first dimension (dim 0).
  • in our FSDP+TP code, the tensor is first sharded over mesh["tp"] then mesh["dp"]. However, the current DTensor placement type cannot represent this sharding pattern. Right now, the placements of the above DTensor would be [Shard(0), Shard(0)] but the corresponding sharding result would be:
Contiguous Sharding
A (placed on rank 0)
B (placed on rank 1)
C (placed on rank 2)
D (placed on rank 3)

while

distribute_tensor(
    distribute_tensor(tensor, mesh=mesh["tp"], placements=[Shard(0)]),  # TP
    mesh=mesh["dp"],
    placements=[Shard(0)],
)  # FSDP

actually gives the following sharding:

FSDP+TP Sharding
A (placed on rank 0)
C (placed on rank 1)
B (placed on rank 2)
D (placed on rank 3)

You'll see that this sharding pattern is exactly Strided Sharding.

Proposal: Placement Type

To introduce Strided Sharding to DTensor, I propose to add a new placement type _StridedShard. The reason we make it a private class is, for the current moment the only use case of Strided Sharding comes from our Composability APIs (e.g. FSDP + TP). This placement type will temporarily only be used in DTensor internal to correctly handle the resharding logic.

The definition of the new placement type is:

@dataclass(frozen=True, kw_only=True)
class _StridedShard(Shard):
    split_factor: int
    stride: int

Proposal: Shard (distribute_tensor())

[WIP] the implementation is identical to the "split, assign, and concat" steps in Description section

Proposal: To Replicate (full_tensor())

[WIP] it is a reverse operation as sharding, i.e. "allgather, split, assign, and concat". The assign step follows the index calculation formula shard_index_new = X % split_factor * stride + X // split_factor where X is the shard index after the split step.

Compatibility

This Strided Sharding concept is compatible with Contiguous Sharding which can be treated as a special case of Strided Sharding (split_factor=1). Therefore we should consider forging this _StridedShard placement type into Shard.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @chauhang @d4l3k @lessw2020 @gnadathur @wz337

@bdhirsh bdhirsh added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 27, 2024
@XilunWu XilunWu changed the title [WIP][RFC] DTensor Strided Sharding [WIP][RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors Jun 27, 2024
@XilunWu XilunWu self-assigned this Jun 28, 2024
@fegin fegin added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 28, 2024
@tianyu-l
Copy link
Contributor
tianyu-l commented Jul 3, 2024

typo?

Unlike contiguous sharding splitting the tensor into NUM_RANKS shards, Contiguous Sharding allows

->
Unlike contiguous sharding splitting the tensor into NUM_RANKS shards, Strided Sharding allows

@wz337 wz337 changed the title [WIP][RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors [RFC][DTensor] DTensor Strided Sharding: A More Flexible Way To Shard Tensors Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add th 4211 is issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0