8000 Implementation of padding in dynamic sharding · pytorch/torchrec@5b1dbe7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5b1dbe7

Browse files
committed
Implementation of padding in dynamic sharding
Differential Revision: D74150894
1 parent 3e2737e commit 5b1dbe7

File tree

2 files changed

+101
-35
lines changed

2 files changed

+101
-35
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5555
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5656
from torchrec.distributed.sharding.dynamic_sharding import (
57+
get_largest_dim_sizes,
5758
shards_all_to_all,
5859
update_module_sharding_plan,
5960
update_state_dict_post_resharding,
@@ -1530,13 +1531,19 @@ def update_shards(
15301531
# Deleting all lookups
15311532
self._lookups.clear()
15321533

1534+
# Get max dim size to enable padding for all_to_all
1535+
# TODO: optimize to only go through changed shards
1536+
max_dim_0, max_dim_1 = get_largest_dim_sizes(current_state)
1537+
15331538
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
15341539
module=self,
15351540
state_dict=current_state,
15361541
device=device, # pyre-ignore
15371542
changed_sharding_params=changed_sharding_params,
15381543
env=env,
15391544
extend_shard_name=self.extend_shard_name,
1545+
max_dim_0=max_dim_0,
1546+
max_dim_1=max_dim_1,
15401547
)
15411548

15421549
current_state = update_state_dict_post_resharding(
@@ -1546,6 +1553,7 @@ def update_shards(
15461553
new_sharding_params=changed_sharding_params,
15471554
curr_rank=dist.get_rank(),
15481555
extend_shard_name=self.extend_shard_name,
1556+
max_dim_0=max_dim_0,
15491557
)
15501558

15511559
for name, param in changed_sharding_params.items():

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 93 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.distributed as dist
14+
import torch.nn.functional as F
1415
from torch.distributed._shard.sharded_tensor import Shard
1516
from torchrec.distributed.types import (
1617
ParameterSharding,
@@ -26,11 +27,14 @@ def shards_all_to_all(
2627
device: torch.device,
2728
changed_sharding_params: Dict[str, ParameterSharding],
2829
env: ShardingEnv,
30+
max_dim_0: int,
31+
max_dim_1: int,
2932
extend_shard_name: Callable[[str], str] = lambda x: x,
30-
) -> Tuple[List[Tuple[str, int]], torch.Tensor]:
33+
) -> Tuple[List[Tuple[str, List[int]]], torch.Tensor]:
3134
"""
3235
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
33-
Assumes ranks are ordered in ParameterSharding.ranks.
36+
Assumes ranks are ordered in ParameterSharding.ranks. Implements padding for concatenating, sending and
37+
receiving tensors of different sizes in dim 0 or 1.
3438
3539
Args:
3640
module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed.
@@ -46,9 +50,13 @@ def shards_all_to_all(
4650
4751
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
4852
53+
max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module.
54+
55+
max_dim_1 (int): The maximum dimension size of dim 1 across all tables in the module.
56+
4957
Returns:
50-
Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing:
51-
- A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank.
58+
Tuple[List[Tuple[str, List[int]]], torch.Tensor]: A tuple containing:
59+
- A list of shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank.
5260
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
5361
- The tensor containing all shards received by the current rank after the all-to-all operation.
5462
"""
@@ -65,7 +73,7 @@ def shards_all_to_all(
6573
output_splits_per_rank = [[0] * world_size for _ in range(world_size)]
6674

6775
# 0 by default, as current rank may be recieving 0 shards
68-
num_embeddings_received = 0
76+
num_embeddings_received_l = []
6977
output_tensor_tensor_count = 0
7078
shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)]
7179
local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)]
@@ -86,29 +94,21 @@ def shards_all_to_all(
8694
src_rank = src_ranks[i]
8795

8896
shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes
89-
shard_size_dim_1 = shard_size[1]
90-
input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1
91-
output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1
97+
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
98+
output_splits_per_rank[dst_rank][src_rank] += max_dim_0
9299
if src_rank == rank:
93100
local_shards = sharded_t.local_shards()
94101
assert len(local_shards) == 1
95-
local_table_to_input_tensor_by_dst_rank[d F438 st_rank].append(
96-
sharded_t.local_shards()[0].tensor
102+
cur_t = pad_tensor_to_max_dims(
103+
sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1
97104
)
105+
local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t)
98106
if dst_rank == rank:
99107
shard_names_to_lengths_by_src_rank[src_rank].append(
100-
(shard_name, shard_size_dim_1)
108+
(shard_name, shard_size)
101109
)
102-
# NOTE: Only need to update num_embeddings_received to be the
103-
# num_embeddings of shards if this rank is actually recieving
104-
# any tensors
105-
if num_embeddings_received == 0:
106-
num_embeddings_received = shard_size[0]
107-
else:
108-
# TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable
109-
# For now, assume that shard_sizes in dim 0 are all the same
110-
assert num_embeddings_received == shard_size[0]
111-
output_tensor_tensor_count += shard_size[1]
110+
num_embeddings_received_l.append(shard_size[1])
111+
output_tensor_tensor_count += max_dim_0
112112

113113
local_input_splits = input_splits_per_rank[rank]
114114
local_output_splits = output_splits_per_rank[rank]
@@ -121,16 +121,13 @@ def shards_all_to_all(
121121
local_input_tensor,
122122
shard_info,
123123
),
124-
dim=1,
124+
dim=0,
125125
)
126126

127-
# Transposing the Tensors - because we are concatenating them along dimension 1
128-
# This is because dim 0 size may be different for different shards
129-
# whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table
127+
max_embedding_size = max_dim_1
130128
local_output_tensor = torch.empty(
131-
[output_tensor_tensor_count, num_embeddings_received], device=device
129+
[output_tensor_tensor_count, max_embedding_size], device=device
132130
)
133-
local_input_tensor = local_input_tensor.T.contiguous()
134131

135132
assert sum(local_output_splits) == len(local_output_tensor)
136133
assert sum(local_input_splits) == len(local_input_tensor)
@@ -153,22 +150,23 @@ def shards_all_to_all(
153150

154151
def update_state_dict_post_resharding(
155152
state_dict: Dict[str, ShardedTensor],
156-
ordered_shard_names_and_lengths: List[Tuple[str, int]],
153+
ordered_shard_names_and_lengths: List[Tuple[str, List[int]]],
157154
output_tensor: torch.Tensor,
158155
new_sharding_params: Dict[str, ParameterSharding],
159156
curr_rank: int,
157+
max_dim_0: int,
160158
extend_shard_name: Callable[[str], str] = lambda x: x,
161159
) -> Dict[str, ShardedTensor]:
162160
"""
163161
Updates and returns the given state_dict with new placements and
164162
local_shards based on the output tensor of the AllToAll collective.
163+
Removes padding from the output tensor in dim 0 and 1 if necessary.
165164
166165
Args:
167166
state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards.
168167
169-
shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1
170-
that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and
171-
sizes by rank, then shard order.
168+
ordered_shard_names_and_lengths (List[Tuple[str, List[int]]]): A list of shard name and the corresponding shard_size.
169+
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
172170
173171
output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation.
174172
@@ -177,6 +175,10 @@ def update_state_dict_post_resharding(
177175
178176
curr_rank (int): The current rank of the process in the distributed environment.
179177
178+
max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module. Only dim 0
179+
is needed here to slice the output tensor correctly, as removing the padding will only reference
180+
the original shard sizes stored in ordered_shard_names_and_lengths.
181+
180182
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
181183
182184
Returns:
@@ -187,10 +189,12 @@ def update_state_dict_post_resharding(
187189
shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {}
188190

189191
for shard_name, shard_size in ordered_shard_names_and_lengths:
190-
end_slice_index = slice_index + shard_size
191-
shard_name_to_local_output_tensor[shard_name] = output_tensor[
192-
slice_index:end_slice_index
193-
].T
192+
end_slice_index = slice_index + max_dim_0
193+
cur_t = output_tensor[slice_index:end_slice_index]
194+
cur_t = pad_tensor_to_max_dims(
195+
cur_t, shard_size[0], shard_size[1], remove_padding=True
196+
)
197+
shard_name_to_local_output_tensor[shard_name] = cur_t
194198
slice_index = end_slice_index
195199

196200
for shard_name, param in new_sharding_params.items():
@@ -234,3 +238,57 @@ def update_module_sharding_plan(
234238
for table_name, param_sharding in changed_sharding_params.items():
235239
current_plan[table_name] = param_sharding
236240
return
241+
242+
243+
def get_largest_dim_sizes(
244+
state_dict: Dict[str, ShardedTensor],
245+
) -> Tuple[int, int]:
246+
"""
247+
Returns the largest dimension size of dim 0 and 1 across all tables in a module.
248+
249+
Args:
250+
state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
251+
252+
Returns:
253+
List[int]: A list of the largest dimension size of each table in the state_dict.
254+
"""
255+
max_dim_0 = 0
256+
max_dim_1 = 0
257+
for sharded_t in state_dict.values():
258+
for shard in sharded_t.metadata().shards_metadata:
259+
max_dim_0 = max(max_dim_0, shard.shard_sizes[0])
260+
max_dim_1 = max(max_dim_1, shard.shard_sizes[1])
261+
262+
return max_dim_0, max_dim_1
263+
264+
265+
def pad_tensor_to_max_dims(
266+
t: torch.Tensor,
267+
expected_dim_0: int,
268+
expected_dim_1: int,
269+
remove_padding: bool = False,
270+
) -> torch.Tensor:
271+
"""
272+
Pads a tensor on the right and bottom with zeros.
273+
274+
Args:
275+
tensor (torch.Tensor): The tensor to be padded.
276+
pad_right (int): The number of zeros to pad on the right.
277+
pad_bottom (int): The number of zeros to pad on the bottom.
278+
279+
Returns:
280+
torch.Tensor: The padded tensor.
281+
"""
282+
pad_right = expected_dim_1 - t.size(1)
283+
pad_bottom = expected_dim_0 - t.size(0)
284+
return F.pad(
285+
input=t,
286+
pad=(
287+
0,
288+
pad_right,
289+
0,
290+
pad_bottom,
291+
), # right and bottom
292+
mode="constant",
293+
value=0,
294+
)

0 commit comments

Comments
 (0)
0