8000 update · pytorch/pytorch@8fa443f · GitHub
[go: up one dir, main page]

Skip to content

Commit 8fa443f

Browse files
committed
update
1 parent 406eac8 commit 8fa443f

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

torch/distributed/tensor/_ops/_matrix_ops.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,96 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
894894
)
895895

896896

897+
@register_op_strategy(
898+
aten._scaled_dot_product_fused_attention_overrideable.default,
899+
schema_info=RuntimeSchemaInfo(6),
900+
)
901+
def scaled_dot_product_fused_attention_overrideable_strategy(
902+
op_schema: OpSchema,
903+
) -> OpStrategy:
904+
# NOTE: currently we only support some simple strategies to support tensor parallelism
905+
mesh = op_schema.get_mesh_from_args()
906+
return_debug_mask = len(op_schema.args_schema) >= 7 and op_schema.args_schema[6]
907+
q_input_strategy = op_schema.args_schema[0]
908+
assert isinstance(q_input_strategy, OpStrategy)
909+
# assuming q/k/v have the same shape
910+
911+
has_attn_bias = (
912+
len(op_schema.args_schema) >= 4 and op_schema.args_schema[3] is not None
913+
)
914+
915+
single_mesh_dim_strategies: list[PlacementList] = []
916+
917+
# placement list stores placements of [outputs, inputs]
918+
# in the spda case, we have 3 valid tensor outputs and 3 or 4 tensor inputs
919+
# first we can always accept full replication for both inputs and outputs
920+
all_replicate: PlacementList = [
921+
# outputs
922+
Replicate(),
923+
Replicate(),
924+
None, # cum_seq_q
925+
None, # cum_seq_k
926+
None, # max_q
927+
None, # max_k
928+
None, # philox_seed
929+
None, # philox_offset
930+
None, # debug_attn_mask
931+
# inputs
932+
Replicate(),
933+
Replicate(),
934+
Replicate(),
935+
]
936+
if has_attn_bias:
937+
all_replicate.append(Replicate()) # attn_bias
938+
single_mesh_dim_strategies.append(all_replicate)
939+
940+
# second we can accept the sharding pattern of tensor parallelism, which
941+
# shard on the num of head dim
942+
qkv_sharding = Shard(1) # num head dim
943+
output_sharding = Shard(1) # num head dim
944+
logsumexp_sharding = Shard(1) # num head dim
945+
946+
num_heads_dim_sharding: PlacementList = [
947+
output_sharding,
948+
logsumexp_sharding,
949+
None, # cum_seq_q
950+
None, # cum_seq_k
951+
None, # max_q
952+
None, # max_k
953+
None, # philox_seed
954+
None, # philox_offset
955+
None, # debug_attn_mask
956+
qkv_sharding,
957+
qkv_sharding,
958+
qkv_sharding,
959+
]
960+
if has_attn_bias:
961+
num_heads_dim_sharding.append(Shard(1))
962+
single_mesh_dim_strategies.append(num_heads_dim_sharding)
963+
964+
# Context Parallelism: shards on the sequence dim
965+
single_mesh_dim_strategies.append(
966+
[
967+
Shard(2), # output
968+
Shard(2), # logsumexp
969+
None, # cum_seq_q
970+
None, # cum_seq_k
971+
None, # max_q
972+
None, # max_k
973+
None, # philox_seed
974+
None, # philox_offset
975+
None, # debug_attn_mask
976+
Shard(2), # q
977+
Shard(2), # k
978+
Shard(2), # v
979+
]
980+
)
981+
982+
return expand_to_full_mesh_op_strategy(
983+
mesh, op_schema, single_mesh_dim_strategies, input_index=9
984+
)
985+
986+
897987
@register_op_strategy(
898988
aten._scaled_dot_product_fused_attention_overrideable_backward.default
899989
)

0 commit comments

Comments
 (0)
0