@@ -894,6 +894,96 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
894
894
)
895
895
896
896
897
+ @register_op_strategy (
898
+ aten ._scaled_dot_product_fused_attention_overrideable .default ,
899
+ schema_info = RuntimeSchemaInfo (6 ),
900
+ )
901
+ def scaled_dot_product_fused_attention_overrideable_strategy (
902
+ op_schema : OpSchema ,
903
+ ) -> OpStrategy :
904
+ # NOTE: currently we only support some simple strategies to support tensor parallelism
905
+ mesh = op_schema .get_mesh_from_args ()
906
+ return_debug_mask = len (op_schema .args_schema ) >= 7 and op_schema .args_schema [6 ]
907
+ q_input_strategy = op_schema .args_schema [0 ]
908
+ assert isinstance (q_input_strategy , OpStrategy )
909
+ # assuming q/k/v have the same shape
910
+
911
+ has_attn_bias = (
912
+ len (op_schema .args_schema ) >= 4 and op_schema .args_schema [3 ] is not None
913
+ )
914
+
915
+ single_mesh_dim_strategies : list [PlacementList ] = []
916
+
917
+ # placement list stores placements of [outputs, inputs]
918
+ # in the spda case, we have 3 valid tensor outputs and 3 or 4 tensor inputs
919
+ # first we can always accept full replication for both inputs and outputs
920
+ all_replicate : PlacementList = [
921
+ # outputs
922
+ Replicate (),
923
+ Replicate (),
924
+ None , # cum_seq_q
925
+ None , # cum_seq_k
926
+ None , # max_q
927
+ None , # max_k
928
+ None , # philox_seed
929
+ None , # philox_offset
930
+ None , # debug_attn_mask
931
+ # inputs
932
+ Replicate (),
933
+ Replicate (),
934
+ Replicate (),
935
+ ]
936
+ if has_attn_bias :
937
+ all_replicate .append (Replicate ()) # attn_bias
938
+ single_mesh_dim_strategies .append (all_replicate )
939
+
940
+ # second we can accept the sharding pattern of tensor parallelism, which
941
+ # shard on the num of head dim
942
+ qkv_sharding = Shard (1 ) # num head dim
943
+ output_sharding = Shard (1 ) # num head dim
944
+ logsumexp_sharding = Shard (1 ) # num head dim
945
+
946
+ num_heads_dim_sharding : PlacementList = [
947
+ output_sharding ,
948
+ logsumexp_sharding ,
949
+ None , # cum_seq_q
950
+ None , # cum_seq_k
951
+ None , # max_q
952
+ None , # max_k
953
+ None , # philox_seed
954
+ None , # philox_offset
955
+ None , # debug_attn_mask
956
+ qkv_sharding ,
957
+ qkv_sharding ,
958
+ qkv_sharding ,
959
+ ]
960
+ if has_attn_bias :
961
+ num_heads_dim_sharding .append (Shard (1 ))
962
+ single_mesh_dim_strategies .append (num_heads_dim_sharding )
963
+
964
+ # Context Parallelism: shards on the sequence dim
965
+ single_mesh_dim_strategies .append (
966
+ [
967
+ Shard (2 ), # output
968
+ Shard (2 ), # logsumexp
969
+ None , # cum_seq_q
970
+ None , # cum_seq_k
971
+ None , # max_q
972
+ None , # max_k
973
+ None , # philox_seed
974
+ None , # philox_offset
975
+ None , # debug_attn_mask
976
+ Shard (2 ), # q
977
+ Shard (2 ), # k
978
+ Shard (2 ), # v
979
+ ]
980
+ )
981
+
982
+ return expand_to_full_mesh_op_strategy (
983
+ mesh , op_schema , single_mesh_dim_strategies , input_index = 9
984
+ )
985
+
986
+
897
987
@register_op_strategy (
898
988
aten ._scaled_dot_product_fused_attention_overrideable_backward .default
899
989
)
0 commit comments