|
20 | 20 | from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
21 | 21 | from torch.distributed.tensor._ops.utils import (
|
22 | 22 | expand_to_full_mesh_op_strategy,
|
| 23 | + generate_redistribute_costs, |
23 | 24 | is_tensor_dim_sharded,
|
24 | 25 | is_tensor_evenly_shardable,
|
25 | 26 | is_tensor_partial,
|
@@ -739,7 +740,6 @@ def place(vp: Placement, ip: Placement) -> Placement:
|
739 | 740 |
|
740 | 741 | @register_prop_rule(
|
741 | 742 | [
|
742 |
| - aten.split.Tensor, |
743 | 743 | aten.split_with_sizes.default,
|
744 | 744 | aten.split_with_sizes_copy.default,
|
745 | 745 | ],
|
@@ -804,3 +804,70 @@ def size_split(N, i) -> list:
|
804 | 804 | for _ in range(len(output_size_list))
|
805 | 805 | ]
|
806 | 806 | return OutputSharding(output_spec_list)
|
| 807 | + |
| 808 | + |
| 809 | +@register_op_strategy(aten.split.Tensor, schema_info=RuntimeSchemaInfo(1)) |
| 810 | +def split_strategy(op_schema: OpSchema) -> TupleStrategy: |
| 811 | + input_strategy = op_schema.args_schema[0] |
| 812 | + assert isinstance(input_strategy, OpStrategy) |
| 813 | + |
| 814 | + split_size_or_sections = op_schema.args_schema[1] |
| 815 | + |
| 816 | + dim = op_schema.args_schema[2] if len(op_schema.args_schema) > 2 else 0 |
| 817 | + assert isinstance(dim, int) |
| 818 | + dim = normalize_dim(dim, input_strategy.ndim) |
| 819 | + |
| 820 | + def size_split(N, i) -> list: |
| 821 | + # Last chunk will be smaller if the tensor size N |
| 822 | + # along the given dimension dim is not divisible by i. |
| 823 | + assert i > 0 |
| 824 | + return [i] * (N // i) + ([N % i] if N % i != 0 else []) |
| 825 | + |
| 826 | + output_size_list = ( |
| 827 | + size_split( |
| 828 | + input_strategy.strategies[0].output_spec.shape[dim], split_size_or_sections |
| 829 | + ) |
| 830 | + if isinstance(split_size_or_sections, int) |
| 831 | + else split_size_or_sections |
| 832 | + ) |
| 833 | + assert isinstance(output_size_list, Sized) |
| 834 | + |
| 835 | + output_strategy_childs = [OpStrategy([]) for _ in range(len(output_size_list))] |
| 836 | + for input_placement_strategy in input_strategy.strategies: |
| 837 | + op_args_target_specs = [] |
| 838 | + redistribute_costs = [] |
| 839 | + input_spec = input_placement_strategy.output_spec |
| 840 | + |
| 841 | + output_placements = input_spec.placements |
| 842 | + if is_tensor_dim_sharded(input_spec, dim=dim): |
| 843 | + # need reshard before splitting |
| 844 | + placements_after_unshard = unshard_tensor_dim( |
| 845 | + input_spec.placements, dim=dim |
| 846 | + ) |
| 847 | + input_target_spec = DTensorSpec( |
| 848 | + mesh=input_spec.mesh, |
| 849 | + placements=placements_after_unshard, |
| 850 | + tensor_meta=input_spec.tensor_meta, |
| 851 | + ) |
| 852 | + op_args_target_specs.append(input_target_spec) |
| 853 | + redistribute_costs.append( |
| 854 | + generate_redistribute_costs(input_strategy, input_target_spec) |
| 855 | + ) |
| 856 | + output_placements = placements_after_unshard |
| 857 | + else: |
| 858 | + op_args_target_specs.append(input_spec) |
| 859 | + redistribute_costs.append([0.0 for _ in input_strategy.strategies]) |
| 860 | + |
| 861 | + for child in output_strategy_childs: |
| 862 | + child.strategies.append( |
| 863 | + PlacementStrategy( |
| 864 | + output_specs=DTensorSpec( |
| 865 | + mesh=input_spec.mesh, |
| 866 | + placements=output_placements, |
| 867 | + ), |
| 868 | + input_specs=op_args_target_specs, |
| 869 | + redistribute_cost=redistribute_costs, |
| 870 | + ) |
| 871 | + ) |
| 872 | + |
| 873 | + return TupleStrategy(output_strategy_childs) |
0 commit comments