8000 [2023.06.08] commit-1 · DarriusL/DRL-ExampleCode@8ea3569 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ea3569

Browse files
committed
[2023.06.08] commit-1
1.Completed the A2C (PPO) test. 2.Optimized the code structure of REINFORCE (PPO). 3.Optimized env.openai_gym.
1 parent 6c74443 commit 8ea3569

File tree

8 files changed

+196
-27
lines changed

8 files changed

+196
-27
lines changed

README.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ options:
104104

105105
### qiuck start
106106

107-
reinforce
107+
- reinforce
108+
108109

109110
```shell
110111
python executor.py -cfg='./config/reinforce/reinforce_cartpole_mc.json' --mode='train'
@@ -113,15 +114,17 @@ python executor.py -cfg='./config/reinforce/reinforce_entropyreg_cartpole_onbatc
113114
python executor.py -cfg='./cache/data/reinforce/cartpole/[-opt-]/config.json' --mode='test'
114115
```
115116

116-
sarsa
117+
- sarsa
118+
117119

118120
```shell
119121
python executor.py -cfg='./config/sarsa/sarsa_cartpole_onbatch.json' --mode='train'
120122
python executor.py -cfg='./config/sarsa/sarsa_cartpole_mc.json' --mode='train'
121123
python executor.py -cfg='./cache/data/sarsa/cartpole/[-opt-]/config.json' --mode='test'
122124
```
123125

124-
dqn
126+
- dqn
127+
125128

126129
```shell
127130
python executor.py -cfg='./config/dqn/dqn_cartpole_off.json' --mode='train'
@@ -130,7 +133,8 @@ python executor.py -cfg='./config/dqn/doubledqn_cartpole_off.json' --mode='train
130133
python executor.py -cfg='./config/dqn/doubledqn_cartpole_per.json' --mode='train'
131134
```
132135

133-
a2c
136+
- a2c
137+
134138

135139
```shell
136140
python executor.py -cfg='./config/a2c/a2c_shared_nstep_cartpole_onbatch.json' --mode='train'
@@ -139,11 +143,17 @@ python executor.py -cfg='./config/a2c/a2c_unshared_gae_cartpole_onbatch.json' --
139143
python executor.py -cfg='./cache/data/a2c/cartpole/[-opt-]/config.json' --mode='test'
140144
```
141145

142-
ppo
146+
- ppo
147+
148+
notes:A2C (PPO) using nstep calculation advantage may cause model parameters to be nan due to gradient disappearance or gradient explosion, so the model is limited to GAE calculation.
143149

144150
```shell
145151
python executor.py -cfg='./config/ppo/reinforce_ppo_cartpole_mc.json' --mode='train'
146152
python executor.py -cfg='./config/ppo/reinforce_ppo_cartpole_onbatch.json' --mode='train'
153+
154+
python executor.py -cfg='./config/ppo/a2c_ppo_shared_gae_cartpole_onbatch.json' --mode='train'
155+
python executor.py -cfg='./config/ppo/a2c_ppo_unshared_gae_cartpole_onbatch.json' --mode='train'
156+
python executor.py -cfg='./cache/data/ppo_a2c/cartpole/[-opt-]/config.json' --mode='test'
147157
```
148158

149159

agent/algorithm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def get_alg(alg_cfg):
3030
return ActorCritic(alg_cfg);
3131
elif alg_cfg['name'].lower() == 'ppo_reinforce':
3232
return ppo.Reinforce(alg_cfg);
33+
elif alg_cfg['name'].lower() in ['ppo_a2c']:
34+
return ppo.ActorCritic(alg_cfg);
3335
else:
3436
logger.error(f'Type of algorithm [{alg_cfg["name"]}] is not supported.\nPlease replace or add by yourself.')
3537
raise callback.CustomException('NetCfgTypeError');

agent/algorithm/actor_critic.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def _cal_mc_advs_and_v_tgts(self, batch, v_preds):
105105
'''Estimate Q using Monte Carlo simulations and use this to calculate advantages
106106
'''
107107
#v_preds:[batch]
108-
#advs and v_tgt don't need to accumulate grad
109-
v_preds = v_preds.detach()
110108
#rets:[batch]
111109
#Mixed trajectory, cannot use [fast]
112110
rets = alg_util.cal_returns(batch['rewards'], batch['dones'], self.gamma, fast = False);
@@ -118,8 +116,6 @@ def _cal_mc_advs_and_v_tgts(self, batch, v_preds):
118116
def _cal_nstep_advs_and_v_tgts(self, batch, v_preds):
119117
'''Using temporal difference learning to estimate Q and then calculate the advantage'''
120118
#v_preds:[batch]
121-
#advs and v_tgt don't need to accumulate grad
122-
v_preds = v_preds.detach()
123119
with torch.no_grad():
124120
#is a value
125121
next_v_pred = self._cal_v(batch['states'][-1]);
@@ -132,8 +128,6 @@ def _cal_nstep_advs_and_v_tgts(self, batch, v_preds):
132128
def _cal_gae_advs_and_v_tgts(self, batch, v_preds):
133129
'''Calculate GAE and estimate v_tgt'''
134130
#v_preds:[batch]
135-
#advs and v_tgt don't need to accumulate grad
136-
v_preds = v_preds.detach()
137131
with torch.no_grad():
138132
#[1]
139133
next_v_pred = self._cal_v(batch['states'][-1].unsqueeze(0));
@@ -171,9 +165,15 @@ def update(self):
171165

172166

173167
def train_step(self, batch):
168+
''''''
169+
with torch.no_grad():
170+
v_preds = self._cal_v(batch['states']);
171+
advs, v_tgts = self._cal_advs_and_v_tgts(batch, v_preds);
172+
self._train_main(batch, advs, v_tgts);
173+
174+
def _train_main(self, batch, advs, v_tgts):
174175
''''''
175176
v_preds = self._cal_v(batch['states']);
176-
advs, v_tgts = self._cal_advs_and_v_tgts(batch, v_preds);
177177
policy_loss = self._cal_policy_loss(batch, advs);
178178
self._check_nan(policy_loss);
179179
value_loss = self._cal_value_loss(v_preds, v_tgts);

agent/algorithm/ppo.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from agent.algorithm import reinforce, alg_util, actor_critic
55
from agent.net import *
66
from agent.memory import *
7-
from lib import glb_var
7+
from lib import glb_var, callback
88
import copy, torch
99

1010
logger = glb_var.get_value('log')
@@ -16,7 +16,7 @@ def __init__(self, algorithm_cfg) -> None:
1616
super().__init__(algorithm_cfg);
1717
self.clip_var_var_shedule = alg_util.VarScheduler(algorithm_cfg['clip_var_cfg']);
1818
self.clip_var = self.clip_var_var_shedule.var_start;
19-
glb_var.get_value('var_reporter').add('Clip coefficient', self.clip_var);
19+
glb_var.get_value('var_reporter').add('Policy gradient clipping coefficient', self.clip_var);
2020
self.batch_spliter = get_batch_split(self.batch_split_type);
2121

2222
def init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epoch):
@@ -89,6 +89,14 @@ def _cal_loss(self, batch, rets):
8989
'''
9090
#[batch_size, out_dim]
9191
action_batch_logits = self._cal_action_pd(batch['states']);
92+
# if torch.any(torch.isnan(action_batch_logits)):
93+
# print(action_batch_logits);
94+
# print(batch['states']);
95+
# if self.is_ac_shared:
96+
# torch.save(self.acnet, './cache/problem.model');
97+
# else:
98+
# torch.save(self.acnets[0], './cache/problem.model');
99+
# raise RuntimeError;
92100
action_pd_batch = torch.distributions.Categorical(logits = action_batch_logits);
93101
#[batch_size]
94102
log_probs = action_pd_batch.log_prob(batch['actions']);
@@ -125,21 +133,18 @@ def train_step(self, batch):
125133
'''Train network'''
126134
subbatches = self.batch_spliter(batch, self.batch_num, add_origin = True);
127135
for subbatch in subbatches:
128-
rets = self._cal_rets(subbatch);
129-
loss = self._cal_loss(subbatch, rets);
130-
self.optimizer.zero_grad();
131-
self._check_nan(loss);
132-
loss.backward();
133-
self.optimizer.step();
134-
if hasattr(torch.cuda, 'empty_cache'):
135-
torch.cuda.empty_cache();
136-
logger.debug(f'Actor loss: [{loss.item()}]');
136+
super().train_step(subbatch);
137137

138138
class ActorCritic(Reinforce, actor_critic.ActorCritic):
139139
def __init__(self, algorithm_cfg) -> None:
140140
actor_critic.ActorCritic.__init__(self, algorithm_cfg);
141141
Reinforce.__init__(self, algorithm_cfg);
142142
self.is_onpolicy = True;
143+
#notes:ppo use gae for calculate advs
144+
self._cal_advs_and_v_tgts = self._cal_gae_advs_and_v_tgts;
145+
if self.lbd is None:
146+
logger.error(f'ActorCritic(PPO) use gae to calculate advantages, but no lambda value is set.');
147+
raise callback.CustomException('CfgError');
143148

144149
def init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epoch):
145150
actor_critic.ActorCritic.init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epoch);
@@ -197,9 +202,12 @@ def _cal_value_loss(self, v_preds, v_tgts):
197202
return actor_critic.ActorCritic._cal_value_loss(self, v_preds, v_tgts);
198203

199204
def train_step(self, batch):
200-
subbatches = self.batch_spliter(batch, self.batch_num, add_origin = True);
205+
with torch.no_grad():
206+
v_preds = self._cal_v(batch['states']);
207+
batch['advs'], batch['v_tgt'] = self._cal_advs_and_v_tgts(batch, v_preds);
208+
subbatches = self.batch_spliter(batch, self.batch_num, add_origin = False);
201209
for subbatch in subbatches:
202-
actor_critic.ActorCritic.train_step(self, subbatch);
210+
actor_critic.ActorCritic._train_main(self, subbatch, subbatch['advs'], subbatch['v_tgt']);
203211

204212

205213

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{
2+
"agent_cfg":{
3+
"algorithm_cfg":{
4+
"name":"PPO_A2C",
5+
"var_schedule_cfg":null,
6+
"gamma":0.99,
7+
"rets_mean_baseline":false,
8+
"policy_loss_var":1,
9+
"value_loss_var":0.5,
10+
"entropy_reg_var_cfg":{
11+
"name":"fixed",
12+
"var_start":0.01,
13+
"var_end":0.01,
14+
"star_epoch":0,
15+
"end_epoch":0
16+
},
17+
"clip_var_cfg":{
18+
"name":"fixed",
19+
"var_start":0.1,
20+
"var_end":0.1,
21+
"star_epoch":0,
22+
"end_epoch":0
23+
},
24+
"n_step_returns":null,
25+
"lbd":0.95,
26+
"batch_split_type":"random",
27+
"batch_num":4
28+
},
29+
"net_cfg":{
30+
"name":"SharedMLPNet",
31+
"body_hid_layers":[32],
32+
"body_out_dim":16,
33+
"hid_layers_activation":"Selu",
34+
"output_hid_layers":[16]
35+
},
36+
"optimizer_cfg":{
37+
"name":"adam",
38+
"lr":1e-3,
39+
"weight_decay": 1e-08,
40+
"betas": [
41+
0.9,
42+
0.999
43+
]
44+
},
45+
"lr_schedule_cfg":null,
46+
"memory_cfg":{
47+
"name":"OnPolicyBatch"
48+
},
49+
"max_epoch":10000,
50+
"explore_times_per_train":4,
51+
"train_exp_size":64,
52+
"batch_learn_times_per_train":4
53+
},
54+
"env":{
55+
"name":"CartPole",
56+
"solved_total_reward":99900,
57+
"finish_total_reward":100000,
58+
"survival_T":100000
59+
},
60+
"model_path":null,
61+
"valid":{
62+
"valid_step":100,
63+
"valid_times":5,
64+
"not_improve_finish_step":5
65+
}
66+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"agent_cfg":{
3+
"algorithm_cfg":{
4+
"name":"A2C",
5+
"var_schedule_cfg":null,
6+
"gamma":0.99,
7+
"rets_mean_baseline":false,
8+
"policy_loss_var":1,
9+
"value_loss_var":0.7,
10+
"entropy_reg_var_cfg":{
11+
"name":"fixed",
12+
"var_start":0.01,
13+
"var_end":0.01,
14+
"star_epoch":0,
15+
"end_epoch":0
16+
},
17+
"clip_var_cfg":{
18+
"name":"fixed",
19+
"var_start":0.1,
20+
"var_end":0.1,
21+
"star_epoch":0,
22+
"end_epoch":0
23+
},
24+
"n_step_returns":null,
25+
"lbd":0.95,
26+
"batch_split_type":"random",
27+
"batch_num":4
28+
},
29+
"net_cfg":{
30+
"actor_net_cfg":{
31+
"name":"MLPNet",
32+
"hid_layers":[32],
33+
"hid_layers_activation":"Selu"
34+
},
35+
"critic_net_cfg":{
36+
"name":"MLPNet",
37+
"hid_layers":[32],
38+
"hid_layers_activation":"Selu"
39+
}
40+
},
41+
"optimizer_cfg":{
42+
"actor_optim_cfg":{
43+
"name":"adam",
44+
"lr":1e-3,
45+
"weight_decay": 1e-08,
46+
"betas": [
47+
0.9,
48+
0.999
49+
]
50+
},
51+
"critic_optim_cfg":{
52+
"name":"adam",
53+
"lr":1e-3,
54+
"weight_decay": 1e-08,
55+
"betas": [
56+
0.9,
57+
0.999
58+
]
59+
}
60+
},
61+
"lr_schedule_cfg":null,
62+
"memory_cfg":{
63+
"name":"OnPolicyBatch"
64+
},
65+
"max_epoch":10000,
66+
"explore_times_per_train":4,
67+
"train_exp_size":64,
68+
"batch_learn_times_per_train":4
69+
},
70+
"env":{
71+
"name":"CartPole",
72+
"solved_total_reward":99900,
73+
"finish_total_reward":100000,
74+
"survival_T":100000
75+
},
76+
"model_path":null,
77+
"valid":{
78+
"valid_step":100,
79+
"valid_times":5,
80+
"not_improve_finish_step":5
81+
}
82+
}

config/ppo/reinforce_ppo_cartpole_mc.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
},
5555
"model_path":null,
5656
"valid":{
57-
"valid_step":100,
57+
"valid_step":10,
5858
"valid_times":5,
5959
"not_improve_finish_step":5
6060
}

env/openai_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def reset(self):
8787
'''Reset the env'''
8888
self.main_body.total_reward = 0;
8989
self.main_body.t = 0;
90-
return self.main_body.env.reset();
90+
state, _ = self.main_body.env.reset();
91+
return state;
9192

9293
def step(self, action):
9394
'''Change the env through the action'''

0 commit comments

Comments
 (0)
0