2
2
3
3
import sys
4
4
from typing import Optional , List , cast
5
- from torch .distributed ._shard . checkpoint .storage import WriteResult
5
+ from torch .distributed .checkpoint .storage import WriteResult
6
6
7
- from torch .distributed ._shard . checkpoint import (
7
+ from torch .distributed .checkpoint import (
8
8
StorageReader ,
9
9
StorageWriter ,
10
10
CheckpointException ,
63
63
)
64
64
sys .exit (0 )
65
65
66
+
66
67
class TestModule (torch .nn .Module ):
67
68
def __init__ (self ) -> None :
68
69
super ().__init__ ()
@@ -121,34 +122,44 @@ def test_default_metadata(self) -> None:
121
122
)
122
123
123
124
state_dict = {
124
- 'sharded' : sharded_tensor .rand (spec , (10 , 10 , )),
125
- 'replicated' : torch .rand (4 , device = device ),
126
- 'bytes' : [1 , 2 , 3 , 4 ],
125
+ "sharded" : sharded_tensor .rand (
126
+ spec ,
127
+ (
128
+ 10 ,
129
+ 10 ,
130
+ ),
131
+ ),
132
+ "replicated" : torch .rand (4 , device = device ),
133
+ "bytes" : [1 , 2 , 3 , 4 ],
127
134
}
128
135
129
136
metadata = _create_default_local_metadata (state_dict )
130
- self .assertTrue ('bytes' in metadata .state_dict_metadata )
131
- self .assertIsInstance (metadata .state_dict_metadata ['bytes' ], BytesStorageMetadata )
137
+ self .assertTrue ("bytes" in metadata .state_dict_metadata )
138
+ self .assertIsInstance (
139
+ metadata .state_dict_metadata ["bytes" ], BytesStorageMetadata
140
+ )
132
141
133
- self .assertTrue ('replicated' in metadata .state_dict_metadata )
134
- self .assertIsInstance (metadata .state_dict_metadata ['replicated' ], TensorStorageMetadata )
135
- md = metadata .state_dict_metadata ['replicated' ]
136
- self .assertEqual (md .size , state_dict ['replicated' ].size ())
142
+ self .assertTrue ("replicated" in metadata .state_dict_metadata )
143
+ self .assertIsInstance (
144
+ metadata .state_dict_metadata ["replicated" ], TensorStorageMetadata
145
+ )
146
+ md = metadata .state_dict_metadata ["replicated" ]
147
+ self .assertEqual (md .size , state_dict ["replicated" ].size ())
137
148
self .assertEqual (md .properties .dtype , torch .float32 )
138
149
self .assertEqual (1 , len (md .chunks ))
139
150
140
- self .assertTrue ('sharded' in metadata .state_dict_metadata )
141
- self .assertIsInstance (metadata .state_dict_metadata ['sharded' ], TensorStorageMetadata )
142
- md = metadata .state_dict_metadata ['sharded' ]
151
+ self .assertTrue ("sharded" in metadata .state_dict_metadata )
152
+ self .assertIsInstance (
153
+ metadata .state_dict_metadata ["sharded" ], TensorStorageMetadata
154
+ )
155
+ md = metadata .state_dict_metadata ["sharded" ]
143
156
self .assertEqual (md .properties .dtype , torch .float32 )
144
- self .assertEqual (md .size , state_dict [' sharded' ].size ())
157
+ self .assertEqual (md .size , state_dict [" sharded" ].size ())
145
158
self .assertEqual (2 , len (md .chunks ))
146
159
160
+
147
161
class TestStorageBase :
148
- def __init__ (
149
- self ,
150
- fail_conf
151
- ):
162
+ def __init__ (self , fail_conf ):
152
163
self .fail_conf = fail_conf
153
164
self .rank = 0 if not dist .is_initialized () else dist .get_rank ()
154
165
@@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None):
164
175
ranks = self ._get_ranks (name )
165
176
fut = Future ()
166
177
if ranks is not None and self .rank in ranks :
167
- fut .set_exception (ValueError (f"async rank fail { self .rank } for { name } " ))
178
+ fut .set_exception (
179
+ ValueError (f"async rank fail { self .rank } for { name } " )
180
+ )
168
181
else :
169
182
fut .set_result (result )
170
183
return fut
171
184
185
+
172
186
class FaultyStorageWriter (TestStorageBase , StorageWriter ):
173
- def __init__ (
174
- self ,
175
- fail_conf
176
- ):
187
+ def __init__ (self , fail_conf ):
177
188
super (FaultyStorageWriter , self ).__init__ (fail_conf )
178
189
179
190
def init (self , is_coordinator : bool ) -> None :
@@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
188
199
return plans
189
200
190
201
def write_data (
191
- self ,
192
- plan : SavePlan ,
193
- planner : SavePlanner
202
+ self , plan : SavePlan , planner : SavePlanner
194
203
) -> Future [List [WriteResult ]]:
195
204
self ._fail_rank ("fail_write_data" )
196
205
return self ._fail_rank_async ("fail_write_data_async" , [])
197
206
198
- def finish (self , metadata : Metadata , results : List [List [WriteResult ]]) -> None :
207
+ def finish (
208
+ self , metadata : Metadata , results : List [List [WriteResult ]]
209
+ ) -> None :
199
210
self ._fail_rank ("fail_finish" )
200
211
201
212
202
213
class FaultyStorageReader (TestStorageBase , StorageReader ):
203
- def __init__ (
204
- self ,
205
- metadata ,
206
- fail_conf
207
- ):
214
+ def __init__ (self , metadata , fail_conf ):
208
215
super (FaultyStorageReader , self ).__init__ (fail_conf )
209
216
self .metadata = metadata
210
217
@@ -219,35 +226,32 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
219
226
self ._fail_rank ("fail_prepare_global_plan" )
220
227
return plans
221
228
222
- def read_data (
223
- self ,
224
- plan : LoadPlan ,
225
- planner : LoadPlanner
226
- ) -> Future [None ]:
229
+ def read_data (self , plan : LoadPlan , planner : LoadPlanner ) -> Future [None ]:
227
230
self ._fail_rank ("fail_read_data" )
228
231
return self ._fail_rank_async ("fail_read_data_async" )
229
232
230
233
def read_metadata (self ) -> Metadata :
231
234
self ._fail_rank ("fail_read_metadata" )
232
235
return self .metadata
233
236
237
+
234
238
class TestDistributedFailure (ShardedTensorTestBase ):
235
239
def get_spec (self ):
236
240
return ChunkShardingSpec (
237
241
dim = 0 ,
238
242
placements = [
239
243
f"rank:{ r } /cuda:{ r } " for r in range (dist .get_world_size ())
240
- ]
244
+ ],
241
245
)
242
246
243
247
@with_comms (init_rpc = False )
244
248
@skip_if_lt_x_gpu (2 )
245
249
@requires_nccl ()
246
250
def test_dummy_writer_works (self ) -> None :
247
251
state_dict = {
248
- ' sharded' : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
249
- ' replicated' : torch .rand (10 , 10 ),
250
- ' bytes' : [1 , 2 , 3 , 4 ]
252
+ " sharded" : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
253
+ " replicated" : torch .rand (10 , 10 ),
254
+ " bytes" : [1 , 2 , 3 , 4 ],
251
255
}
252
256
253
257
save_state_dict (state_dict , FaultyStorageWriter ({}))
@@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None:
257
261
@requires_nccl ()
258
262
def test_dummy_reader_works (self ) -> None :
259
263
state_dict = {
260
- ' sharded' : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
261
- ' replicated' : torch .rand (10 , 10 ),
262
- ' bytes' : [1 , 2 , 3 , 4 ]
264
+ " sharded" : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
265
+ " replicated" : torch .rand (10 , 10 ),
266
+ " bytes" : [1 , 2 , 3 , 4 ],
263
267
}
264
268
metadata = _create_default_local_metadata (state_dict )
265
269
@@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs):
283
287
284
288
failed_ranks = e .failures .keys ()
285
289
for rank in bad_ranks :
286
- self .assertTrue (rank in failed_ranks , msg = f"{ rank } was supposed to fail was fine" )
287
-
290
+ self .assertTrue (
291
+ rank in failed_ranks ,
292
+ msg = f"{ rank } was supposed to fail was fine" ,
293
+ )
288
294
289
295
def _test_save (self , state_dict , coordinator = 0 , ** kwargs ):
290
296
no_dist = not dist .is_initialized ()
@@ -296,6 +302,7 @@ def _save():
296
302
coordinator_rank = coordinator ,
297
303
no_dist = no_dist ,
298
304
)
305
+
299
306
self ._test_dist_failure (_save , kwargs )
300
307
301
308
def _test_load (self , state_dict , coordinator = 0 , ** kwargs ):
@@ -317,9 +324,9 @@ def _load():
317
324
@requires_nccl ()
318
325
def test_save_error_handling (self ) -> None :
319
326
state_dict = {
320
- ' sharded' : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
321
- ' replicated' : torch .rand (10 , 10 ),
322
- ' bytes' : [1 , 2 , 3 , 4 ]
327
+ " sharded" : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
328
+ " replicated" : torch .rand (10 , 10 ),
329
+ " bytes" : [1 , 2 , 3 , 4 ],
323
330
}
324
331
325
332
self ._test_save (state_dict , fail_init = [0 ])
@@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None:
334
341
self ._test_save (state_dict , coordinator = 1 , fail_finish = [1 ])
335
342
336
343
def test_save_error_handling_no_dist (self ) -> None :
337
- state_dict = {
338
- 'replicated' : torch .rand (10 , 10 ),
339
- 'bytes' : [1 , 2 , 3 , 4 ]
340
- }
344
+ state_dict = {"replicated" : torch .rand (10 , 10 ), "bytes" : [1 , 2 , 3 , 4 ]}
341
345
342
346
self .assertFalse (dist .is_initialized ())
343
347
@@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None:
354
358
@requires_nccl ()
355
359
def test_load_error_handling (self ) -> None :
356
360
state_dict = {
357
- ' sharded' : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
358
- ' replicated' : torch .rand (10 , 10 ),
359
- ' bytes' : [1 , 2 , 3 , 4 ]
361
+ " sharded" : sharded_tensor .rand (self .get_spec (), 20 , 20 ),
362
+ " replicated" : torch .rand (10 , 10 ),
363
+ " bytes" : [1 , 2 , 3 , 4 ],
360
364
}
361
365
362
366
self ._test_load (state_dict )
@@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None:
373
377
self ._test_load (state_dict , coordinator = 3 , fail_read_data_async = [2 ])
374
378
self ._test_load (state_dict , coordinator = 1 , fail_prepare_global_plan = [1 ])
375
379
376
-
377
380
def test_load_error_handling_no_dist (self ) -> None :
378
- state_dict = {
379
- 'replicated' : torch .rand (10 , 10 ),
380
- 'bytes' : [1 , 2 , 3 , 4 ]
381
- }
381
+ state_dict = {"replicated" : torch .rand (10 , 10 ), "bytes" : [1 , 2 , 3 , 4 ]}
382
382
self ._test_load (state_dict )
383
383
self ._test_load (state_dict , fail_init = [0 ])
384
384
self ._test_load (state_dict , fail_read_metadata = [0 ])
@@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None:
387
387
self ._test_load (state_dict , fail_read_data = [0 ])
388
388
self ._test_load (state_dict , fail_read_data_async = [0 ])
389
389
390
+
390
391
if __name__ == "__main__" :
391
392
run_tests ()
0 commit comments