@@ -290,35 +290,37 @@ def foreach_all_gather_copy_out(
290
290
out = [t .view (world_size , - 1 ).view (torch .uint8 ) for t in split_with_sizes_out ]
291
291
else :
292
292
out = [t .view (world_size , - 1 ) for t in split_with_sizes_out ]
293
- torch .ops .fsdp .split_with_sizes_copy (
294
- all_gather_output , all_gather_input_split_sizes , dim = 1 , out = out
295
- )
293
+ with torch .autograd ._unsafe_preserve_version_counter (tuple (out )):
294
+ torch .ops .fsdp .split_with_sizes_copy (
295
+ all_gather_output , all_gather_input_split_sizes , dim = 1 , out = out
296
+ )
296
297
297
298
for fsdp_param , param_all_gather_outputs in shard_i_copy_infos :
298
299
# Chunk-cat from the temporary to the final all-gather output tensors
299
300
shard_dim = fsdp_param .fsdp_placement .dim
300
- for param_all_gather_output , target_all_gather_output in zip (
301
- param_all_gather_outputs , fsdp_param .all_gather_outputs
301
+
302
+ with torch .autograd ._unsafe_preserve_version_counter (
303
+ tuple (fsdp_param .all_gather_outputs )
302
304
):
303
- padded_sharded_size = (
304
- fsdp_param .padded_sharded_param_size
305
- if fsdp_param . sharded_state == ShardedState . SHARDED
306
- else cast (
307
- torch . Tensor , fsdp_param ._sharded_post_forward_param_data
308
- ). size ()
309
- )
310
- pre_param_size = list ( padded_sharded_size )
311
- pre_param_size [ 0 ] *= world_size
312
- chunks = torch . chunk (
313
- param_all_gather_output . view ( pre_param_size ), world_size , dim = 0
314
- )
315
- post_param_size = list ( padded_sharded_size )
316
- post_param_size [ shard_dim ] *= world_size
317
- cat_out = target_all_gather_output . view ( post_param_size )
318
- torch . cat ( chunks , dim = shard_dim , out = cat_out )
319
- torch . _C . _autograd . _unsafe_set_version_counter (
320
- target_all_gather_output , target_all_gather_output ._version - 1
321
- )
305
+ for param_all_gather_output , target_all_gather_output in zip (
306
+ param_all_gather_outputs , fsdp_param .all_gather_outputs
307
+ ):
308
+ padded_sharded_size = (
309
+ fsdp_param .padded_sharded_param_size
310
+ if fsdp_param . sharded_state == ShardedState . SHARDED
311
+ else cast (
312
+ torch . Tensor , fsdp_param . _sharded_post_forward_param_data
313
+ ). size ()
314
+ )
315
+ pre_param_size = list ( padded_sharded_size )
316
+ pre_param_size [ 0 ] *= world_size
317
+ chunks = torch . chunk (
318
+ param_all_gather_output . view ( pre_param_size ), world_size , dim = 0
319
+ )
320
+ post_param_size = list ( padded_sharded_size )
321
+ post_param_size [ shard_dim ] *= world_size
322
+ cat_out = target_all_gather_output .view ( post_param_size )
323
+ torch . cat ( chunks , dim = shard_dim , out = cat_out )
322
324
323
325
324
326
@torch .no_grad ()
0 commit comments