8000 [PT-D][Checkpoint] Update import and update docstring for distributed checkpoint by wz337 · Pull Request #89256 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[PT-D][Checkpoint] Update import and update docstring for distributed checkpoint #89256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 63 additions & 62 deletions test/distributed/checkpoint/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import sys
from typing import Optional, List, cast
from torch.distributed._shard.checkpoint.storage import WriteResult
from torch.distributed.checkpoint.storage import WriteResult

from torch.distributed._shard.checkpoint import (
from torch.distributed.checkpoint import (
StorageReader,
StorageWriter,
CheckpointException,
Expand Down Expand Up @@ -63,6 +63,7 @@
)
sys.exit(0)


class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -121,34 +122,44 @@ def test_default_metadata(self) -> None:
)

state_dict = {
'sharded': sharded_tensor.rand(spec, (10, 10, )),
'replicated': torch.rand(4, device=device),
'bytes': [1, 2, 3, 4],
"sharded": sharded_tensor.rand(
spec,
(
10,
10,
),
),
"replicated": torch.rand(4, device=device),
"bytes": [1, 2, 3, 4],
}

metadata = _create_default_local_metadata(state_dict)
self.assertTrue('bytes' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata)
self.assertTrue("bytes" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["bytes"], BytesStorageMetadata
)

self.assertTrue('replicated' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata)
md = metadata.state_dict_metadata['replicated']
self.assertEqual(md.size, state_dict['replicated'].size())
self.assertTrue("replicated" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["replicated"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["replicated"]
self.assertEqual(md.size, state_dict["replicated"].size())
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(1, len(md.chunks))

self.assertTrue('sharded' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata)
md = metadata.state_dict_metadata['sharded']
self.assertTrue("sharded" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["sharded"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["sharded"]
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(md.size, state_dict['sharded'].size())
self.assertEqual(md.size, state_dict["sharded"].size())
self.assertEqual(2, len(md.chunks))


class TestStorageBase:
def __init__(
self,
fail_conf
):
def __init__(self, fail_conf):
self.fail_conf = fail_conf
self.rank = 0 if not dist.is_initialized() else dist.get_rank()

Expand All @@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None):
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
fut.set_exception(
ValueError(f"async rank fail {self.rank} for {name}")
)
else:
fut.set_result(result)
return fut


class FaultyStorageWriter(TestStorageBase, StorageWriter):
def __init__(
self,
fail_conf
):
def __init__(self, fail_conf):
super(FaultyStorageWriter, self).__init__(fail_conf)

def init(self, is_coordinator: bool) -> None:
Expand All @@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
return plans

def write_data(
self,
plan: SavePlan,
planner: SavePlanner
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])

def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
self._fail_rank("fail_finish")


class FaultyStorageReader(TestStorageBase, StorageReader):
def __init__(
self,
metadata,
fail_conf
):
def __init__(self, metadata, fail_conf):
super(FaultyStorageReader, self).__init__(fail_conf)
self.metadata = metadata

Expand All @@ -219,35 +226,32 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
self._fail_rank("fail_prepare_global_plan")
return plans

def read_data(
self,
plan: LoadPlan,
planner: LoadPlanner
) -> Future[None]:
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
self._fail_rank("fail_read_data")
return self._fail_rank_async("fail_read_data_async")

def read_metadata(self) -> Metadata:
self._fail_rank("fail_read_metadata")
return self.metadata


class TestDistributedFailure(ShardedTensorTestBase):
def get_spec(self):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
]
],
)

@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_dummy_writer_works(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

save_state_dict(state_dict, FaultyStorageWriter({}))
Expand All @@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None:
@requires_nccl()
def test_dummy_reader_works(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
metadata = _create_default_local_metadata(state_dict)

Expand All @@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs):

failed_ranks = e.failures.keys()
for rank in bad_ranks:
self.assertTrue(rank in failed_ranks, msg=f"{rank} was supposed to fail was fine")

self.assertTrue(
rank in failed_ranks,
msg=f"{rank} was supposed to fail was fine",
)

def _test_save(self, state_dict, coordinator=0, **kwargs):
no_dist = not dist.is_initialized()
Expand All @@ -296,6 +302,7 @@ def _save():
coordinator_rank=coordinator,
no_dist=no_dist,
)

self._test_dist_failure(_save, kwargs)

def _test_load(self, state_dict, coordinator=0, **kwargs):
Expand All @@ -317,9 +324,9 @@ def _load():
@requires_nccl()
def test_save_error_handling(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

self._test_save(state_dict, fail_init=[0])
Expand All @@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None:
self._test_save(state_dict, coordinator=1, fail_finish=[1])

def test_save_error_handling_no_dist(self) -> None:
state_dict = {
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
}
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}

self.assertFalse(dist.is_initialized())

Expand All @@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None:
@requires_nccl()
def test_load_error_handling(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

self._test_load(state_dict)
Expand All @@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None:
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])


def test_load_error_handling_no_dist(self) -> None:
state_dict = {
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
}
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
self._test_load(state_dict)
self._test_load(state_dict, fail_init=[0])
self._test_load(state_dict, fail_read_metadata=[0])
Expand All @@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None:
self._test_load(state_dict, fail_read_data=[0])
self._test_load(state_dict, fail_read_data_async=[0])


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_distributed_checkpoint.py
Original file line number Diff line number Diff line change
A3D4 Expand Up @@ -5,7 +5,7 @@

import torch
from torch import distributed as dist
from torch.distributed._shard.checkpoint import (
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
load_state_dict,
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/state_dict_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def load_state_dict(
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1")
>>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1")

>>> torch.distributed._shard.checkpoint.load_state_dict(
>>> torch.distributed.checkpoint.load_state_dict(
>>> state_dict=model_state_dict,
>>> storage_reader=fs_storage_loader,
>>> )
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/state_dict_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def save_state_dict(

>>> model_state_dict = my_model.state_dict()

>>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed._shard.checkpoint.save_state_dict(
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save_state_dict(
>>> state_dict=model_state_dict,
>>> storage_writer=fs_stroage_writer,
>>> )
Expand Down
0