8000 [DTensor][random] defer DTensor RNG state sync until first random op … · pytorch/pytorch@13eadba · GitHub
[go: up one dir, main page]

Skip to content

Commit 13eadba

Browse files
XilunWuaditew01
authored andcommitted
[DTensor][random] defer DTensor RNG state sync until first random op call or manual_seed call; support more flexible OffsetBasedRNGTracker init (#147025)
Resolves #146767. May also resolve #147584. ### Summary This PR removes the RNG tracker init from the `distribute_tensor` call for the following reasons: 1. if the user does not use random ops on DTensor, there's no need to init DTensor RNG which currently requires CUDA device to be present. 2. this complies with the 0-communication semantic of `src_data_rank=None` shard distribution. Besides, `OffsetBasedRNGTracker` only accepts `DeviceMesh` argument to its constructor method. ### Consequence DTensor RNG initialization is delayed till the first DTensor random ops call or `torch.distributed.tensor.random.manual_seed`. ### Test `pytest test/distributed/tensor/test_random_ops.py` `pytest test/distributed/tensor/parallel/test_tp_random_state.py` `pytest test/distributed/tensor/parallel/test_tp_style.py` Differential Revision: [D70201856](https://our.internmc.facebook.com/intern/diff/D70201856) Pull Request resolved: #147025 Approved by: https://github.com/kwen2501
1 parent a119ca3 commit 13eadba

File tree

6 files changed

+75
-48
lines changed

6 files changed

+75
-48
lines changed

test/distributed/tensor/parallel/test_tp_random_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_model_init(self):
4949
self.assertEqual(dp_rank, self.rank // tp_size)
5050
self.assertEqual(tp_rank, self.rank % tp_size)
5151

52-
for enable_distribute_flag in [False, True]:
52+
for enable_distribute_fla 8000 g in [True, False]:
5353
# a local model on meta device
5454
model = MLPModule(device="meta")
5555
# the col-wise parallel style shards the weight over tensor dim 0
@@ -68,7 +68,9 @@ def test_model_init(self):
6868
torch.cuda.manual_seed(dp_rank)
6969

7070
# disable/enable parallel RNG feature
71-
random._rng_tracker.distribute_region_enabled = enable_distribute_flag
71+
if random._rng_tracker:
72+
random._rng_tracker.distribute_region_enabled = enable_distribute_flag
73+
7274
self.assertTrue(model_tp.net1.weight.is_meta)
7375
# initialize the model's local shard
7476
model_tp.to_empty(device=self.device_type)

test/distributed/tensor/parallel/test_tp_style.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ def test_prepare_module_output(self):
346346
@with_comms
347347
def test_sequence_parallel_style(self):
348348
mesh = init_device_mesh(self.device_type, (self.world_size,))
349+
# early init RNG tracker
350+
torch.distributed.tensor._random.manual_seed(0, mesh)
349351

350352
comm_mode = CommDebugMode()
351353
batch, N, embedding_dim = 20, 8, 12

test/distributed/tensor/test_random_ops.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,34 +100,41 @@ def test_meta_tensor_init(self):
100100
meta_dtensor = distribute_tensor(
101101
torch.empty(*size, device="meta"), device_mesh, [Replicate()]
102102
)
103+
104+
# the tensor slice on the current rank
105+
self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024)
106+
107+
# Test 1: enable the distribute region for RNG (by default)
103108
self.assertTrue(meta_dtensor.is_meta)
109+
# Tensor meta init
104110
dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
105-
106-
# disable the distribute region for RNG
107-
random._rng_tracker.distribute_region_enabled = False
108111
dtensor.uniform_()
112+
# check `distribute_region_enabled` is set to True by default
113+
self.assertTrue(random._rng_tracker.distribute_region_enabled)
109114

110115
# allgather the local tensors
111116
local_tensor = funcol.all_gather_tensor(
112117
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
113118
)
114119

115120
# compare with local tensors from other ranks
116-
self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024)
117121
for other_rank in range(self.world_size):
118-
# the RNG result on each rank differs even they're supposed
119-
# to be replicated
122+
# the RNG result on each rank are the same because they're replicated
120123
if self.rank != other_rank:
124+
# other rank should have an identical local tensor
121125
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
122-
self.assertNotEqual(
126+
self.assertEqual(
123127
local_tensor[self_slice, :], local_tensor[other_slice, :]
124128
)
125129

126-
# enable the distribute region for RNG
127-
random._rng_tracker.distribute_region_enabled = True
130+
# Test 2: disable the distribute region for RNG
128131
self.assertTrue(meta_dtensor.is_meta)
132+
# Tensor meta init
129133
dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
134+
random._rng_tracker.distribute_region_enabled = False
130135
dtensor.uniform_()
136+
# check `distribute_region_enabled` is set to False
137+
self.assertTrue(not random._rng_tracker.distribute_region_enabled)
131138

132139
# allgather the local tensors
133140
local_tensor = funcol.all_gather_tensor(
@@ -136,11 +143,11 @@ def test_meta_tensor_init(self):
136143

137144
# compare with local tensors from other ranks
138145
for other_rank in range(self.world_size):
139-
# the RNG result on each rank are the same because they're replicated
146+
# the RNG result on each rank differs even they're supposed
147+
# to be replicated
140148
if self.rank != other_rank:
141-
# other rank should have an identical local tensor
142149
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
143-
self.assertEqual(
150+
self.assertNotEqual(
144151
local_tensor[self_slice, :], local_tensor[other_slice, :]
145152
)
146153

@@ -251,10 +258,15 @@ def test_rng_tracker_init(self):
251258
seed_from_rank_0 = int(object_list[0])
252259

253260
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
254-
# seed synchronization happens after the first `distribute_tensor` call
255-
distribute_tensor(
261+
# seed synchronization now does NOT happen after the first `distribute_tensor`
262+
# call
263+
dt = distribute_tensor(
256264
torch.empty([self.world_size], device=TYPE_DEVICE), device_mesh, [Shard(0)]
257265
)
266+
self.assertTrue(random._rng_tracker is None)
267+
# seed synchronization only happens after `manual_seed` or the first DTensor
268+
# random op call
269+
dt.uniform_(0, 1)
258270
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
259271

260272
@with_comms
@@ -459,6 +471,9 @@ def test_deterministic_uniform_2d(self):
459471
for placements, shard_index in zip(placements_list, shard_index_list):
460472
dtensor = dtensor.redistribute(device_mesh, placements)
461473

474+
# random op call
475+
dtensor.uniform_(0, 1)
476+
462477
# check shard information is correct
463478
shard_coord = [
464479
coordinate[mesh_dim] if mesh_dim >= 0 else 0
@@ -503,9 +518,6 @@ def test_deterministic_uniform_2d(self):
503518

504519
local_shard_comb = itertools.product(*local_shard_list_on_dim)
505520

506-
# random op call
507-
dtensor.uniform_(0, 1)
508-
509521
# the local shard
510522
local_tensor = dtensor.to_local()
511523
# allgather the local tensors

torch/distributed/tensor/_api.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
1515
from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
1616
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
17-
from torch.distributed.tensor._random import (
18-
is_rng_supported_mesh,
19-
OffsetBasedRNGTracker,
20-
)
2117
from torch.distributed.tensor._redistribute import (
2218
Redistribute,
2319
redistribute_local_tensor,
@@ -705,13 +701,6 @@ def distribute_tensor(
705701
msg = "To use DTensor API with xla, you must install the torch_xla package!"
706702
raise ImportError(msg) from e
707703

708-
# instantiate a RNG tracker if haven't. By default DTensor uses an
709-
# OffsetBasedRNGTracker to perform random operators.
710-
# TODO: the value assignment to global variable is not the ideal solution
711-
# we can replace it in future.
712-
if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
713-
random._rng_tracker = OffsetBasedRNGTracker(device_type)
714-
715704
if not tensor.is_leaf:
716705
raise RuntimeError(
717706
"`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
@@ -1025,7 +1014,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
10251014
spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
10261015

10271016
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
1028-
random._rng_tracker = random.OffsetBasedRNGTracker()
1017+
random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh)
10291018

10301019
assert random._rng_tracker is not None
10311020
with random._rng_tracker._distribute_region(spec):

torch/distributed/tensor/_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def dispatch(
198198
if not random._rng_tracker and is_rng_supported_mesh(mesh):
199199
# Default to `OffsetBasedRNGTracker` if the parallelism API
200200
# did not already construct one
201-
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
201+
random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
202202

203203
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
204204
torch.Tensor, local_tensor_args[0]

torch/distributed/tensor/_random.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,18 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
6868
``manual_seed`` will throw an error.
6969
Current implementation only supports a GPU device mesh.
7070
"""
71-
device_handle = _get_device_handle(device_mesh.device_type)
72-
if not device_handle:
73-
raise NotImplementedError(
74-
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
71+
if not is_rng_supported_mesh(device_mesh):
72+
warnings.warn(
73+
"DTensor manual_seed() may not have complete support "
74+
f"on {device_mesh.device_type} device mesh"
7575
)
76+
return
7677

7778
# instantiate a RNG tracker if haven't. By default DTensor uses an
7879
# OffsetBasedRNGTracker to perform random operators.
7980
global _rng_tracker
8081
if not _rng_tracker:
81-
_rng_tracker = OffsetBasedRNGTracker(
82-
device_mesh.device_type, run_state_sync=False
83-
)
82+
_rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False)
8483

8584
# the current rank is in mesh
8685
if device_mesh.get_coordinate() is not None:
@@ -102,16 +101,16 @@ class _RNGStateTracker:
102101
a random op (an operator that calls RNG).
103102
"""
104103

105-
def __init__(self, device_type: str = "cuda"):
106-
self._device_type = device_type
107-
self._device_handle = _get_device_handle(device_type)
104+
def __init__(self, device: torch.device):
105+
self._device = device
106+
self._device_handle = _get_device_handle(self._device.type)
108107
if not (self._device_handle and self._device_handle.is_available()):
109108
raise RuntimeError(
110-
f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
109+
f"{self.__class__.__name__} instantiation requires the presence of "
110+
f"{device.type} device but couldn't find."
111111
)
112112

113113
self._states: dict[str, Tensor] = {}
114-
self._devices = [self._device_handle.current_device()]
115114
self._use_distribute_region = True
116115

117116
@property
@@ -159,11 +158,25 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
159158
This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
160159
should be shared and synchronized among all ranks to respect the semantics of DTensor
161160
random operators.
161+
162+
note: _RNGStateTracker only supports cuda/cuda-like device
162163
"""
163164

164-
def __init__(self, device_type: str = "cuda", run_state_sync: bool = True):
165-
super().__init__(device_type)
166-
rng_state = self._device_handle.get_rng_state().to(device_type)
165+
def __init__(
166+
self,
167+
device_mesh: DeviceMesh,
168+
run_state_sync: bool = True,
169+
):
170+
super().__init__(_resolve_device(device_mesh=device_mesh))
171+
assert self._device_handle is not None
172+
# DTensor RNG tracker so far only supports CUDA/CUDA-like devices
173+
if self._device.type != "cuda":
174+
raise RuntimeError(
175+
f"{self.__class__.__name__} instantiation requires the presence of "
176+
f"CUDA/CUDA-like device. Got {self._device.type} instead."
177+
)
178+
179+
rng_state = self._device_handle.get_rng_state().to(self._device)
167180
if run_state_sync:
168181
# synchronize RNG state using rank 0's current one
169182
dist.broadcast(rng_state, 0)
@@ -185,7 +198,8 @@ def _distribute_region(self, spec: DTensorSpec):
185198
if self.distribute_region_enabled:
186199
old_offset = self.get_offset("parallel-rng")
187200
self._set_pre_op_offset(spec)
188-
with torch.random.fork_rng(self._devices, device_type=self._device_type):
201+
with torch.random.fork_rng(devices=[self._device]):
202+
assert self._device_handle is not None
189203
self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
190204
try:
191205
yield # execute the region code
@@ -366,3 +380,11 @@ def _calc_shard_linear_idx(
366380
shard_coord_stride *= size
367381

368382
return shard_linear_idx
383+
384+
385+
def _resolve_device(device_mesh: DeviceMesh) -> torch.device:
386+
device_type = device_mesh.device_type
387+
device_handle = _get_device_handle(device_type)
388+
assert device_handle is not None
389+
device_idx = device_mesh.get_rank() % device_handle.device_count()
390+
return torch.device(f"{device_type}:{device_idx:d}")

0 commit comments

Comments
 (0)
0