8000 bootcamp task for DTensor · pytorch/pytorch@8f0579a · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f0579a

Browse files
committed
bootcamp task for DTensor
ghstack-source-id: 8098694 Pull Request resolved: #148932
1 parent 703176e commit 8f0579a

File tree

2 files changed

+77
-10
lines changed

2 files changed

+77
-10
lines changed

torch/distributed/tensor/_dispatch.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,9 @@ def unwrap_to_op_info(
390390
kwargs_schema[k] = v
391391
local_kwargs[k] = v
392392

393-
assert compute_mesh is not None, (
394-
f"found no DeviceMesh from dtensor args for {op_call}!"
395-
)
393+
assert (
394+
compute_mesh is not None
395+
), f"found no DeviceMesh from dtensor args for {op_call}!"
396396
op_info = OpInfo(
397397
compute_mesh,
398398
OpSchema(
@@ -416,18 +416,18 @@ def unwrap_to_op_info(
416416
def wrap(res: object, spec: OutputSpecType) -> object:
417417
if isinstance(res, torch.Tensor):
418418
if spec is not None:
419-
assert isinstance(spec, DTensorSpec), (
420-
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
421-
)
419+
assert isinstance(
420+
spec, DTensorSpec
421+
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
422422
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
423423
else:
424424
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
425425
assert res.ndim == 0, "output tensor should be scalar!"
426426
return res
427427
elif isinstance(res, (list, tuple)):
428-
assert spec is not None and isinstance(spec, (list, tuple)), (
429-
f"output spec does not match with output! Expected list/tuple, got {spec}."
430-
)
428+
assert spec is not None and isinstance(
429+
spec, (list, tuple)
430+
), f"output spec does not match with output! Expected list/tuple, got {spec}."
431431
res_list = []
432432
for e, s in zip(res, spec):
433433
res_list.append(OpDispatcher.wrap(e, s))

torch/distributed/tensor/_ops/_tensor_ops.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
2121
from torch.distributed.tensor._ops.utils import (
2222
expand_to_full_mesh_op_strategy,
23+
generate_redistribute_costs,
2324
is_tensor_dim_sharded,
2425
is_tensor_evenly_shardable,
2526
is_tensor_partial,
@@ -739,7 +740,6 @@ def place(vp: Placement, ip: Placement) -> Placement:
739740

740741
@register_prop_rule(
741742
[
742-
aten.split.Tensor,
743743
aten.split_with_sizes.default,
744744
aten.split_with_sizes_copy.default,
745745
],
@@ -804,3 +804,70 @@ def size_split(N, i) -> list:
804804
for _ in range(len(output_size_list))
805805
]
806806
return OutputSharding(output_spec_list)
807+
808+
809+
@register_op_strategy(aten.split.Tensor, schema_info=RuntimeSchemaInfo(1))
810+
def split_strategy(op_schema: OpSchema) -> TupleStrategy:
811+
input_strategy = op_schema.args_schema[0]
812+
assert isinstance(input_strategy, OpStrategy)
813+
814+
split_size_or_sections = op_schema.args_schema[1]
815+
816+
dim = op_schema.args_schema[2] if len(op_schema.args_schema) > 2 else 0
817+
assert isinstance(dim, int)
818+
dim = normalize_dim(dim, input_strategy.ndim)
819+
820+
def size_split(N, i) -> list:
821+
# Last chunk will be smaller if the tensor size N
822+
# along the given dimension dim is not divisible by i.
823+
assert i > 0
824+
return [i] * (N // i) + ([N % i] if N % i != 0 else [])
825+
826+
output_size_list = (
827+
size_split(
828+
input_strategy.strategies[0].output_spec.shape[dim], split_size_or_sections
829+
)
830+
if isinstance(split_size_or_sections, int)
831+
else split_size_or_sections
832+
)
833+
assert isinstance(output_size_list, Sized)
834+
835+
output_strategy_childs = [OpStrategy([]) for _ in range(len(output_size_list))]
836+
for input_placement_strategy in input_strategy.strategies:
837+
op_args_target_specs = []
838+
redistribute_costs = []
839+
input_spec = input_placement_strategy.output_spec
840+
841+
output_placements = input_spec.placements
842+
if is_tensor_dim_sharded(input_spec, dim=dim):
843+
# need reshard before splitting
844+
placements_after_unshard = unshard_tensor_dim(
845+
input_spec.placements, dim=dim
846+
)
847+
input_target_spec = DTensorSpec(
848+
mesh=input_spec.mesh,
849+
placements=placements_after_unshard,
850+
tensor_meta=input_spec.tensor_meta,
851+
)
852+
op_args_target_specs.append(input_target_spec)
853+
redistribute_costs.append(
854+
generate_redistribute_costs(input_strategy, input_target_spec)
855+
)
856+
output_placements = placements_after_unshard
857+
else:
858+
op_args_target_specs.append(input_spec)
859+
redistribute_costs.append([0.0 for _ in input_strategy.strategies])
860+
861+
for child in output_strategy_childs:
862+
child.strategies.append(
863+
PlacementStrategy(
864+
output_specs=DTensorSpec(
865+
mesh=input_spec.mesh,
866+
placements=output_placements,
867+
),
868+
input_specs=op_args_target_specs,
869+
redistribute_cost=redistribute_costs,
870+
)
871+
)
872+
873+
return TupleStrategy(output_strategy_childs)

0 commit comments

Comments
 (0)
0