8000 [torchrec][pt-d][model store] introduce LocalShardsWrapper for DTenso… · pytorch/pytorch@0acd09a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0acd09a

Browse files
iamzainhudapytorchmergebot
authored andcommitted
[torchrec][pt-d][model store] introduce LocalShardsWrapper for DTensor (#129150)
Summary: Same as D57688538, recreated because of GH issues This diff introduces LocalShardsWrapper which is crucial to migrating from using ShardedTensor to DTensor in TRec state dict representation. As well as any changes needed in PT-D and ModelStore to support this. It allows us to extend DTensor to support multiple shards on a rank as well as empty shards on a rank as needed by TRec sharding logic. This diff also extends the support for LocalShardsWrapper to be used in conjunction with DTensor in checkpointing cases (ModelStore and DCP) See D54375878 for how it is used. **LocalShardsWrapper supports the following torch ops:** + torch.ops._c10d_functional.all_gather_into_tensor.default + aten._to_copy.default + aten.view.default + aten.equal.default + aten.detach.default With extensibility to add more as required by use cases. See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach. NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512 Test Plan: ` buck test mode/opt -c python.package_style=inplace aiplatform/modelstore/client/tests_gpu:dist_checkpoint_save_load_with_stateful_tests -- --print-passing-details` `buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_tensor_configs -- --print-passing-details` Sandcastle Reviewed By: XilunWu, wanchaol Differential Revision: D58570479 Pull Request resolved: #129150 Approved by: https://github.com/XilunWu
1 parent 31c9e3d commit 0acd09a

File tree

6 files changed

+394
-60
lines changed

6 files changed

+394
-60
lines changed

torch/distributed/_checkpointable.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# mypy: allow-untyped-defs
2+
# Copyright (c) Meta Platforms, Inc. and affiliates
3+
from typing import Any, Protocol, runtime_checkable
4+
5+
import torch
6+
7+
8+
@runtime_checkable
9+
class _Checkpointable(Protocol): # noqa: PYI046
10+
"""
11+
Interface for checkpointable objects.
12+
Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly.
13+
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
14+
"""
15+
16+
def __create_write_items__(self, fqn: str, object: Any):
17+
"""
18+
Return a list of WriteItems based on object's contents.
19+
"""
20+
raise NotImplementedError(
21+
"_Checkpointable._create_write_items is not implemented"
22+
)
23+
24+
def __create_chunk_list__(self):
25+
"""
26+
Return a list of `ChunkStorageMetadata` based on object's contents.
27+
"""
28+
raise NotImplementedError(
29+
"_Checkpointable._create_chunk_list is not implemented"
30+
)
31+
32+
def __get_tensor_shard__(self, index) -> torch.Tensor:
33+
"""
34+
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
35+
"""
36+
raise NotImplementedError(
37+
"_Checkpointable._get_tensor_shard is not implemented"
38+
)
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# mypy: allow-untyped-defs
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Any, List, Tuple
9+
10+
import torch
11+
from torch.distributed.checkpoint.metadata import (
12+
ChunkStorageMetadata,
13+
MetadataIndex,
14+
TensorProperties,
15+
TensorStorageMetadata,
16+
)
17+
from torch.distributed.checkpoint.planner import (
18+
TensorWriteData,
19+
WriteItem,
20+
WriteItemType,
21+
)
22+
23+
aten = (
24+
torch.ops.aten
25+
) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified.
26+
27+
28+
class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
29+
"""
30+
A wrapper class to hold local shards of a DTensor.
31+
This class is used largely for checkpointing purposes and implicity subtypes
32+
the _Checkpointable protocol.
33+
"""
34+
35+
__slots__ = ["_local_shards", "_storage_meta"]
36+
_local_shards: List[torch.Tensor]
37+
_storage_meta: TensorStorageMetadata
38+
39+
@staticmethod
40+
def __new__(
41+
cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]]
42+
) -> "LocalShardsWrapper":
43+
assert len(local_shards) > 0
44+
assert len(local_shards) == len(local_offsets)
45+
assert all(
46+
tensor.device == local_shards[0].device for tensor in local_shards[1:]
47+
)
48+
49+
# we calculate the total tensor size by "concat" on second tensor dimension
50+
cat_tensor_shape = list(local_shards[0].size())
51+
if len(local_shards) > 1: # column-wise sharding
52+
for shard in local_shards[1:]:
53+
cat_tensor_shape[1] += shard.size()[1]
54+
55+
wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
56+
wrapper_shape = torch.Size(cat_tensor_shape)
57+
chunks_meta = [
58+
ChunkStorageMetadata(
59+
offsets=torch.Size(offset),
60+
sizes=shard.size(),
61+
)
62+
for shard, offset in zip(local_shards, local_offsets)
63+
]
64+
65+
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
66+
cls,
67+
torch.Size(cat_tensor_shape),
68+
)
69+
r._local_shards = local_shards
70+
r._storage_meta = TensorStorageMetadata(
71+
properties=wrapper_properties,
72+
size=wrapper_shape,
73+
chunks=chunks_meta,
74+
)
75+
76+
return r
77+
78+
# necessary for ops dispatching from this subclass to its local shards
79+
@classmethod
80+
# pyre-fixme[3]: Return type must be annotated.
81+
# pyre-fixme[2]: Parameter must be annotated.
82+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
83+
kwargs = kwargs or {}
84+
85+
dispatcher = {
86+
torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor,
87+
torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor,
88+
aten._to_copy.default: cls.handle_to_copy,
89+
aten.view.default: cls.handle_view,
90+
aten.equal.default: cls.handle_equal,
91+
aten.detach.default: cls.handle_detach,
92+
aten.clone.default: cls.handle_clone,
93+
}
94+
95+
if func in dispatcher:
96+
return dispatcher[func](
97+
args, kwargs
98+
) # pyre-ignore [29] - `Variable[_VT]` is not a function.
99+
else:
100+
raise NotImplementedError(
101+
f"{func} is not supported for LocalShardsWrapper!"
102+
)
103+
104+
@staticmethod
105+
# pyre-fixme[3]: Return type must be annotated.
106+
# pyre-fixme[2]: Parameter must be annotated.
107+
def handle_all_gather_into_tensor(args, kwargs):
108+
dim = args[0].local_sizes()[0][1]
109+
cat_tensor = torch.cat(
110+
[t.view(-1) for t in args[0].local_shards()], dim=0
111+
).view(-1, dim)
112+
return torch.ops._c10d_functional.all_gather_into_tensor.default(
113+
cat_tensor, *args[1:], **kwargs
114+
)
115+
116+
@staticmethod
117+
# pyre-fixme[3]: Return type must be annotated.
118+
# pyre-fixme[2]: Parameter must be annotated.
119+
def handle_wait_tensor(args, kwargs):
120+
return torch.ops._c10d_functional.wait_tensor(args[0])
121+
122+
@staticmethod
123+
# pyre-fixme[3]: Return type must be annotated.
124+
# pyre-fixme[2]: Parameter must be annotated.
125+
def handle_to_copy(args, kwargs):
126+
res_shards_list = [
127+
aten._to_copy.default(shard, *args[1:], **kwargs)
128+
for shard in args[0].local_shards()
129+
]
130+
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
131+
132+
@staticmethod
133+
# pyre-fixme[3]: Return type must be annotated.
134+
# pyre-fixme[2]: Parameter must be annotated.
135+
def handle_view(args, kwargs):
136+
# TODO, do we need to change the shape of associated offsets?
137+
res_shards_list = [
138+
aten.view.default(shard, args[1], **kwargs)
139+
for shard in args[0].local_shards()
140+
]
141+
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
142+
143+
@staticmethod
144+
# pyre-fixme[3]: Return type must be annotated.
145+
# pyre-fixme[2]: Parameter must be annotated.
146+
def handle_equal(args, kwargs):
147+
"""
148+
LocalShardsWrapper equal impl also checks for equality of storage metadata
149+
and the order of shards
150+
"""
151+
a, b = args[0], args[1]
152+
if len(a.local_shards()) != len(b.local_shards()):
153+
return False
154+
if not all(
155+
aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
156+
):
157+
return False
158+
if not a.storage_metadata() == b.storage_metadata():
159+
return False
160+
return True
161+
162+
@staticmethod
163+
# pyre-fixme[3]: Return type must be annotated.
164+
# pyre-fixme[2]: Parameter must be annotated.
165+
def handle_detach(args, kwargs):
166+
self_ls = args[0]
167+
deatched_local_shards = [
168+
aten.detach.default(shard) for shard in self_ls.local_shards()
169+
]
170+
self_ls._local_shards = deatched_local_shards
171+
self_ls._storage_meta.properties.requires_grad = False
172+
return self_ls
173+
174+
@staticmethod
175+
# pyre-fixme[3]: Return type must be annotated.
176+
# pyre-fixme[2]: Parameter must be annotated.
177+
def handle_clone(args, kwargs):
178+
self_ls = args[0]
179+
desired_memory_format = kwargs.get("memory_format", None)
180+
if desired_memory_format and desired_memory_format != torch.preserve_format:
181+
raise NotImplementedError(
182+
f"{desired_memory_format} is not supported for LocalShardsWrapper!"
183+
)
184+
cloned_local_shards = [
185+
shard.clone(memory_format=desired_memory_format)
186+
for shard in self_ls._local_shards
187+
]
188+
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
189+
190+
@property
191+
def device(self) -> torch._C.device: # type: ignore[override]
192+
return self._local_shards[0].device
193+
194+
@property
195+
def is_meta(self) -> bool: # type: ignore[override]
196+
return self._local_shards[0].is_meta
197+
198+
# pyre-ignore[14]
199+
def is_pinned(self) -> bool: # type: ignore[override]
200+
return self._storage_meta.properties.pin_memory
201+
202+
# pyre-ignore[14]
203+
def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper":
204+
self._storage_meta.properties.requires_grad = requires_grad
205+
[shard.requires_grad_(requires_grad) for shard in self._local_shards]
206+
return self
207+
208+
def local_shards(self) -> List[torch.Tensor]:
209+
"""
210+
Returns a list of :class:`torch.Tensor' corresponding to the
211+
local shards for this rank. Returns an empty list if the current rank
212+
does not host any shards for this Tensor.
213+
"""
214+
return self._local_shards
215+
216+
def local_sizes(self) -> List[torch.Size]:
217+
"""
218+
Returns a list of :class:`torch.Size' corresponding to the
219+
local sizes for the shards on this rank. Returns an empty list if the current rank
220+
does not host any shards for this Tensor.
221+
"""
222+
return [chunk.sizes for chunk in self._storage_meta.chunks]
223+
224+
def local_offsets(self) -> List[torch.Size]:
225+
"""
226+
Returns a list of :class:`torch.Size' corresponding to the
227+
local offsets for the shards on this rank. Returns an empty list if the current rank
228+
does not host any shards for this Tensor.
229+
"""
230+
return [chunk.offsets for chunk in self._storage_meta.chunks]
231+
232+
@property
233+
def local_chunks(self) -> List[ChunkStorageMetadata]:
234+
"""
235+
Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
236+
metadata for each tensor shard
237+
"""
238+
return self._storage_meta.chunks
239+
240+
def storage_metadata(self) -> TensorStorageMetadata:
241+
"""
242+
Returns a :class:`TensorStorageMetadata` object corresponding to the
243+
metadata for the local tensor on current rank
244+
"""
245+
return self._storage_meta
246+
247+
def __create_write_items__(
248+
self, fqn: str, object: Any
249+
) -> List[WriteItem]: # pyre-ignore[2]
250+
"""
251+
For compatibility with DCP, we support creation of WriteItems
252+
10000 such that they can be saved properly.
253+
"""
254+
return [
255+
WriteItem(
256+
index=MetadataIndex(fqn, chunks.offsets),
257+
type=WriteItemType.SHARD,
258+
tensor_data=TensorWriteData(
259+
chunk=ChunkStorageMetadata(
260+
offsets=chunks.offsets,
261+
sizes=chunks.sizes,
262+
),
263+
properties=self._storage_meta.properties,
264+
size=object.size(),
265+
),
266+
)
267+
for tensor, chunks in zip(self.local_shards(), self.local_chunks)
268+
]
269+
270+
def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
271+
"""
272+
For compatibility with DCP, we support creation of chunk lists
273+
such that they can be saved properly.
274+
"""
275+
return self._storage_meta.chunks
276+
277+
def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
278+
"""
279+
For compatibility with DCP, we support finding shard based on index
280+
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
281+
"""
282+
# Fast lookup path
283+
if index.index is not None:
284+
if (
285+
len(self._local_shards) > index.index
286+
and self._storage_meta.chunks[index.index].offsets == index.offset
287+
):
288+
return self._local_shards[index.index]
289+
290+
if index.offset is not None:
291+
for shard, chunk in zip(self._local_shards, self._storage_meta.chunks):
292+
if chunk.offsets == index.offset:
293+
return shard
294+
295+
raise ValueError(
296+
f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
297+
)
298+
299+
def _get_tensor_size_bytes(self) -> int:
300+
object_size = 0
301+
for shard in self.local_shards():
302+
object_size += shard.nelement() * shard.element_size()
303+
return object_size
304+
305+
# pyre-fixme[3]: Return type must be annotated.
306+
def __hash__(self):
307+
return id(self)
308+
309+
# pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
310+
# pyre-fixme[3]: Return type must be annotated.
311+
def __repr__(self):
312+
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
313+
314+
def __str__(self) -> str:
315+
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"

torch/distributed/_tensor/api.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,38 @@ def placements(self) -> Sequence[Placement]:
532532
"""
533533
return self._spec.placements
534534

535+
def __create_write_items__(self, fqn: str, object: Any):
536+
from torch.distributed.checkpoint.planner_helpers import (
537+
_create_write_items_for_dtensor,
538+
)
539+
540+
if hasattr(self._local_tensor, "__create_write_items__"):
541+
return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined]
542+
elif isinstance(self._local_tensor, torch.Tensor):
543+
return [_create_write_items_for_dtensor(fqn, object)]
544+
else:
545+
raise RuntimeError("Unsupported tensor type!")
546+
547+
def __create_chunk_list__(self):
548+
from torch.distributed.checkpoint.planner_helpers import (
549+
_create_chunk_from_dtensor,
550+
)
551+
552+
if hasattr(self._local_tensor, "__create_chunk_list__"):
553+
return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined]
554+
elif isinstance(self._local_tensor, torch.Tensor):
555+
return [_create_chunk_from_dtensor(self)]
556+
else:
557+
raise RuntimeError("Unsupported tensor type!")
558+
559+
def __get_tensor_shard__(self, index):
560+
if hasattr(self._local_tensor, "__get_tensor_shard__"):
561+
return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
562+
elif isinstance(self._local_tensor, torch.Tensor):
563+
return self.to_local()
564+
else:
565+
raise RuntimeError("Unsupported tensor type!")
566+
535567

536568
def distribute_tensor(
537569
tensor: torch.Tensor,

0 commit comments

Comments
 (0)
0