@@ -87,6 +87,7 @@ def __init__(
87
87
overlap_broadcast = False ,
88
88
grad_scal_cfg : Config = None ,
89
89
zero_cfg : Config = None ,
90
+ use_fp16 : bool = True ,
90
91
):
91
92
# DynamicGradScaler related args
92
93
initial_scale = grad_scal_cfg .fp16 .initial_scale
@@ -104,6 +105,7 @@ def __init__(
104
105
105
106
super ().__init__ (optim = optimizer )
106
107
108
+ self .use_fp16 = use_fp16
107
109
self ._dtype = self .optim .param_groups [0 ]["params" ][0 ].dtype
108
110
self ._cpu_offload = cpu_offload
109
111
self ._zero_local_rank = gpc .get_local_rank (ParallelMode .ZERO1 )
@@ -125,14 +127,18 @@ def __init__(
125
127
self ._reduce_bucket_size = reduce_bucket_size
126
128
127
129
# gradient scaler
128
- self .grad_scaler = DynamicGradScaler (
129
- initial_scale = initial_scale ,
130
- min_scale = min_scale ,
131
- growth_factor = growth_factor ,
132
- backoff_factor = backoff_factor ,
133
- growth_interval = growth_interval ,
134
- hysteresis = hysteresis ,
135
- max_scale = max_scale ,
130
+ self .grad_scaler = (
131
+ DynamicGradScaler (
132
+ initial_scale= initial_scale ,
133
+ min_scale = min_scale ,
134
+ growth_factor = growth_factor ,
135
+ backoff_factor = backoff_factor ,
136
+ growth_interval = growth_interval ,
137
+ hysteresis = hysteresis ,
138
+ max_scale = max_scale ,
139
+ )
140
+ if self .use_fp16
141
+ else None
136
142
)
137
143
self ._found_overflow = torch .cuda .FloatTensor ([0 ], device = get_current_device ())
138
144
@@ -176,11 +182,14 @@ def __init__(
176
182
for param in params :
177
183
self ._param_store .set_param_to_rank (param , rank )
178
184
185
+ # flatten the reordered tensors
179
186
# move to cpu to make room to create the flat tensor
187
+ # Even for fp32 training, we will still flattend the tensor,
188
+ # which will not increase the use of GPU memory,
189
+ # and can improve the efficiency of broadcasting.
180
190
for param in group_params :
181
191
param .data = param .data .cpu ()
182
192
183
- # flatten the reordered tensors
184
193
for rank in range (self ._zero_world_size ):
185
194
# No flat fp16 buffer is allocated if the process has no parameters.
186
195
if rank not in self .param_group_no_params_ranks [group_id ]:
@@ -194,19 +203,25 @@ def __init__(
194
203
# create a copy of fp32 weights of the parameters for which this rank is responsible
195
204
# No flat fp32 buffer is allocated if the process has no parameters.
196
205
if self .param_group_has_params [group_id ]:
197
- fp16_flat_current_rank = self ._param_store .get_flat_fp16_param_by_rank_group (
198
- self ._zero_local_rank , group_id
199
- )
200
- fp32_flat_current_rank = fp16_flat_current_rank .float ()
201
- device = "cpu" if self ._cpu_offload else get_current_device ()
202
- fp32_flat_current_rank = fp32_flat_current_rank .to (device )
203
- fp32_flat_current_rank .requires_grad = True
204
- self ._fp32_flat_param_groups_of_current_rank [group_id ] = fp32_flat_current_rank
205
-
206
- # need to replace the params in the `params` field in the optimizer
207
- # so that when the optimizer calls step(), it only updates the tensors
208
- # managed by this data parallel rank
209
- param_group ["params" ] = [fp32_flat_current_rank ]
206
+ if self .use_fp16 :
207
+ fp16_flat_current_rank = self ._param_store .get_flat_fp16_param_by_rank_group (
208
+ self ._zero_local_rank , group_id
209
+ )
210
+ fp32_flat_current_rank = fp16_flat_current_rank .float ()
211
+ device = "cpu" if self ._cpu_offload else get_current_device ()
212
+ fp32_flat_current_rank = fp32_flat_current_rank .to (device )
213
+ fp32_flat_current_rank .requires_grad = True
214
+ self ._fp32_flat_param_groups_of_current_rank [group_id ] = fp32_flat_current_rank
215
+
216
+ # need to replace the params in the `params` field in the optimizer
217
+ # so that when the optimizer calls step(), it only updates the tensors
218
+ # managed by this data parallel rank
219
+ param_group ["params" ] = [fp32_flat_current_rank ]
220
+ else :
221
+ # use fp32
222
+ param_group ["params" ] = self ._param_store .get_fp16_params_by_rank_group (
223
+ self ._zero_local_rank , group_id
224
+ )
210
225
211
226
# set reduction state
212
227
for param in self ._fp16_param_groups [group_id ]:
@@ -243,7 +258,10 @@ def dtype(self):
243
258
244
259
@property
245
260
def loss_scale (self ):
246
- return self .grad_scaler .scale
261
+ if self .grad_scaler is None :
262
+ return 1
263
+ else :
264
+ return self .grad_scaler .scale
247
265
248
266
@property
249
267
def num_param_groups (self ):
@@ -533,7 +551,8 @@ def _step(self, closure=None):
533
551
norm_groups .append (norm_group )
534
552
535
553
loss_scale = float (self .loss_scale .item ()) # backup
536
- self .grad_scaler .update (found_inf )
554
+ if self .grad_scaler :
555
+ self .grad_scaler .update (found_inf )
537
556
# update loss scale if overflow occurs
538
557
if found_inf :
539
558
if gpc .is_rank_for_log ():
@@ -552,21 +571,30 @@ def _step(self, closure=None):
552
571
continue
553
572
gradients = self ._grad_store .get_averaged_gradients_by_group (group_id )
554
573
555
- # create flat gradient for the flat fp32 params
556
- fp16_avg_grads = gradients
557
- flat_fp16_avg_grads = flatten (fp16_avg_grads )
574
+ if self .use_fp16 :
575
+ # create flat gradient for the flat fp32 params
576
+ fp16_avg_grads = gradients
577
+ flat_fp16_avg_grads = flatten (fp16_avg_grads )
558
578
559
- dtype = self ._fp32_flat_param_groups_of_current_rank [group_id ].dtype
560
- flat_fp32_avg_grads = flat_fp16_avg_grads .to (dtype )
579
+ dtype = self ._fp32_flat_param_groups_of_current_rank [group_id ].dtype
580
+ flat_fp32_avg_grads = flat_fp16_avg_grads .to (dtype )
561
581
562
- param_shape = self ._fp32_flat_param_groups_of_current_rank [group_id ].shape
563
- assert (
564
- param_shape == flat_fp32_avg_grads .shape
565
- ), f"fp32 param and grad have different shape { param_shape } vs { flat_fp32_avg_grads .shape } "
582
+ param_shape = self ._fp32_flat_param_groups_of_current_rank [group_id ].shape
583
+ assert (
584
+ param_shape == flat_fp32_avg_grads .shape
585
+ ), f"fp32 param and grad have different shape { param_shape } vs { flat_fp32_avg_grads .shape } "
586
+
587
+ single_grad_partition_groups .append (flat_fp32_avg_grads )
588
+ device = self ._fp32_flat_param_groups_of_current_rank [group_id ].device
589
+ self ._fp32_flat_param_groups_of_current_rank [group_id ].grad = flat_fp32_avg_grads .to (device )
590
+ else :
591
+ assert len (gradients ) == len (self .optim .param_groups [group_id ]["params" ]), (
592
+ len (gradients ),
593
+ len (self .optim .param_groups [group_id ]["params" ]),
594
+ )
595
+ for g , p in zip (gradients , self .optim .param_groups [group_id ]["params" ]):
596
+ p .grad = g
566
597
567
- single_grad_partition_groups .append (flat_fp32_avg_grads )
568
- device = self ._fp32_flat_param_groups_of_current_rank [group_id ].device
569
- self ._fp32_flat_param_groups_of_current_rank [group_id ].grad = flat_fp32_avg_grads .to (device )
570
598
self ._grad_store ._averaged_gradients [group_id ] = []
571
599<
10000
/code>
self ._grad_store ._averaged_gradients [group_id ] = []
572
600
@@ -576,8 +604,9 @@ def _step(self, closure=None):
576
604
global_norm = sum (norm_groups ) ** 0.5
577
605
578
606
# the following operations are performed only on the rank to which parameters are assigned.
579
- if len (single_grad_partition_groups ) != 0 :
580
- self ._unscale_and_clip_grads (single_grad_partition_groups , global_norm , loss_scale )
607
+ if self .use_fp16 :
608
+ if len (single_grad_partition_groups ) != 0 :
609
+ self ._unscale_and_clip_grads (single_grad_partition_groups , global_norm , loss_scale )
581
610
582
611
timer ("cal_norm" ).stop ()
583
612
# update the parameters
@@ -588,15 +617,16 @@ def _step(self, closure=None):
588
617
if self .has_params :
589
618
self .optim .step ()
590
619
# release the fp32 grad
591
- release_param_grad (self ._fp32_flat_param_groups_of_current_rank .values ())
592
- # update fp16 partition updated by the current rank
593
- for group_id in range (len (self ._fp16_param_groups )):
594
- if self .param_group_has_params [group_id ]:
595
- fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (
596
- rank = self ._zero_local_rank , group_id = group_id
597
- )
598
- fp32_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
599
- fp16_param .data .copy_ (fp32_param )
620
+ if self .use_fp16 :
621
+ release_param_grad (self ._fp32_flat_param_groups_of_current_rank .values ())
622
+ # update fp16 partition updated by the current rank
623
+ for group_id in range (len (self ._fp16_param_groups )):
624
+ if self .param_group_has_params [group_id ]:
625
+ fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (
626
+ rank = self ._zero_local_rank , group_id = group_id
627
+ )
628
+ fp32_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
629
+ fp16_param .data .copy_ (fp32_param )
600
630
601
631
# TODO: support broadcast overlap
602
632
self .broadcast_params (overlap = False )
@@ -614,8 +644,6 @@ def broadcast_params(self, overlap=False):
614
644
# The following operations are performed only on the rank to which parameters are assigned.
615
645
if rank not in self .param_group_no_params_ranks [group_id ]:
616
646
fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (rank = rank , group_id = group_id )
617
7802
- # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
618
- # assert grank == rank, f"{grank} == {rank}"
619
647
g_rank = gpc .get_ranks_in_group (self ._broadcast_parallel_mode )[rank ]
620
648
handle = dist .broadcast (
621
649
fp16_param , src = g_rank , group = gpc .get_group (ParallelMode .ZERO1 ), async_op = True
@@ -667,48 +695,52 @@ def clip_grad_norm(self, model, max_norm):
667
695
668
696
def state_dict (self ):
669
697
states = {}
670
- grad_scaler = self .grad_scaler .state_dict ()
671
- states ["grad_scaler" ] = grad_scaler
672
698
optim_states = self .optim .state_dict ()
673
699
states ["base_optim_states" ] = optim_states
674
700
675
- flat_fp32_weights = {}
676
- for group_id , param in self ._fp32_flat_param_groups_of_current_rank .items ():
677
- if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
678
- assert param .grad is None
679
- flat_fp32_weights [group_id ] = param
680
- states ["flat_fp32_weights" ] = flat_fp32_weights
701
+ if self .use_fp16 :
702
+ grad_scaler = self .grad_scaler .state_dict ()
703
+ states ["grad_scaler" ] = grad_scaler
704
+
705
+ flat_fp32_weights = {}
706
+ for group_id , param in self ._fp32_flat_param_groups_of_current_rank .items ():
707
+ if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
708
+ assert param .grad is None
709
+ flat_fp32_weights [group_id ] = param
710
+ states ["flat_fp32_weights" ] = flat_fp32_weights
681
711
states ["zero_devide_optim_plan" ] = self .params_per_rank_id_dict
682
712
683
713
return states
684
714
685
715
def load_state_dict (self , states ):
686
716
# TODO: Need to take into account the change in the number of DP.
687
- assert "grad_scaler" in states , "Not found grad_scaler state!"
688
- grad_scaler = states ["grad_scaler" ]
689
- self .grad_scaler .load_state_dict (grad_scaler )
690
717
optim_states = states ["base_optim_states" ]
691
718
self .optim .load_state_dict (optim_states )
692
719
693
- # load fp32 model weight.
694
- flat_fp32_weights = states ["flat_fp32_weights" ]
695
- assert set (flat_fp32_weights .keys ()) == set (self ._fp32_flat_param_groups_of_current_rank )
696
- for group_id , param in flat_fp32_weights .items ():
<
F438
/code>
697
- if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
698
- self_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
699
- assert (
700
- self_param .shape == param .shape
701
- ), f"The loaded parameter shape is inconsistent, { self_param .shape } != { param .shape } "
702
- self_param .data .copy_ (param .data )
703
-
704
- # Load the fp16 model weights.
705
- for group_id in range (len (self ._fp16_param_groups )):
706
- if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
707
- fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (
708
- rank = self ._zero_local_rank , group_id = group_id
709
- )
710
- fp32_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
711
- fp16_param .data .copy_ (fp32_param )
720
+ if self .use_fp16 :
721
+ assert "grad_scaler" in states , "Not found grad_scaler state!"
722
+ grad_scaler = states ["grad_scaler" ]
723
+ self .grad_scaler .load_state_dict (grad_scaler )
724
+
725
+ # load fp32 model weight.
726
+ flat_fp32_weights = states ["flat_fp32_weights" ]
727
+ assert set (flat_fp32_weights .keys ()) == set (self ._fp32_flat_param_groups_of_current_rank )
728
+ for group_id , param in flat_fp32_weights .items ():
729
+ if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
730
+ self_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
731
+ assert (
732
+ self_param .shape == param .shape
733
+ ), f"The loaded parameter shape is inconsistent, { self_param .shape } != { param .shape } "
734
+ self_param .data .copy_ (param .data )
735
+
736
+ # Load the fp16 model weights.
737
+ for group_id in range (len (self ._fp16_param_groups )):
738
+ if self ._zero_local_rank not in self .param_group_no_params_ranks [group_id ]:
739
+ fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (
740
+ rank = self ._zero_local_rank , group_id = group_id
741
+ )
742
+ fp32_param = self ._fp32_flat_param_groups_of_current_rank [group_id ]
743
+ fp16_param .data .copy_ (fp32_param )
712
744
713
745
if "zero_devide_optim_plan" in states :
714
746
self .params_per_rank_id_dict = states ["zero_devide_optim_plan" ]
0 commit comments