8000 [dtensor] fix scaled dot product flash attention sharding · pytorch/pytorch@4439ac0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4439ac0

Browse files
committed
[dtensor] fix scaled dot product flash attention sharding
ghstack-source-id: 408ec85 Pull Request resolved: #148125
1 parent 2978771 commit 4439ac0

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

test/distributed/tensor/test_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
4343
backends.append(SDPBackend.EFFICIENT_ATTENTION)
4444

45-
4645
rotater_enum_to_str = {
4746
_RotateMethod.ALL_GATHER: "allgather",
4847
_RotateMethod.ALL_TO_ALL: "alltoall",
@@ -360,6 +359,9 @@ def _test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> Non
360359
self.device_type,
361360
torch.arange(0, self.world_size),
362361
)
362+
# early init DTensor RNG tracker to avoid broadcast be captuured in comm_mode
363+
torch.distributed.tensor._random.manual_seed(10, device_mesh)
364+
363365
dtype = torch.bfloat16
364366
bs = 2
365367
args = ModelArgs()

torch/distributed/tensor/_ops/_matrix_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def scaled_dot_product_flash_attention_strategy(
252252
None, # cum_seq_k
253253
None, # max_q
254254
None, # max_k
255-
None, # philox_seed
256-
None, # philox_offset
255+
Replicate(), # rng_state
256+
None, # unused
257257
Replicate(),
258258
Replicate(),
259259
Replicate(),
@@ -279,8 +279,8 @@ def scaled_dot_product_flash_attention_strategy(
279279
None, # cum_seq_k
280280
None, # max_q
281281
None, # max_k
282-
None, # philox_seed
283-
None, # philox_offset
282+
Replicate(), # rng_state
283+
None, # unused
284284
debug_attn_mask_sharding,
285285
qkv_sharding,
286286
qkv_sharding,
@@ -297,8 +297,8 @@ def scaled_dot_product_flash_attention_strategy(
297297
None, # cum_seq_k
298298
None, # max_q
299299
None, # max_k
300-
None, # philox_seed
301-
None, # philox_offset
300+
Replicate(), # rng_state
301+
None, # unused
302302
Shard(2), # debugattn
303303
Shard(2), # q
304304
Shard(2), # k

0 commit comments

Comments
 (0)
0