8000 more dist ops in non strict (#147417) · pytorch/pytorch@2473876 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2473876

Browse files
avikchaudhuripytorchmergebot
authored andcommitted
more dist ops in non strict (#147417)
Summary: Previously we added support for `all_reduce` to non strict. This PR extends this support to other non-functional collectives that are remapped in Dynamo: `all_gather`, `all_gather_into_tensor`, `all_to_all_single`, `reduce_scatter_tensor`. Test Plan: added unit tests Differential Revision: D69813991 Pull Request resolved: #147417 Approved by: https://github.com/angelayi
1 parent 3946767 commit 2473876

File tree

4 files changed

+156
-22
lines changed

4 files changed

+156
-22
lines changed

test/export/test_export.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11953,6 +11953,20 @@ def forward(self, x):
1195311953
]
1195411954
self.assertEqual(len(shift_op), 1)
1195511955

11956+
@contextmanager
11957+
def distributed_env(self, world_size):
11958+
try:
11959+
torch.distributed.init_process_group(
11960+
backend="fake",
11961+
world_size=world_size,
11962+
rank=0,
11963+
store=FakeStore(),
11964+
)
11965+
yield
11966+
11967+
finally:
11968+
torch.distributed.destroy_process_group()
11969+
1195611970
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
1195711971
def test_distributed_all_reduce(self):
1195811972
class Foo(torch.nn.Module):
@@ -11965,21 +11979,75 @@ def forward(self, x):
1196511979
torch.distributed.all_reduce(y)
1196611980
return y
1196711981

11968-
try:
11969-
torch.distributed.init_process_group(
11970-
backend="fake",
11971-
world_size=2,
11972-
rank=0,
11973-
store=FakeStore(),
11974-
)
11975-
11982+
with self.distributed_env(world_size=2):
1197611983
m = Foo()
1197711984
ep = export(m, (torch.randn(4, 4),))
1197811985
inp = (torch.randn(4, 4),)
1197911986
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
1198011987

11981-
finally:
11982-
torch.distributed.destroy_process_group()
11988+
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
11989+
def test_distributed_all_gather(self):
11990+
class Foo(torch.nn.Module):
11991+
def forward(self, x):
11992+
ys = [torch.empty_like(x) for _ in range(2)]
11993+
torch.distributed.all_gather(ys, x)
11994+
return ys
11995+
11996+
with self.distributed_env(world_size=2):
11997+
m = Foo()
11998+
ep = export(m, (torch.randn(2),))
11999+
inp = (torch.randn(2),)
12000+
self.assertTrue(
12001+
torch.allclose(a, b) for a, b in zip(ep.module()(*inp), m(*inp))
12002+
)
12003+
12004+
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
12005+
def test_distributed_all_gather_into_tensor(self):
12006+
class Foo(torch.nn.Module):
12007+
def forward(self, x):
12008+
y = torch.empty(2 * 2)
12009+
torch.distributed.all_gather_into_tensor(y, x)
12010+
return y
12011+
12012+
with self.distributed_env(world_size=2):
12013+
m = Foo()
12014+
ep = export(m, (torch.randn(2),))
12015+
inp = (torch.randn(2),)
12016+
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
12017+
12018+
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
12019+
def test_distributed_all_to_all_single(self):
12020+
class Foo(torch.nn.Module):
12021+
def forward(self, x):
12022+
y = torch.empty(4)
12023+
torch.distributed.all_to_all_single(y, x)
12024+
return y
12025+
12026+
with self.distributed_env(world_size=4):
12027+
m = Foo()
12028+
ep = export(m, (torch.randn(4),))
12029+
nodes = ep.graph.find_nodes(
12030+
op="call_function",
12031+
target=torch.ops._c10d_functional.all_to_all_single.default,
12032+
)
12033+
self.assertEqual(len(nodes), 1)
12034+
12035+
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
12036+
def test_distributed_reduce_scatter_tensor(self):
12037+
class Foo(torch.nn.Module):
12038+
def forward(self, x):
12039+
y = torch.empty(2)
12040+
torch.distributed.reduce_scatter_tensor(y, x)
12041+
return y
12042+
12043+
with self.distributed_env(world_size=2):
12044+
m = Foo()
12045+
ep = export(m, (torch.randn(2 * 2),))
12046+
nodes = ep.graph.find_nodes(
12047+
op="call_function",
12048+
target=torch.ops._c10d_functional.reduce_scatter_tensor.default,
12049+
)
12050+
self.assertEqual(len(nodes), 1)
1198312051

1198412052

1198512053
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")

torch/_export/non_strict_utils.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -619,21 +619,28 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
619619
"""
620620

621621
def _override(self, func, args, kwargs):
622-
if torch.distributed.is_available() and func is torch.distributed.all_reduce:
623-
# Redirect to a corresponding functional collective, following Dynamo.
624-
# See torch/distributed/_functional_collectives.py for details.
622+
if torch.distributed.is_available():
625623
from torch.distributed._functional_collectives import (
626-
all_reduce_inplace,
627624
REDUCE_OP_TO_STR,
625+
traceable_collective_remaps,
628626
)
629627

630-
# see CollectiveFunctionRewriteVariable for remapping logic
631-
signature = inspect.signature(func)
632-
kwargs = dict(signature.bind(*args, **kwargs).arguments)
633-
args = ()
634-
if "op" in kwargs:
635-
kwargs["op"] = REDUCE_OP_TO_STR[kwargs["op"]]
636-
return all_reduce_inplace, args, kwargs
628+
if func in traceable_collective_remaps:
629+
# Redirect to a corresponding functional collective, following Dynamo.
630+
# See torch/distributed/_functional_collectives.py for details.
631+
# The following is an adaptation of CollectiveFunctionRewriteVariable.
632+
mapped_func = traceable_collective_remaps[func]
633+
signature = inspect.signature(func)
634+
kwargs = dict(signature.bind(*args, **kwargs).arguments)
635+
args = ()
636+
if func in (
637+
torch.distributed.all_reduce,
638+
torch.distributed.reduce_scatter_tensor,
639+
torch.distributed._reduce_scatter_base,
640+
):
641+
if "op" in kwargs:
642+
kwargs["op"] = REDUCE_OP_TO_STR[kwargs["op"]]
643+
return mapped_func, args, kwargs
637644
if func is torch.tensor:
638645
# Redirect to Python implementation of torch.tensor for data with symints.
639646
# NOTE(avik): We don't unconditionally redirect to this implementation

torch/distributed/_functional_collectives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def _reduce_scatter_tensor_coalesced_native_meta(
10521052
def all_gather_tensor_inplace(
10531053
output_tensor: torch.Tensor,
10541054
input_tensor: torch.Tensor,
1055-
group, # TODO add a type,
1055+
group=None, # TODO add a type,
10561056
async_op: bool = False,
10571057
tag: str = "",
10581058
gather_dim: int = 0,

torch/distributed/distributed_c10d.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3703,6 +3703,20 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
37033703
[tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1
37043704
37053705
"""
3706+
# Dynamo has built-in logic to map legacy distributed ops to functional collectives.
3707+
# Let's redirect to a torch function mode that can mimic this logic outside Dynamo
3708+
# (e.g., non-strict export implements such a torch function mode).
3709+
relevant_args = (tensor,)
3710+
if has_torch_function(relevant_args):
3711+
return handle_torch_function(
3712+
all_gather,
3713+
relevant_args,
3714+
tensor_list,
3715+
tensor,
3716+
group=group,
3717+
async_op=async_op,
3718+
)
3719+
37063720
_check_tensor_list(tensor_list, "tensor_list")
37073721
_check_single_tensor(tensor, "tensor")
37083722
_ensure_all_tensors_same_dtype(tensor_list, tensor)
@@ -3779,6 +3793,20 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
37793793
The Gloo backend does not support this API.
37803794
37813795
"""
3796+
# Dynamo has built-in logic to map legacy distributed ops to functional collectives.
3797+
# Let's redirect to a torch function mode that can mimic this logic outside Dynamo
3798+
# (e.g., non-strict export implements such a torch function mode).
3799+
relevant_args = (input_tensor,)
3800+
if has_torch_function(relevant_args):
3801+
return handle_torch_function(
3802+
all_gather_into_tensor,
3803+
relevant_args,
3804+
output_tensor,
3805+
input_tensor,
3806+
group=group,
3807+
async_op=async_op,
3808+
)
3809+
37823810
_check_single_tensor(input_tensor, "input_tensor")
37833811
_check_single_tensor(output_tensor, "output_tensor")
37843812
if _rank_not_in_group(group):
@@ -4224,6 +4252,21 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
42244252
The Gloo backend does not support this API.
42254253
42264254
"""
4255+
# Dynamo has built-in logic to map legacy distributed ops to functional collectives.
4256+
# Let's redirect to a torch function mode that can mimic this logic outside Dynamo
4257+
# (e.g., non-strict export implements such a torch function mode).
4258+
relevant_args = (input,)
4259+
if has_torch_function(relevant_args):
4260+
return handle_torch_function(
4261+
reduce_scatter_tensor,
4262+
relevant_args,
4263+
output,
4264+
input,
4265+
op=op,
4266+
group=group,
4267+
async_op=async_op,
4268+
)
4269+
42274270
_check_single_tensor(output, "output")
42284271
_check_single_tensor(input, "input")
42294272

@@ -4382,6 +4425,22 @@ def all_to_all_single(
43824425
tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2
43834426
tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3
43844427
"""
4428+
# Dynamo has built-in logic to map legacy distributed ops to functional collectives.
4429+
# Let's redirect to a torch function mode that can mimic this logic outside Dynamo
4430+
# (e.g., non-strict export implements such a torch function mode).
4431+
relevant_args = (input,)
4432+
if has_torch_function(relevant_args):
4433+
return handle_torch_function(
4434+
all_to_all_single,
4435+
relevant_args,
4436+
output,
4437+
input,
4438+
output_split_sizes=output_split_sizes,
4439+
input_split_sizes=input_split_sizes,
4440+
group=group,
4441+
async_op=async_op,
4442+
)
4443+
43854444
if _rank_not_in_group(group):
43864445
_warn_not_in_group("all_to_all_single")
43874446
return

0 commit comments

Comments
 (0)
0