8000 Merge pull request #2 from SolenoidWGT/fp32_zero · InternLM/InternLM@53fc50b · GitHub
[go: up one dir, main page]

Skip to content

Commit 53fc50b

Browse files
Merge pull request #2 from SolenoidWGT/fp32_zero
feat(optim): add support for fp32 zero
2 parents 40f24d0 + 72b27b0 commit 53fc50b

File tree

2 files changed

+115
-80
lines changed

2 files changed

+115
-80
lines changed

internlm/solver/optimizer/hybrid_zero_optim.py

Lines changed: 111 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
overlap_broadcast=False,
8888
grad_scal_cfg: Config = None,
8989
zero_cfg: Config = None,
90+
use_fp16: bool = True,
9091
):
9192
# DynamicGradScaler related args
9293
initial_scale = grad_scal_cfg.fp16.initial_scale
@@ -104,6 +105,7 @@ def __init__(
104105

105106
super().__init__(optim=optimizer)
106107

108+
self.use_fp16 = use_fp16
107109
self._dtype = self.optim.param_groups[0]["params"][0].dtype
108110
self._cpu_offload = cpu_offload
109111
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@@ -125,14 +127,18 @@ def __init__(
125127
self._reduce_bucket_size = reduce_bucket_size
126128

127129
# gradient scaler
128-
self.grad_scaler = DynamicGradScaler(
129-
initial_scale=initial_scale,
130-
min_scale=min_scale,
131-
growth_factor=growth_factor,
132-
backoff_factor=backoff_factor,
133-
growth_interval=growth_interval,
134-
hysteresis=hysteresis,
135-
max_scale=max_scale,
130+
self.grad_scaler = (
131+
DynamicGradScaler(
132+
initial_scale=initial_scale,
133+
min_scale=min_scale,
134+
growth_factor=growth_factor,
135+
backoff_factor=backoff_factor,
136+
growth_interval=growth_interval,
137+
hysteresis=hysteresis,
138+
max_scale=max_scale,
139+
)
140+
if self.use_fp16
141+
else None
136142
)
137143
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
138144

@@ -176,11 +182,14 @@ def __init__(
176182
for param in params:
177183
self._param_store.set_param_to_rank(param, rank)
178184

185+
# flatten the reordered tensors
179186
# move to cpu to make room to create the flat tensor
187+
# Even for fp32 training, we will still flattend the tensor,
188+
# which will not increase the use of GPU memory,
189+
# and can improve the efficiency of broadcasting.
180190
for param in group_params:
181191
param.data = param.data.cpu()
182192

183-
# flatten the reordered tensors
184193
for rank in range(self._zero_world_size):
185194
# No flat fp16 buffer is allocated if the process has no parameters.
186195
if rank not in self.param_group_no_params_ranks[group_id]:
@@ -194,19 +203,25 @@ def __init__(
194203
# create a copy of fp32 weights of the parameters for which this rank is responsible
195204
# No flat fp32 buffer is allocated if the process has no parameters.
196205
if self.param_group_has_params[group_id]:
197-
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
198-
self._zero_local_rank, group_id
199-
)
200-
fp32_flat_current_rank = fp16_flat_current_rank.float()
201-
device = "cpu" if self._cpu_offload else get_current_device()
202-
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
203-
fp32_flat_current_rank.requires_grad = True
204-
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
205-
206-
# need to replace the params in the `params` field in the optimizer
207-
# so that when the optimizer calls step(), it only updates the tensors
208-
# managed by this data parallel rank
209-
param_group["params"] = [fp32_flat_current_rank]
206+
if self.use_fp16:
207+
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
208+
self._zero_local_rank, group_id
209+
)
210+
fp32_flat_current_rank = fp16_flat_current_rank.float()
211+
device = "cpu" if self._cpu_offload else get_current_device()
212+
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
213+
fp32_flat_current_rank.requires_grad = True
214+
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
215+
216+
# need to replace the params in the `params` field in the optimizer
217+
# so that when the optimizer calls step(), it only updates the tensors
218+
# managed by this data parallel rank
219+
param_group["params"] = [fp32_flat_current_rank]
220+
else:
221+
# use fp32
222+
param_group["params"] = self._param_store.get_fp16_params_by_rank_group(
223+
self._zero_local_rank, group_id
224+
)
210225

211226
# set reduction state
212227
for param in self._fp16_param_groups[group_id]:
@@ -243,7 +258,10 @@ def dtype(self):
243258

244259
@property
245260
def loss_scale(self):
246-
return self.grad_scaler.scale
261+
if self.grad_scaler is None:
262+
return 1
263+
else:
264+
return self.grad_scaler.scale
247265

248266
@property
249267
def num_param_groups(self):
@@ -533,7 +551,8 @@ def _step(self, closure=None):
533551
norm_groups.append(norm_group)
534552

535553
loss_scale = float(self.loss_scale.item()) # backup
536-
self.grad_scaler.update(found_inf)
554+
if self.grad_scaler:
555+
self.grad_scaler.update(found_inf)
537556
# update loss scale if overflow occurs
538557
if found_inf:
539558
if gpc.is_rank_for_log():
@@ -552,21 +571,30 @@ def _step(self, closure=None):
552571
continue
553572
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
554573

555-
# create flat gradient for the flat fp32 params
556-
fp16_avg_grads = gradients
557-
flat_fp16_avg_grads = flatten(fp16_avg_grads)
574+
if self.use_fp16:
575+
# create flat gradient for the flat fp32 params
576+
fp16_avg_grads = gradients
577+
flat_fp16_avg_grads = flatten(fp16_avg_grads)
558578

559-
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
560-
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
579+
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
580+
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
561581

562-
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
563-
assert (
564-
param_shape == flat_fp32_avg_grads.shape
565-
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
582+
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
583+
assert (
584+
param_shape == flat_fp32_avg_grads.shape
585+
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
586+
587+
single_grad_partition_groups.append(flat_fp32_avg_grads)
588+
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
589+
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
590+
else:
591+
assert len(gradients) == len(self.optim.param_groups[group_id]["params"]), (
592+
len(gradients),
593+
len(self.optim.param_groups[group_id]["params"]),
594+
)
595+
for g, p in zip(gradients, self.optim.param_groups[group_id]["params"]):
596+
p.grad = g
566597

567-
single_grad_partition_groups.append(flat_fp32_avg_grads)
568-
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
569-
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
570598
self._grad_store._averaged_gradients[group_id] = []
571599< 10000 /code>
self._grad_store._averaged_gradients[group_id] = []
572600

@@ -576,8 +604,9 @@ def _step(self, closure=None):
576604
global_norm = sum(norm_groups) ** 0.5
577605

578606
# the following operations are performed only on the rank to which parameters are assigned.
579-
if len(single_grad_partition_groups) != 0:
580-
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
607+
if self.use_fp16:
608+
if len(single_grad_partition_groups) != 0:
609+
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
581610

582611
timer("cal_norm").stop()
583612
# update the parameters
@@ -588,15 +617,16 @@ def _step(self, closure=None):
588617
if self.has_params:
589618
self.optim.step()
590619
# release the fp32 grad
591-
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
592-
# update fp16 partition updated by the current rank
593-
for group_id in range(len(self._fp16_param_groups)):
594-
if self.param_group_has_params[group_id]:
595-
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
596-
rank=self._zero_local_rank, group_id=group_id
597-
)
598-
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
599-
fp16_param.data.copy_(fp32_param)
620+
if self.use_fp16:
621+
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
622+
# update fp16 partition updated by the current rank
623+
for group_id in range(len(self._fp16_param_groups)):
624+
if self.param_group_has_params[group_id]:
625+
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
626+
rank=self._zero_local_rank, group_id=group_id
627+
)
628+
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
629+
fp16_param.data.copy_(fp32_param)
600630

601631
# TODO: support broadcast overlap
602632
self.broadcast_params(overlap=False)
@@ -614,8 +644,6 @@ def broadcast_params(self, overlap=False):
614644
# The following operations are performed only on the rank to which parameters are assigned.
615645
if rank not in self.param_group_no_params_ranks[group_id]:
616646
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
617 7802 -
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
618-
# assert grank == rank, f"{grank} == {rank}"
619647
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
620648
handle = dist.broadcast(
621649
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
@@ -667,48 +695,52 @@ def clip_grad_norm(self, model, max_norm):
667695

668696
def state_dict(self):
669697
states = {}
670-
grad_scaler = self.grad_scaler.state_dict()
671-
states["grad_scaler"] = grad_scaler
672698
optim_states = self.optim.state_dict()
673699
states["base_optim_states"] = optim_states
674700

675-
flat_fp32_weights = {}
676-
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
677-
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
678-
assert param.grad is None
679-
flat_fp32_weights[group_id] = param
680-
states["flat_fp32_weights"] = flat_fp32_weights
701+
if self.use_fp16:
702+
grad_scaler = self.grad_scaler.state_dict()
703+
states["grad_scaler"] = grad_scaler
704+
705+
flat_fp32_weights = {}
706+
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
707+
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
708+
assert param.grad is None
709+
flat_fp32_weights[group_id] = param
710+
states["flat_fp32_weights"] = flat_fp32_weights
681711
states["zero_devide_optim_plan"] = self.params_per_rank_id_dict
682712

683713
return states
684714

685715
def load_state_dict(self, states):
686716
# TODO: Need to take into account the change in the number of DP.
687-
assert "grad_scaler" in states, "Not found grad_scaler state!"
688-
grad_scaler = states["grad_scaler"]
689-
self.grad_scaler.load_state_dict(grad_scaler)
690717
optim_states = states["base_optim_states"]
691718
self.optim.load_state_dict(optim_states)
692719

693-
# load fp32 model weight.
694-
flat_fp32_weights = states["flat_fp32_weights"]
695-
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
696-
for group_id, param in flat_fp32_weights.items():
< F438 /code>
697-
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
698-
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
699-
assert (
700-
self_param.shape == param.shape
701-
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
702-
self_param.data.copy_(param.data)
703-
704-
# Load the fp16 model weights.
705-
for group_id in range(len(self._fp16_param_groups)):
706-
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
707-
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
708-
rank=self._zero_local_rank, group_id=group_id
709-
)
710-
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
711-
fp16_param.data.copy_(fp32_param)
720+
if self.use_fp16:
721+
assert "grad_scaler" in states, "Not found grad_scaler state!"
722+
grad_scaler = states["grad_scaler"]
723+
self.grad_scaler.load_state_dict(grad_scaler)
724+
725+
# load fp32 model weight.
726+
flat_fp32_weights = states["flat_fp32_weights"]
727+
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
728+
for group_id, param in flat_fp32_weights.items():
729+
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
730+
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
731+
assert (
732+
self_param.shape == param.shape
733+
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
734+
self_param.data.copy_(param.data)
735+
736+
# Load the fp16 model weights.
737+
for group_id in range(len(self._fp16_param_groups)):
738+
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
739+
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
740+
rank=self._zero_local_rank, group_id=group_id
741+
)
742+
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
743+
fp16_param.data.copy_(fp32_param)
712744

713745
if "zero_devide_optim_plan" in states:
714746
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def initialize_optimizer(model: nn.Module):
282282
)
283283

284284
optimizer = HybridZeroOptimizer(
285-
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
285+
naive_optimizer,
286+
grad_scal_cfg=gpc.config.grad_scaler,
287+
zero_cfg=gpc.config.hybrid_zero_optimizer,
288+
use_fp16= gpc.config.model.dtype is torch.float32,
286289
)
287290

288291
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)

0 commit comments

Comments
 (0)
0