11
11
12
12
import torch
13
13
import torch .distributed as dist
14
+ import torch .nn .functional as F
14
15
from torch .distributed ._shard .sharded_tensor import Shard
15
16
from torchrec .distributed .types import (
16
17
ParameterSharding ,
@@ -26,11 +27,14 @@ def shards_all_to_all(
26
27
device : torch .device ,
27
28
changed_sharding_params : Dict [str , ParameterSharding ],
28
29
env : ShardingEnv ,
30
+ max_dim_0 : int ,
31
+ max_dim_1 : int ,
29
32
extend_shard_name : Callable [[str ], str ] = lambda x : x ,
30
- ) -> Tuple [List [Tuple [str , int ]], torch .Tensor ]:
33
+ ) -> Tuple [List [Tuple [str , List [ int ] ]], torch .Tensor ]:
31
34
"""
32
35
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
33
- Assumes ranks are ordered in ParameterSharding.ranks.
36
+ Assumes ranks are ordered in ParameterSharding.ranks. Implements padding for concatenating, sending and
37
+ receiving tensors of different sizes in dim 0 or 1.
34
38
35
39
Args:
36
40
module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed.
@@ -46,9 +50,13 @@ def shards_all_to_all(
46
50
47
51
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
48
52
53
+ max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module.
54
+
55
+ max_dim_1 (int): The maximum dimension size of dim 1 across all tables in the module.
56
+
49
57
Returns:
50
- Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing:
51
- - A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank.
58
+ Tuple[List[Tuple[str, List[ int] ]], torch.Tensor]: A tuple containing:
59
+ - A list of shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank.
52
60
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
53
61
- The tensor containing all shards received by the current rank after the all-to-all operation.
54
62
"""
@@ -65,7 +73,7 @@ def shards_all_to_all(
65
73
output_splits_per_rank = [[0 ] * world_size for _ in range (world_size )]
66
74
67
75
# 0 by default, as current rank may be recieving 0 shards
68
- num_embeddings_received = 0
76
+ num_embeddings_received_l = []
69
77
output_tensor_tensor_count = 0
70
78
shard_names_to_lengths_by_src_rank = [[] for _ in range (world_size )]
71
79
local_table_to_input_tensor_by_dst_rank = [[] for _ in range (world_size )]
@@ -86,29 +94,21 @@ def shards_all_to_all(
86
94
src_rank = src_ranks [i ]
87
95
88
96
shard_size = sharded_t .metadata ().shards_metadata [i ].shard_sizes
89
- shard_size_dim_1 = shard_size [1 ]
90
- input_splits_per_rank [src_rank ][dst_rank ] += shard_size_dim_1
91
- output_splits_per_rank [dst_rank ][src_rank ] += shard_size_dim_1
97
+ input_splits_per_rank [src_rank ][dst_rank ] += max_dim_0
98
+ output_splits_per_rank [dst_rank ][src_rank ] += max_dim_0
92
99
if src_rank == rank :
93
100
local_shards = sharded_t .local_shards ()
94
101
assert len (local_shards ) == 1
95
- local_table_to_input_tensor_by_dst_rank [ d
F438
st_rank ]. append (
96
- sharded_t .local_shards ()[0 ].tensor
102
+ cur_t = pad_tensor_to_max_dims (
103
+ sharded_t .local_shards ()[0 ].tensor , max_dim_0 , max_dim_1
97
104
)
105
+ local_table_to_input_tensor_by_dst_rank [dst_rank ].append (cur_t )
98
106
if dst_rank == rank :
99
107
shard_names_to_lengths_by_src_rank [src_rank ].append (
100
- (shard_name , shard_size_dim_1 )
108
+ (shard_name , shard_size )
101
109
)
102
- # NOTE: Only need to update num_embeddings_received to be the
103
- # num_embeddings of shards if this rank is actually recieving
104
- # any tensors
105
- if num_embeddings_received == 0 :
106
- num_embeddings_received = shard_size [0 ]
107
- else :
108
- # TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable
109
- # For now, assume that shard_sizes in dim 0 are all the same
110
- assert num_embeddings_received == shard_size [0 ]
111
- output_tensor_tensor_count += shard_size [1 ]
110
+ num_embeddings_received_l .append (shard_size [1 ])
111
+ output_tensor_tensor_count += max_dim_0
112
112
113
113
local_input_splits = input_splits_per_rank [rank ]
114
114
local_output_splits = output_splits_per_rank [rank ]
@@ -121,16 +121,13 @@ def shards_all_to_all(
121
121
local_input_tensor ,
122
122
shard_info ,
123
123
),
124
- dim = 1 ,
124
+ dim = 0 ,
125
125
)
126
126
127
- # Transposing the Tensors - because we are concatenating them along dimension 1
128
- # This is because dim 0 size may be different for different shards
129
- # whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table
127
+ max_embedding_size = max_dim_1
130
128
local_output_tensor = torch .empty (
131
- [output_tensor_tensor_count , num_embeddings_received ], device = device
129
+ [output_tensor_tensor_count , max_embedding_size ], device = device
132
130
)
133
- local_input_tensor = local_input_tensor .T .contiguous ()
134
131
135
132
assert sum (local_output_splits ) == len (local_output_tensor )
136
133
assert sum (local_input_splits ) == len (local_input_tensor )
@@ -153,22 +150,23 @@ def shards_all_to_all(
153
150
154
151
def update_state_dict_post_resharding (
155
152
state_dict : Dict [str , ShardedTensor ],
156
- ordered_shard_names_and_lengths : List [Tuple [str , int ]],
153
+ ordered_shard_names_and_lengths : List [Tuple [str , List [ int ] ]],
157
154
output_tensor : torch .Tensor ,
158
155
new_sharding_params : Dict [str , ParameterSharding ],
159
156
curr_rank : int ,
157
+ max_dim_0 : int ,
160
158
extend_shard_name : Callable [[str ], str ] = lambda x : x ,
161
159
) -> Dict [str , ShardedTensor ]:
162
160
"""
163
161
Updates and returns the given state_dict with new placements and
164
162
local_shards based on the output tensor of the AllToAll collective.
163
+ Removes padding from the output tensor in dim 0 and 1 if necessary.
165
164
166
165
Args:
167
166
state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards.
168
167
169
- shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1
170
- that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and
171
- sizes by rank, then shard order.
168
+ ordered_shard_names_and_lengths (List[Tuple[str, List[int]]]): A list of shard name and the corresponding shard_size.
169
+ This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
172
170
173
171
output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation.
174
172
@@ -177,6 +175,10 @@ def update_state_dict_post_resharding(
177
175
178
176
curr_rank (int): The current rank of the process in the distributed environment.
179
177
178
+ max_dim_0 (int): The maximum dimension size of dim 0 across all tables in the module. Only dim 0
179
+ is needed here to slice the output tensor correctly, as removing the padding will only reference
180
+ the original shard sizes stored in ordered_shard_names_and_lengths.
181
+
180
182
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.
181
183
182
184
Returns:
@@ -187,10 +189,12 @@ def update_state_dict_post_resharding(
187
189
shard_name_to_local_output_tensor : Dict [str , torch .Tensor ] = {}
188
190
189
191
for shard_name , shard_size in ordered_shard_names_and_lengths :
190
- end_slice_index = slice_index + shard_size
191
- shard_name_to_local_output_tensor [shard_name ] = output_tensor [
192
- slice_index :end_slice_index
193
- ].T
192
+ end_slice_index = slice_index + max_dim_0
193
+ cur_t = output_tensor [slice_index :end_slice_index ]
194
+ cur_t = pad_tensor_to_max_dims (
195
+ cur_t , shard_size [0 ], shard_size [1 ], remove_padding = True
196
+ )
197
+ shard_name_to_local_output_tensor [shard_name ] = cur_t
194
198
slice_index = end_slice_index
195
199
196
200
for shard_name , param in new_sharding_params .items ():
@@ -234,3 +238,57 @@ def update_module_sharding_plan(
234
238
for table_name , param_sharding in changed_sharding_params .items ():
235
239
current_plan [table_name ] = param_sharding
236
240
return
241
+
242
+
243
+ def get_largest_dim_sizes (
244
+ state_dict : Dict [str , ShardedTensor ],
245
+ ) -> Tuple [int , int ]:
246
+ """
247
+ Returns the largest dimension size of dim 0 and 1 across all tables in a module.
248
+
249
+ Args:
250
+ state_dict (Dict[str, ShardedTensor]): The state dict containing the sharded tensors.
251
+
252
+ Returns:
253
+ List[int]: A list of the largest dimension size of each table in the state_dict.
254
+ """
255
+ max_dim_0 = 0
256
+ max_dim_1 = 0
257
+ for sharded_t in state_dict .values ():
258
+ for shard in sharded_t .metadata ().shards_metadata :
259
+ max_dim_0 = max (max_dim_0 , shard .shard_sizes [0 ])
260
+ max_dim_1 = max (max_dim_1 , shard .shard_sizes [1 ])
261
+
262
+ return max_dim_0 , max_dim_1
263
+
264
+
265
+ def pad_tensor_to_max_dims (
266
+ t : torch .Tensor ,
267
+ expected_dim_0 : int ,
268
+ expected_dim_1 : int ,
269
+ remove_padding : bool = False ,
270
+ ) -> torch .Tensor :
271
+ """
272
+ Pads a tensor on the right and bottom with zeros.
273
+
274
+ Args:
275
+ tensor (torch.Tensor): The tensor to be padded.
276
+ pad_right (int): The number of zeros to pad on the right.
277
+ pad_bottom (int): The number of zeros to pad on the bottom.
278
+
279
+ Returns:
280
+ torch.Tensor: The padded tensor.
281
+ """
282
+ pad_right = expected_dim_1 - t .size (1 )
283
+ pad_bottom = expected_dim_0 - t .size (0 )
284
+ return F .pad (
285
+ input = t ,
286
+ pad = (
287
+ 0 ,
288
+ pad_right ,
289
+ 0 ,
290
+ pad_bottom ,
291
+ ), # right and bottom
292
+ mode = "constant" ,
293
+ value = 0 ,
294
+ )
0 commit comments