8000 update import and update docstring · pytorch/pytorch@ad7ea01 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad7ea01

Browse files
committed
update import and update docstring
1 parent f057a45 commit ad7ea01

File tree

4 files changed

+68
-67
lines changed

4 files changed

+68
-67
lines changed

test/distributed/checkpoint/test_checkpoint.py

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import sys
44
from typing import Optional, List, cast
5-
from torch.distributed._shard.checkpoint.storage import WriteResult
5+
from torch.distributed.checkpoint.storage import WriteResult
66

7-
from torch.distributed._shard.checkpoint import (
7+
from torch.distributed.checkpoint import (
88
StorageReader,
99
StorageWriter,
1010
CheckpointException,
@@ -63,6 +63,7 @@
6363
)
6464
sys.exit(0)
6565

66+
6667
class TestModule(torch.nn.Module):
6768
def __init__(self) -> None:
6869
super().__init__()
@@ -121,34 +122,44 @@ def test_default_metadata(self) -> None:
121122
)
122123

123124
state_dict = {
124-
'sharded': sharded_tensor.rand(spec, (10, 10, )),
125-
'replicated': torch.rand(4, device=device),
126-
'bytes': [1, 2, 3, 4],
125+
"sharded": sharded_tensor.rand(
126+
spec,
127+
(
128+
10,
129+
10,
130+
),
131+
),
132+
"replicated": torch.rand(4, device=device),
133+
"bytes": [1, 2, 3, 4],
127134
}
128135

129136
metadata = _create_default_local_metadata(state_dict)
130-
self.assertTrue('bytes' in metadata.state_dict_metadata)
131-
self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata)
137+
self.assertTrue("bytes" in metadata.state_dict_metadata)
138+
self.assertIsInstance(
139+
metadata.state_dict_metadata["bytes"], BytesStorageMetadata
140+
)
132141

133-
self.assertTrue('replicated' in metadata.state_dict_metadata)
134-
self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata)
135-
md = metadata.state_dict_metadata['replicated']
136-
self.assertEqual(md.size, state_dict['replicated'].size())
142+
self.assertTrue("replicated" in metadata.state_dict_metadata)
143+ self.assertIsInstance(
144+
metadata.state_dict_metadata["replicated"], TensorStorageMetadata
145+
)
146+
md = metadata.state_dict_metadata["replicated"]
147+
self.assertEqual(md.size, state_dict["replicated"].size())
137148
self.assertEqual(md.properties.dtype, torch.float32)
138149
self.assertEqual(1, len(md.chunks))
139150

140-
self.assertTrue('sharded' in metadata.state_dict_metadata)
141-
self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata)
142-
md = metadata.state_dict_metadata['sharded']
151+
self.assertTrue("sharded" in metadata.state_dict_metadata)
152+
self.assertIsInstance(
153+
metadata.state_dict_metadata["sharded"], TensorStorageMetadata
154+
)
155+
md = metadata.state_dict_metadata["sharded"]
143156
self.assertEqual(md.properties.dtype, torch.float32)
144-
self.assertEqual(md.size, state_dict['sharded'].size())
157+
self.assertEqual(md.size, state_dict["sharded"].size())
145158
self.assertEqual(2, len(md.chunks))
146159

160+
147161
class TestStorageBase:
148-
def __init__(
149-
self,
150-
fail_conf
151-
):
162+
def __init__(self, fail_conf):
152163
self.fail_conf = fail_conf
153164
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
154165

@@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None):
164175
ranks = self._get_ranks(name)
165176
fut = Future()
166177
if ranks is not None and self.rank in ranks:
167-
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
178+
fut.set_exception(
179+
ValueError(f"async rank fail {self.rank} for {name}")
180+
)
168181
else:
169182
fut.set_result(result)
170183
return fut
171184

185+
172186
class FaultyStorageWriter(TestStorageBase, StorageWriter):
173-
def __init__(
174-
self,
175-
fail_conf
176-
):
187+
def __init__(self, fail_conf):
177188
super(FaultyStorageWriter, self).__init__(fail_conf)
178189

179190
def init(self, is_coordinator: bool) -> None:
@@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
188199
return plans
189200

190201
def write_data(
191-
self,
192-
plan: SavePlan,
193-
planner: SavePlanner
202+
self, plan: SavePlan, planner: SavePlanner
194203
) -> Future[List[WriteResult]]:
195204
self._fail_rank("fail_write_data")
196205
return self._fail_rank_async("fail_write_data_async", [])
197206

198-
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
207+
def finish(
208+
self, metadata: Metadata, results: List[List[WriteResult]]
209+
) -> None:
199210
self._fail_rank("fail_finish")
200211

201212

202213
class FaultyStorageReader(TestStorageBase, StorageReader):
203-
def __init__(
204-
self,
205-
metadata,
206-
fail_conf
207-
):
214+
def __init__(self, metadata, fail_conf):
208215
super(FaultyStorageReader, self).__init__(fail_conf)
209216
self.metadata = metadata
210217

@@ -219,35 +226,32 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
219226
self._fail_rank("fail_prepare_global_plan")
220227
return plans
221228

222-
def read_data(
223-
self,
224-
plan: LoadPlan,
225-
planner: LoadPlanner
226-
) -> Future[None]:
229+
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
227230
self._fail_rank("fail_read_data")
228231
return self._fail_rank_async("fail_read_data_async")
229232

230233
def read_metadata(self) -> Metadata:
231234
self._fail_rank("fail_read_metadata")
232235
return self.metadata
233236

237+
234238
class TestDistributedFailure(ShardedTensorTestBase):
235239
def get_spec(self):
236240
return ChunkShardingSpec(
237241
dim=0,
238242
placements=[
239243
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
240-
]
244+
],
241245
)
242246

243247
@with_comms(init_rpc=False)
244248
@skip_if_lt_x_gpu(2)
245249
@requires_nccl()
246250
def test_dummy_writer_works(self) -> None:
247251
state_dict = {
248-
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
249-
'replicated': torch.rand(10, 10),
250-
'bytes': [1, 2, 3, 4]
252+
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
253+
"replicated": torch.rand(10, 10),
254+
"bytes": [1, 2, 3, 4],
251255
}
252256

253257
save_state_dict(state_dict, FaultyStorageWriter({}))
@@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None:
257261
@requires_nccl()
258262
def test_dummy_reader_works(self) -> None:
259263
state_dict = {
260-
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
261-
'replicated': torch.rand(10, 10),
262-
'bytes': [1, 2, 3, 4]
264+
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
265+
"replicated": torch.rand(10, 10),
266+
"bytes": [1, 2, 3, 4],
263267
}
264268
metadata = _create_default_local_metadata(state_dict)
265269

@@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs):
283287

284288
failed_ranks = e.failures.keys()
285289
for rank in bad_ranks:
286-
self.assertTrue(rank in failed_ranks, msg=f"{rank} was supposed to fail was fine")
287-
290+
self.assertTrue(
291+
rank in failed_ranks,
292+
msg=f"{rank} was supposed to fail was fine",
293+
)
288294

289295
def _test_save(self, state_dict, coordinator=0, **kwargs):
290296
no_dist = not dist.is_initialized()
@@ -296,6 +302,7 @@ def _save():
296302
coordinator_rank=coordinator,
297303
no_dist=no_dist,
298304
)
305+
299306
self._test_dist_failure(_save, kwargs)
300307

301308
def _test_load(self, state_dict, coordinator=0, **kwargs):
@@ -317,9 +324,9 @@ def _load():
317324
@requires_nccl()
318325
def test_save_error_handling(self) -> None:
319326
state_dict = {
320-
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
321-
'replicated': torch.rand(10, 10),
322-
'bytes': [1, 2, 3, 4]
327+
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
328+
"replicated": torch.rand(10, 10),
329+
"bytes": [1, 2, 3, 4],
323330
}
324331

325332
self._test_save(state_dict, fail_init=[0])
@@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None:
334341
self._test_save(state_dict, coordinator=1, fail_finish=[1])
335342

336343
def test_save_error_handling_no_dist(self) -> None:
337-
state_dict = {
338-
'replicated': torch.rand(10, 10),
339-
'bytes': [1, 2, 3, 4]
340-
}
344+
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
341345

342346
self.assertFalse(dist.is_initialized())
343347

@@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None:
354358
@requires_nccl()
355359
def test_load_error_handling(self) -> None:
356360
state_dict = {
357-
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
358-
'replicated': torch.rand(10, 10),
359-
'bytes': [1, 2, 3, 4]
361+
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
362+
"replicated": torch.rand(10, 10),
363+
"bytes": [1, 2, 3, 4],
360364
}
361365

362366
self._test_load(state_dict)
@@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None:
373377
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
374378
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
375379

376-
377380
def test_load_error_handling_no_dist(self) -> None:
378-
state_dict = {
379-
'replicated': torch.rand(10, 10),
380-
'bytes': [1, 2, 3, 4]
381-
}
381+
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
382382
self._test_load(state_dict)
383383
self._test_load(state_dict, fail_init=[0])
384384
self._test_load(state_dict, fail_read_metadata=[0])
@@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None:
387387
self._test_load(state_dict, fail_read_data=[0])
388388
self._test_load(state_dict, fail_read_data_async=[0])
389389

390+
390391
if __name__ == "__main__":
391392
run_tests()

test/distributed/fsdp/test_distributed_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from torch import distributed as dist
8-
from torch.distributed._shard.checkpoint import (
8+
from torch.distributed.checkpoint import (
99
FileSystemReader,
1010
FileSystemWriter,
1111
load_state_dict,

torch/distributed/checkpoint/state_dict_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def load_state_dict(
5959
>>> my_model = MyModule()
6060
>>> optimizer = Adagrad(my_model.parameters())
6161
>>> model_state_dict = my_model.state_dict()
62-
>>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1")
62+
>>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1")
6363
64-
>>> torch.distributed._shard.checkpoint.load_state_dict(
64+
>>> torch.distributed.checkpoint.load_state_dict(
6565
>>> state_dict=model_state_dict,
6666
>>> storage_reader=fs_storage_loader,
6767
>>> )

torch/distributed/checkpoint/state_dict_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def save_state_dict(
5959
6060
>>> model_state_dict = my_model.state_dict()
6161
62-
>>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1")
63-
>>> torch.distributed._shard.checkpoint.save_state_dict(
62+
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
63+
>>> torch.distributed.checkpoint.save_state_dict(
6464
>>> state_dict=model_state_dict,
6565
>>> storage_writer=fs_stroage_writer,
6666
>>> )

0 commit comments

Comments
 (0)
0