8000 Update base for Update on "[FSDP][dtensor] use _StridedShard to repre… · pytorch/pytorch@d7a07ce · GitHub
[go: up one dir, main page]

Skip to content

Commit d7a07ce

Browse files
committed
Update base for Update on "[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result"
Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim. 3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 4. Re-enabled the tests that were disabled in #129519 and removed relevant code **test** `pytest test/distributed/_composable/fsdp/test_fully_shard_training.py` `pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py` `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang [ghstack-poisoned]
1 parent 7b666b8 commit d7a07ce

File tree

5 files changed

+445
-5
lines changed

5 files changed

+445
-5
lines changed

test/distributed/_tensor/test_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ def test_compute_local_shape_and_global_offset_2D(self):
127127
global_tensor[dim0_start:dim0_end, dim1_start:dim1_end],
128128
)
129129

130+
@with_comms
131+
def test_fsdp_tp_meta_compute(self):
132+
# FSDP + TP sharding
133+
tp_size = 2
134+
dp_size = self.world_size // tp_size
135+
global_mesh = init_device_mesh(
136+
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
137+
)
138+
# local shard shape is [2, 2]
139+
global_tensor_shape = torch.Size([2 * self.world_size, 2])
140+
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
141+
142+
local_shape, global_offset = compute_local_shape_and_global_offset(
143+
global_tensor_shape, global_mesh, placements
144+
)
145+
assert global_mesh.get_coordinate is not None
146+
dp_rank = global_mesh.get_local_rank("dp")
147+
tp_rank = global_mesh.get_local_rank("tp")
148+
shard_idx_on_dim_0 = tp_rank * dp_size + dp_rank
149+
expected_local_shape = (2, 2)
150+
expected_global_offset = (shard_idx_on_dim_0 * 2, 0)
151+
self.assertEqual(local_shape, expected_local_shape)
152+
self.assertEqual(global_offset, expected_global_offset)
153+
130154

131155
class TestStridedSharding(DTensorTestBase):
132156
@property
@@ -266,6 +290,64 @@ def test_2d_mesh_strided_sharding(self):
266290
)
267291
self.assertEqual(full_tensor, x)
268292

293+
@with_comms
294+
def test_2d_mesh_2d_tensor_strided_sharding(self):
295+
# Test 2: 1-d tensor over 2-d mesh
296+
mesh_2d = init_device_mesh(
297+
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dim0", "dim1")
298+
)
299+
mesh_dim0_size = mesh_2d["dim0"].size()
300+
mesh_dim1_size = mesh_2d["dim1"].size()
301+
mesh_dim0_local_rank = mesh_2d["dim0"].get_local_rank(mesh_dim=0)
302+
mesh_dim1_local_rank = mesh_2d["dim1"].get_local_rank(mesh_dim=0)
303+
x = torch.arange(2 * self.world_size, device=self.device_type).reshape(2, -1)
304+
"""
305+
strided sharding:
306+
rank 0: [[0], [4]]
307+
rank 1: [[2], [6]]
308+
rank 2: [[1], [5]]
309+
rank 3: [[3], [7]]
310+
"""
311+
split_factor = 2
312+
# shard on mesh dim-0
313+
shard_placement_dim0 = _StridedShard(1, split_factor=split_factor)
314+
tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size)
315+
shard_x = tensor_list[mesh_dim0_local_rank]
316+
expected_shard_dim0 = (
317+
torch.tensor([[0, 2], [4, 6]], device=self.device_type)
318+
if mesh_dim0_local_rank == 0
319+
else torch.tensor([[1, 3], [5, 7]], device=self.device_type)
320+
)
321+
self.assertEqual(shard_x, expected_shard_dim0)
322+
323+
# shard on mesh dim-1
324+
shard_placement_dim1 = _StridedShard(1, split_factor=1) # same as Shard(1)
325+
tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size)
326+
shard_x = tensor_list[mesh_dim1_local_rank]
327+
expected_shard_dim1 = [
328+
torch.tensor(value, device=self.device_type)
329+
for value in [[[0], [4]], [[2], [6]], [[1], [5]], [[3], [7]]]
330+
][self.rank]
331+
self.assertEqual(shard_x, expected_shard_dim1)
332+
333+
# shard_to_replicate on mesh dim-1
334+
full_tensor = shard_placement_dim1._to_replicate_tensor(
335+
shard_x,
336+
mesh_2d,
337+
mesh_dim=1,
338+
current_logical_shape=list(expected_shard_dim0.shape),
339+
)
340+
self.assertEqual(full_tensor, expected_shard_dim0)
341+
342+
# shard_to_replicate on mesh dim-0
343+
full_tensor = shard_placement_dim0._to_replicate_tensor(
344+
full_tensor,
345+
mesh_2d,
346+
mesh_dim=0,
347+
current_logical_shape=list(x.shape),
348+
)
349+
self.assertEqual(full_tensor, x)
350+
269351

270352
class Test2DStridedLocalShard(DTensorTestBase):
271353
@property

0 commit comments

Comments
 (0)
0