8000 [pytree][1/N] change pytree usages to implementation agnostic: `torch… · pytorch/pytorch@1f8026c · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f8026c

Browse files
committed
[pytree][1/N] change pytree usages to implementation agnostic: torch.distributed 8000
ghstack-source-id: 34a09e0 Pull Request resolved: #144332
1 parent d774a47 commit 1f8026c

31 files changed

+71
-96
lines changed

benchmarks/dynamo/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import torch
77
import torch._dynamo as dynamo
8-
import torch.utils.pytree.python as pytree
98
from torch._dynamo.testing import reduce_to_scalar_loss
109
from torch.nn.parallel import DistributedDataParallel as DDP
1110
from torch.profiler import profile, ProfilerActivity, record_function
11+
from torch.utils.pytree import tree_map
1212

1313

1414
try:
@@ -62,7 +62,7 @@ def move_tensor(maybe_tensor):
6262
return maybe_tensor.to(dev_rank)
6363
return maybe_tensor
6464

65-
inputs = pytree.tree_map(move_tensor, inputs)
65+
inputs = tree_map(move_tensor, inputs)
6666

6767
if args.fsdp:
6868
model = apply_fsdp(

test/distributed/_composable/fsdp/test_fully_shard_extensions.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14-
import torch.utils.pytree.python as pytree
1514
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
1615
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1716
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
@@ -25,6 +24,7 @@
2524
)
2625
from torch.testing._internal.common_utils import run_tests
2726
from torch.testing._internal.two_tensor import TwoTensor
27+
from torch.utils.pytree import tree_map_only
2828

2929

3030
def two_tensor_fsdp_pre_all_gather_v1(
@@ -144,13 +144,9 @@ def unwrap(x: cls):
144144
assert pad_in_pre_all_gather == x._pad_in_pre_all_gather
145145
return x._data
146146

147-
out = func(
148-
*pytree.tree_map_only(cls, unwrap, args),
149-
**pytree.tree_map_only(cls, unwrap, kwargs),
150-
)
151-
return pytree.tree_map_only(
152-
torch.Tensor, lambda x: cls(x, pad_in_pre_all_gather), out
153-
)
147+
args, kwargs = tree_map_only(cls, unwrap, (args, kwargs))
148+
out = func(*args, **kwargs)
149+
return tree_map_only(torch.Tensor, lambda x: cls(x, pad_in_pre_all_gather), out)
154150

155151
def __tensor_flatten__(self):
156152
return ["_data"], None

test/distributed/checkpoint/fsdp/test_fsdp_dsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.testing._internal.common_fsdp import FSDPTest, MLP
2929
from torch.testing._internal.common_utils import run_tests
3030
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
31-
from torch.utils.pytree.python import tree_all_only
31+
from torch.utils.pytree import tree_all_only
3232

3333

3434
class TestFullyShardWithDistributedStateDict(FSDPTest):

test/distributed/checkpoint/test_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
with_comms,
4949
)
5050
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
51-
from torch.utils.pytree.python import tree_all, tree_all_only
51+
from torch.utils.pytree import tree_all, tree_all_only
5252

5353

5454
if not dist.is_available():

test/distributed/pipelining/test_stage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
parametrize,
2626
skip_but_pass_in_sandcastle_if,
2727
)
28-
from torch.utils.pytree.python import tree_map_only
28+
from torch.utils.pytree import tree_map_only
2929

3030

3131
d_hid = 512

test/distributed/tensor/test_dtensor_ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import torch.distributed as dist
99
import torch.testing._internal.common_methods_invocations as common_ops
10-
import torch.utils.pytree.python as pytree
1110
from torch.distributed._tensor import DeviceMesh, DTensor
1211
from torch.overrides import resolve_name
1312
from torch.testing._internal.common_device_type import (
@@ -20,7 +19,7 @@
2019
DTensorConverter,
2120
DTensorOpTestBase,
2221
)
23-
from torch.utils.pytree.python import tree_map
22+
from torch.utils.pytree import tree_leaves, tree_map
2423

2524

2625
# rewrite common size variables to sth can be sharded evenly
@@ -535,8 +534,8 @@ def test():
535534
self.check_dtensor_func(test, op)
536535

537536
def assert_ref_dtensor_equal(self, dtensor_rs, rs):
538-
flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
539-
flat_rs = pytree.tree_leaves(rs)
537+
flat_dtensor_rs = tree_leaves(dtensor_rs)
538+
flat_rs = tree_leaves(rs)
540539
self.assertEqual(len(flat_dtensor_rs), len(flat_rs))
541540
for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):
542541
if not isinstance(r, torch.Tensor):
@@ -600,7 +599,7 @@ def to_replicate(e: object) -> object:
600599
# we need to skip tests containing tensors of zero elements for now.
601600
# see issue: https://github.com/pytorch/tau/issues/470
602601
# TODO remove this once issue above fixed.
603-
flat_args = pytree.tree_leaves(dtensor_rs)
602+
flat_args = tree_leaves(dtensor_rs)
604603
if any(
605604
isinstance(e, torch.Tensor) and e.numel() == 0
606605
for e in flat_args

test/distributed/tensor/test_pointwise_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from unittest import skip
77

88
import torch
9-
import torch.utils.pytree.python as pytree
109
from torch import Tensor
1110
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
1211
from torch.distributed._tensor.placement_types import (
@@ -20,6 +19,7 @@
2019
DTensorOpTestBase,
2120
skip_unless_torch_gpu,
2221
)
22+
from torch.utils.pytree import tree_map
2323

2424

2525
def no_op():
@@ -48,7 +48,7 @@ def f(x):
4848
)
4949
return x
5050

51-
return pytree.tree_map(f, [val])[0]
51+
return tree_map(f, val)
5252

5353

5454
def deepcopy_convert_from_dtensor(val: Any) -> Any:
@@ -64,7 +64,7 @@ def f(x):
6464
return x.full_tensor()
6565
return x
6666

67-
return pytree.tree_map(f, [val])[0]
67+
return tree_map(f, val)
6868

6969

7070
class DistElementwiseOpsTest(DTensorOpTestBase):

test/distributed/tensor/test_view_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88
import torch.distributed as dist
9-
import torch.utils.pytree.python as pytree
109
from torch import rand, randn, Tensor
1110
from torch.distributed._tensor import (
1211
DeviceMesh,
@@ -32,6 +31,7 @@
3231
DTensorTestBase,
3332
with_comms,
3433
)
34+
from torch.utils.pytree import tree_leaves
3535

3636

3737
class TestViewOps(DTensorTestBase):
@@ -139,7 +139,7 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
139139
dim_map = dim_maps[op]
140140
rules = dim_map(*args, **kwargs)
141141
outputs = op(*args, **kwargs)
142-
flat_args = pytree.arg_tree_leaves(*args)
142+
flat_args = tree_leaves(args)
143143
in_shape = flat_args[0].shape
144144

145145
no_shard_dims = set()

torch/distributed/_functional_collectives.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@
99
import torch.distributed.distributed_c10d as c10d
1010
from torch.distributed.device_mesh import DeviceMesh
1111
from torch.fx.experimental.proxy_tensor import get_proxy_mode
12+
from torch.utils.pytree import tree_map_only
1213

1314
from . import _functional_collectives_impl as fun_col_impl
1415

1516

16-
try:
17-
from torch.utils._cxx_pytree import tree_map_only
18-
except ImportError:
19-
from torch.utils.pytree.python import tree_map_only # type: ignore[no-redef]
20-
21-
2217
if torch._running_with_deploy():
2318

2419
def is_torchdynamo_compiling():

torch/distributed/_shard/common_op_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
import torch
5-
import torch.utils.pytree.python as pytree
5+
from torch.utils.pytree import tree_map_
66

77

88
def _basic_validation(op, args=(), kwargs=None):
@@ -22,8 +22,8 @@ def is_distributed_tensor(e):
2222
if isinstance(e, ShardedTensor):
2323
has_distributed_tensor = True
2424

25-
pytree.tree_map_(is_distributed_tensor, args)
26-
pytree.tree_map_(is_distributed_tensor, kwargs)
25+
tree_map_(is_distributed_tensor, args)
26+
tree_map_(is_distributed_tensor, kwargs)
2727

2828
if not has_distributed_tensor:
2929
raise TypeError(
@@ -44,8 +44,8 @@ def validate_pg(e):
4444
)
4545
cur_pg = e._process_group
4646

47-
pytree.tree_map_(validate_pg, args)
48-
pytree.tree_map_(validate_pg, kwargs)
47+
tree_map_(validate_pg, args)
48+
tree_map_(validate_pg, kwargs)
4949

5050

5151
def _register_default_op(op, decorator):

0 commit comments

Comments
 (0)
0