8000 [2023.06.06] commit-1 · DarriusL/DRL-ExampleCode@a8b0242 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit a8b0242

Browse files
committed
[2023.06.06] commit-1
1.Updated [OnPolicyMemory] and [OnPolicySystem] to accommodate multi-actor acquisition experience. 2.Fixed a bug in ActorCritic where the advantages of Monte Carlo simulation calculations could not be used.
1 parent 9290415 commit a8b0242

19 files changed

+166
-14
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ a2c
136136

137137
```shell
138138
python executor.py -cfg='./config/a2c/a2c_shared_nstep_cartpole_onbatch.json' --mode='train'
139+
python executor.py -cfg='./config/a2c/a2c_shared_mc_cartpole_mc.json' --mode='train'
139140
python executor.py -cfg='./config/a2c/a2c_unshared_gae_cartpole_onbatch.json' --mode='train'
140141
python executor.py -cfg='./cache/data/a2c/cartpole/[-opt-]/config.json' --mode='test'
141-
142142
```
143143

agent/algorithm/actor_critic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, algorithm_cfg) -> None:
1919
glb_var.get_value('var_reporter').add('Value loss coefficient', self.value_loss_var);
2020

2121
#cal advs method
22-
if self.n_step_returns is not None and self.lbd is not None:
22+
if self.n_step_returns is None and self.lbd is None:
2323
self._cal_advs_and_v_tgts = self._cal_mc_advs_and_v_tgts;
2424
elif self.n_step_returns is not None and self.lbd is None:
2525
#use n-step

agent/algorithm/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def init_net(self, net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epo
9696
super().init_net(net_cfg, optim_cfg, lr_schedule_cfg, in_dim, out_dim, max_epoch);
9797
self.q_target_net = get_net(net_cfg, in_dim, 9E88 out_dim, device = glb_var.get_value('device'));
9898
#Initialize q_target_net with q_net
99-
self.net_updater.net_param_copy(self.q_net, self.q_target_net);
99+
net_util.net_param_copy(self.q_net, self.q_target_net)
100100
self.net_updater.set_net(self.q_net, self.q_target_net);
101101
self.q_eval_net = self.q_target_net;
102102

agent/algorithm/ppo.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# @Time : 2023.06.06
2+
# @Author : Darrius Lei
3+
# @Email : darrius.lei@outlook.com
4+
from agent.algorithm import reinforce
5+
from agent import alg_util
6+
from lib import glb_var
7+
8+
logger = glb_var.get_value('log')
9+
10+
class Reinforce(reinforce.Reinforce):
11+
def __init__(self, algorithm_cfg) -> None:
12+
super().__init__(algorithm_cfg);
13+
self.clip_var_var_shedule = alg_util.VarScheduler(algorithm_cfg['clip_var_cfg']);
14+
self.clip_var = self.clip_var_var_shedule.var_start;
15+
glb_var.get_value('var_reporter').add('Clip coefficient', self.clip_var)

agent/algorithm/reinforce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self, algorithm_cfg) -> None:
2727
if algorithm_cfg['entropy_reg_var_cfg'] is not None:
2828
self.entorpy_reg_var_shedule = alg_util.VarScheduler(algorithm_cfg['entropy_reg_var_cfg']);
2929
self.entorpy_reg_var = self.entorpy_reg_var_shedule.var_start;
30+
glb_var.get_value('var_reporter').add('Entropy regularization coefficient', self.entorpy_reg_var)
3031
else:
3132
self.entorpy_reg_var_shedule = None;
3233

agent/memory/onpolicy.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,15 @@ def __init__(self, memory_cfg) -> None:
2424
self.is_episodic_exp = True;
2525
#Experience data that needs to be stored
2626
self.exp_keys = ['states', 'actions', 'rewards', 'next_states', 'dones'];
27+
#multiple expolre
28+
self.explore_times = 1;
29+
self.explore_cnt = 0;
2730
self.reset();
2831

32+
def multiple_explore(self, times):
33+
''''''
34+
self.explore_times = times;
35+
2936
def train(self):
3037
pass
3138

@@ -40,6 +47,7 @@ def reset(self):
4047
self.exp_latest = [None] * len(self.exp_keys);
4148
self.cur_exp = {k: [] for k in self.exp_keys};
4249
self.stock = 0;
50+
self.explore_cnt = 0;
4351

4452
def update(self, state, action, reward, next_state, done):
4553
'''Add experience to experience memory
@@ -51,8 +59,11 @@ def update(self, state, action, reward, next_state, done):
5159

5260
self.stock += 1;
5361
if done:
54-
for key in self.exp_keys:
55-
getattr(self, key).append(self.cur_exp[key]);
62+
self.explore_cnt += 1;
63+
if self.explore_cnt >= self.explore_times:
64+
self.explore_cnt = 0;
65+
for key in self.exp_keys:
66+
getattr(self, key).append(self.cur_exp[key]);
5667

5768
def _batch_to_tensor(self, batch):
5869
'''Convert a batch to a format for torch training

agent/net/net_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,18 @@ def get_mlpnet(hid_layers, activation_fn, in_dim, out_dim):
115115
]
116116
return torch.nn.Sequential(*layers);
117117

118+
def net_param_copy(src, tgt):
119+
'''Copy network parameters from src to tgt'''
120+
tgt.load_state_dict(src.state_dict());
121+
118122
class NetUpdater():
119123
'''for updating the network'''
120124
def __init__(self, net_update_cfg) -> None:
121125
util.set_attr(self, net_update_cfg, except_type = dict);
122126
self.epoch = 0;
123127
#generate net update policy
124128
if self.name.lower() == 'replace':
125-
self.updater = self.net_param_copy;
129+
self.updater = net_param_copy;
126130
elif self.name.lower() == 'polyak':
127131
self.updater = self.net_param_polyak_update;
128132
else:
@@ -134,10 +138,6 @@ def set_net(self, src, tgt):
134138
self.src_net = src;
135139
self.tgt_net = tgt;
136140

137-
def net_param_copy(self, src, tgt):
138-
'''Copy network parameters from src to tgt'''
139-
tgt.load_state_dict(src.state_dict());
140-
141141
def net_param_polyak_update(self, src, tgt):
142142
'''Polyak updata policy
143143
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"agent_cfg":{
3+
"algorithm_cfg":{
4+
"name":"A2C",
5+
"var_schedule_cfg":null,
6+
"gamma":0.99,
7+
"rets_mean_baseline":false,
8+
"policy_loss_var":1,
9+
"value_loss_var":0.6,
10+
"entropy_reg_var_cfg":{
11+
"name":"fixed",
12+
"var_start":0.01,
13+
"var_end":0.01,
14+
"star_epoch":0,
15+
"end_epoch":0
16+
},
17+
"n_step_returns":null,
18+
"lbd":null
19+
},
20+
"net_cfg":{
21+
"name":"SharedMLPNet",
22+
"body_hid_layers":[32],
23+
"body_out_dim":16,
24+
"hid_layers_activation":"Selu",
25+
"output_hid_layers":[16]
26+
},
27+
"optimizer_cfg":{
28+
"name":"adam",
29+
"lr":1e-3,
30+
"weight_decay": 1e-08,
31+
"betas": [
32+
0.9,
33+
0.999
34+
]
35+
},
36+
"lr_schedule_cfg":null,
37+
"memory_cfg":{
38+
"name":"OnPolicy"
39+
},
40+
"max_epoch":10000,
41+
"explore_times_per_train":1,
42+
"train_exp_size":1,
43+
"batch_learn_times_per_train":4
44+
},
45+
"env":{
46+
"name":"CartPole",
47+
"solved_total_reward":99900,
48+
"finish_total_reward":100000,
49+
"survival_T":100000
50+
},
51+
"model_path":null,
52+
"valid":{
53+
"valid_step":100,
54+
"valid_times":5,
55+
"not_improve_finish_step":5
56+
}
57+
}

config/a2c/a2c_shared_nstep_cartpole_onbatch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"name":"OnPolicyBatch"
3939
},
4040
"max_epoch":10000,
41+
"explore_times_per_train":1,
4142
"train_exp_size":64,
4243
"batch_learn_times_per_train":4
4344
},

config/a2c/a2c_unshared_gae_cartpole_onbatch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"name":"OnPolicyBatch"
5555
},
5656
"max_epoch":10000,
57+
"explore_times_per_train":1,
5758
"train_exp_size":64,
5859
"batch_learn_times_per_train":4
5960
},

0 commit comments

Comments
 (0)
0