8000 [2023.06.13] commit-2 · DarriusL/DRL-ExampleCode@355b281 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 355b281

Browse files
committed
[2023.06.13] commit-2
`Fixed the bug that the A3C performance of the shared network in the framework was slow to improve. `Fixed a bug during training in the A3C non-shared network. `Fixed an issue where the process terminated abnormally.
1 parent 4d0f985 commit 355b281

File tree

5 files changed

+126
-25
lines changed

5 files changed

+126
-25
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ python executor.py -cfg='./config/ppo/a2c_ppo_unshared_gae_cartpole_onbatch.json
156156
python executor.py -cfg='./cache/data/ppo_a2c/cartpole/[-opt-]/config.json' --mode='test'
157157
```
158158

159+
- a3c
160+
161+
```shell
162+
python executor.py -cfg='./config/a3c/a3c_shared_nstep_cartpole_onbatch.json' --mode='train'
163+
```
164+
159165

160166

161167
## Refrence

agent/algorithm/actor_critic.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epo
7070
optimizer = optimizer,
7171
lr_schedule = lr_schedule
7272
));
73+
self.optim_net = acnet;
7374
glb_var.get_value('var_reporter').add('lr', self.optimizer.param_groups[0]["lr"]);
7475
else:
7576
util.set_attr(self, dict(
@@ -78,6 +79,7 @@ def init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epo
7879
optimizers = optimizer,
7980
lr_schedules = lr_schedule
8081
));
82+
self.optim_nets = acnet;
8183
glb_var.get_value('var_reporter').add('actor-lr', self.optimizers[0].param_groups[0]["lr"]);
8284
glb_var.get_value('var_reporter').add('critic-lr', self.optimizers[-1].param_groups[0]["lr"]);
8385

@@ -123,7 +125,7 @@ def __set_shared_param(net, shared_net):
123125
if self.is_ac_shared:
124126
__set_shared_param(self.acnet, self.shared_net);
125127
else:
126-
for net, shared_net in zip(self.acnet, self.shared_nets):
128+
for net, shared_net in zip(self.acnets, self.shared_nets):
127129
__set_shared_param(net, shared_net);
128130
else:
129131
pass;
@@ -221,6 +223,21 @@ def train_step(self, batch):
221223
v_preds = self._cal_v(batch['states']);
222224
advs, v_tgts = self._cal_advs_and_v_tgts(batch, v_preds);
223225
self._train_main(batch, advs, v_tgts);
226+
227+
def _optim_net(self, loss, net, optimizer = None):
228+
''''''
229+
def __optim(loss, net, optimizer):
230+
loss.backward();
231+
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm = 0.5);
232+
self._set_shared_grads();
233+
optimizer.step();
234+
if optimizer is None:
235+
optimizer = self.optimizer;
236+
if not self.is_asyn:
237+
__optim(loss, net, optimizer);
238+
else:
239+
with glb_var.get_value('lock'):
240+
__optim(loss, net, optimizer);
224241

225242
def _train_main(self, batch, advs, v_tgts):
226243
''''''
@@ -232,17 +249,11 @@ def _train_main(self, batch, advs, v_tgts):
232249
if self.is_ac_share 6D40 d:
233250
loss = policy_loss + value_loss;
234251
self.optimizer.zero_grad();
235-
loss.backward();
236-
torch.nn.utils.clip_grad_norm_(self.acnet.parameters(), max_norm = 0.5);
237-
self._set_shared_grads();
238-
self.optimizer.step();
252+
self._optim_net(loss, self.acnet, self.optimizer);
239253
else:
240254
for net, optimzer, loss in zip(self.acnets, self.optimizers, [policy_loss, value_loss]):
241255
optimzer.zero_grad();
242-
loss.backward();
243-
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm = 0.5);
244-
self._set_shared_grads();
245-
optimzer.step();
256+
self._optim_net(loss, net, optimzer);
246257
loss = policy_loss + value_loss;
247258
logger.debug(f'ActorCritic Total loss:{loss.item()}');
248259
if hasattr(torch.cuda, 'empty_cache'):

config/a3c/a3c_shared_nstep_cartpole_onbatch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"star_epoch":0,
1515
"end_epoch":0
1616
},
17-
"n_step_returns":64,
17+
"n_step_returns":32,
1818
"lbd":null
1919
},
2020
"net_cfg":{
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
{
2+
"agent_cfg":{
3+
"algorithm_cfg":{
4+
"name":"A3C",
5+
"var_schedule_cfg":null,
6+
"gamma":0.99,
7+
"rets_mean_baseline":false,
8+
"policy_loss_var":1,
9+
"value_loss_var":0.7,
10+
"entropy_reg_var_cfg":{
11+
"name":"fixed",
12+
"var_start":0.01,
13+
"var_end":0.01,
14+
"star_epoch":0,
15+
"end_epoch":0
16+
},
17+
"n_step_returns":null,
18+
"lbd":0.95
19+
},
20+
"net_cfg":{
21+
"actor_net_cfg":{
22+
"name":"MLPNet",
23+
"hid_layers":[32],
24+
"hid_layers_activation":"Selu"
25+
},
26+
"critic_net_cfg":{
27+
"name":"MLPNet",
28+
"hid_layers":[32],
29+
"hid_layers_activation":"Selu"
30+
}
31+
},
32+
"optimizer_cfg":{
33+
"actor_optim_cfg":{
34+
"name":"adam",
35+
"lr":1e-3,
36+
"weight_decay": 1e-08,
37+
"betas": [
38+
0.9,
39+
0.999
40+
]
41+
},
42+
"critic_optim_cfg":{
43+
"name":"adam",
44+
"lr":1e-3,
45+
"weight_decay": 1e-08,
46+
"betas": [
47+
0.9,
48+
0.999
49+
]
50+
}
51+
},
52+
"lr_schedule_cfg":null,
53+
"memory_cfg":{
54+
"name":"OnPolicyBatch"
55+
},
56+
"max_epoch":10000,
57+
"explore_times_per_train":1,
58+
"train_exp_size":64,
59+
"batch_learn_times_per_train":4,
60+
"asyn_num":3
61+
},
62+
"env":{
63+
"name":"CartPole",
64+
"solved_total_reward":99900,
65+
"finish_total_reward":100000,
66+
"survival_T":100000
67+
},
68+
"model_path":null,
69+
"is_gpu_available":false,
70+
"valid":{
71+
"valid_step":100,
72+
"valid_times":5,
73+
"not_improve_finish_step":5
74+
}
75+
}

room/system/onpolicy.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def init_sys(self, rank, shared_alg, optimzer):
195195
);
196196
self.agent.algorithm.set_shared_net(shared_alg);
197197

198-
def train(self, lock, stop_event, cnt, rank, shared_alg, optimzer):
198+
def train(self, lock, stop_event, cnt, end_cnt, rank, shared_alg, optimzer):
199199
''''''
200200
self.init_sys(rank, shared_alg, optimzer);
201201
for epoch in range(self.agent.max_epoch):
@@ -213,9 +213,11 @@ def train(self, lock, stop_event, cnt, rank, shared_alg, optimzer):
213213
self.agent.algorithm.update();
214214
with lock:
215215
cnt.value += 1;
216-
logger.info(f'Process {self.rank} end.')
216+
logger.info(f'Process {self.rank} end.');
217+
with lock:
218+
end_cnt.value += 1;
217219

218-
def valid(self, lock, stop_event, cnt, rank, shared_alg, optimzer):
220+
def valid(self, lock, stop_event, cnt, end_cnt, rank, shared_alg, optimzer):
219221
''''''
220222
self.init_sys(rank, shared_alg, optimzer)
221223
while True:
@@ -227,17 +229,22 @@ def valid(self, lock, stop_event, cnt, rank, shared_alg, optimzer):
227229
if self._valid_epoch(cnt_value):
228230
stop_event.set();
229231
break;
232+
if end_cnt == self.rank:
233+
break;
230234
time.sleep(60);
231-
#plot rets
232-
util.single_plot(
233-
np.arange(len(self.rets_mean_valid)) + 1,
234-
self.rets_mean_valid,
235-
'valid_times', 'mean_rets', self.save_path + '/mean_rets.png');
236-
#plot total rewards
237-
util.single_plot(
238-
np.arange(len(self.total_rewards_valid)) + 1,
239-
self.total_rewards_valid,
240-
'valid_times', 'rewards', self.save_path + '/rewards.png');
235+
logger.info(f'Saved Model Information:\nSolved: [{self.best_solved}] - Mean total rewards: [{self.max_total_rewards}]'
236+
f'\nSaved path:{self.save_path}');
237+
if end_cnt != self.rank:
238+
#plot rets
239+
util.single_plot(
240+
np.arange(len(self.rets_mean_valid)) + 1,
241+
self.rets_mean_valid,
242+
'valid_times', 'mean_rets', self.save_path + '/mean_rets.png');
243+
#plot total rewards
244+
util.single_plot(
245+
np.arange(len(self.total_rewards_valid)) + 1,
246+
self.total_rewards_valid,
247+
'valid_times', 'rewards', self.save_path + '/rewards.png');
241248

242249

243250
class OnPolicyAsynSystem(OnPolicySystem):
@@ -257,19 +264,21 @@ def train(self):
257264
subvalidsystem = copy.deepcopy(subtrainsystems[-1]);
258265
del subtrainsystems[-1];
259266
cnt = torch.multiprocessing.Value('i', 0);
267+
end_cnt = torch.multiprocessing.Value('i', 0);
260268
lock = torch.multiprocessing.Lock();
269+
glb_var.set_value('lock', lock);
261270
stop_event = torch.multiprocessing.Event();
262271
processes = [];
263272
for rank, sys in enumerate(subtrainsystems):
264273
p = torch.multiprocessing.Process(
265274
target = sys.train,
266-
args = (lock, stop_event, cnt, rank, self.agent.algorithm, optimizer)
275+
args = (lock, stop_event, cnt, end_cnt, rank, self.agent.algorithm, optimizer)
267276
);
268277
p.start();
269278
processes.append(p);
270279
p_valid = torch.multiprocessing.Process(
271280
target = subvalidsystem.valid,
272-
args = (lock, stop_event, cnt, rank + 1, self.agent.algorithm, optimizer)
281+
args = (lock, stop_event, cnt, end_cnt, rank + 1, self.agent.algorithm, optimizer)
273282
);
274283
p_valid.start();
275284
processes.append(p_valid);

0 commit comments

Comments
 (0)
0