8000 bootcamp task for DTensor by XilunWu · Pull Request #148932 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

bootcamp task for DTensor #148932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: gh/XilunWu/125/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions torch/distributed/tensor/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ def unwrap_to_op_info(
kwargs_schema[k] = v
local_kwargs[k] = v

assert compute_mesh is not None, (
f"found no DeviceMesh from dtensor args for {op_call}!"
)
assert (
compute_mesh is not None
), f"found no DeviceMesh from dtensor args for {op_call}!"
op_info = OpInfo(
compute_mesh,
OpSchema(
Expand All @@ -416,18 +416,18 @@ def unwrap_to_op_info(
def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
if spec is not None:
assert isinstance(spec, DTensorSpec), (
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
)
assert isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
assert res.ndim == 0, "output tensor should be scalar!"
return res
elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance(spec, (list, tuple)), (
f"output spec does not match with output! Expected list/tuple, got {spec}."
)
assert spec is not None and isinstance(
spec, (list, tuple)
), f"output spec does not match with output! Expected list/tuple, got {spec}."
res_list = []
for e, s in zip(res, spec):
res_list.append(OpDispatcher.wrap(e, s))
Expand Down
69 changes: 68 additions & 1 deletion torch/distributed/tensor/_ops/_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
is_tensor_dim_sharded,
is_tensor_evenly_shardable,
is_tensor_partial,
Expand Down Expand Up @@ -739,7 +740,6 @@ def place(vp: Placement, ip: Placement) -> Placement:

@register_prop_rule(
[
aten.split.Tensor,
aten.split_with_sizes.default,
aten.split_with_sizes_copy.default,
],
Expand Down Expand Up @@ -804,3 +804,70 @@ def size_split(N, i) -> list:
for _ in range(len(output_size_list))
]
return OutputSharding(output_spec_list)


@register_op_strategy(aten.split.Tensor, schema_info=RuntimeSchemaInfo(1))
def split_strategy(op_schema: OpSchema) -> TupleStrategy:
input_strategy = op_schema.args_schema[0]
assert isinstance(input_strategy, OpStrategy)

split_size_or_sections = op_schema.args_schema[1]

dim = op_schema.args_schema[2] if len(op_schema.args_schema) > 2 else 0
assert isinstance(dim, int)
dim = normalize_dim(dim, input_strategy.ndim)

def size_split(N, i) -> list:
Copy link
Collaborator
@Skylion007 Skylion007 Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be more specific about the typing here for size_split? at least preserve the List type from the input type with a type var.

# Last chunk will be smaller if the tensor size N
# along the given dimension dim is not divisible by i.
assert i > 0
return [i] * (N // i) + ([N % i] if N % i != 0 else [])

output_size_list = (
size_split(
input_strategy.strategies[0].output_spec.shape[dim], split_size_or_sections
)
if isinstance(split_size_or_sections, int)
else split_size_or_sections
)
assert isinstance(output_size_list, Sized)

output_strategy_childs = [OpStrategy([]) for _ in range(len(output_size_list))]
for input_placement_strategy in input_strategy.strategies:
op_args_target_specs = []
redistribute_costs = []
input_spec = input_placement_strategy.output_spec

output_placements = input_spec.placements
if is_tensor_dim_sharded(input_spec, dim=dim):
# need reshard before splitting
placements_after_unshard = unshard_tensor_dim(
input_spec.placements, dim=dim
)
input_target_spec = DTensorSpec(
mesh=input_spec.mesh,
placements=placements_after_unshard,
tensor_meta=input_spec.tensor_meta,
)
op_args_target_specs.append(input_target_spec)
redistribute_costs.append(
generate_redistribute_costs(input_strategy, input_target_spec)
)
output_placements = placements_after_unshard
else:
op_args_target_specs.append(input_spec)
redistribute_costs.append([0.0 for _ in input_strategy.strategies])

for child in output_strategy_childs:
child.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(
mesh=input_spec.mesh,
placements=output_placements,
),
input_specs=op_args_target_specs,
redistribute_cost=redistribute_costs,
)
)

return TupleStrategy(output_strategy_childs)
Loading
0