8000 [2023.07.12] commit-1 · DarriusL/DRL-ExampleCode@35561ec · GitHub
[go: up one dir, main page]

Skip to content

Commit 35561ec

Browse files
committed
[2023.07.12] commit-1
1) Added MountainCar environment. 2) Fixed some bugs in the off-policy system.
1 parent 4266f5d commit 35561ec

File tree

6 files changed

+80
-11
lines changed

6 files changed

+80
-11
lines changed

agent/algorithm/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def update(self):
3838
glb_var.get_value('var_reporter').add('Tau', self.var);
3939
if self.lr_schedule is not None:
4040
self.lr_schedule.step();
41-
glb_var.get_value('var_reporter').add('lr', self.agent.algorithm.optimizer.param_groups[0]["lr"])
41+
glb_var.get_value('var_reporter').add('lr', self.optimizer.param_groups[0]["lr"])
4242

4343
def _cal_loss(self, batch):
4444
'''Calculate MSELoss for DQN'''
@@ -73,7 +73,7 @@ def train_step(self, batch):
7373
batch:dict
7474
Convert through batch_to_tensor before passing in
7575
'''
76-
loss = self.cal_loss(batch);
76+
loss = self._cal_loss(batch);
7777
self.optimizer.zero_grad();
7878
self._check_nan(loss);
7979
loss.backward();

config/dqn/dqn_mountaincar_off.json

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"agent_cfg":{
3+
"algorithm_cfg":{
4+
"name":"DQN",
5+
"var_schedule_cfg":{
6+
"name":"linear",
7+
"var_start":5.0,
8+
"var_end":0.05,
9+
"star_epoch":0,
10+
"end_epoch":10000
11+
},
12+
"gamma": 0.99
13+
},
14+
"net_cfg":{
15+
"name":"MLPNet",
16+
"hid_layers":[64],
17+
"hid_layers_activation":"Selu"
18+
},
19 10000 +
"optimizer_cfg":{
20+
"name":"adam",
21+
"lr":1e-2,
22+
"weight_decay": 1e-08,
23+
"betas": [
24+
0.9,
25+
0.999
26+
]
27+
},
28+
"lr_schedule_cfg":{
29+
"name":"StepLR",
30+
"step_size":10,
31+
"gamma":0.997
32+
},
33+
"memory_cfg":{
34+
"name":"OffPolicy",
35+
"max_size":10000,
36+
"batch_size":128,
37+
"sample_add_latest":false
38+
},
39+
"max_epoch":20000,
40+
"train_start_epoch":2,
41+
"expolre_max_step":64,
42+
"train_times_per_epoch":16,
43+
"batch_learn_times_per_train":2
44+
},
45+
"env":{
46+
"name":"MountainCar",
47+
"solved_total_reward":-100,
48+
"finish_total_reward":-80,
49+
"survival_T":5000
50+
},
51+
"model_path":null,
52+
"is_gpu_available":true,
53+
"valid":{
54+
"valid_step":10,
55+
"valid_times":5,
56+
"not_improve_finish_step":5
57+
}
58+
}

env/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77
logger = glb_var.get_value('log');
88

99
def _make_env(env_cfg):
10-
if env_cfg['name'].lower() in 'cartpole':
10+
if env_cfg['name'].lower() == 'cartpole':
1111
if glb_var.get_value('mode') == 'train':
1212
return gym.make("CartPole-v1");
1313
else:
1414
return gym.make("CartPole-v1", render_mode="human");
15+
elif env_cfg['name'].lower() == 'mountaincar':
16+
if glb_var.get_value('mode') == 'train':
17+
return gym.make("MountainCar-v0").env;
18+
else:
19+
return gym.make("MountainCar-v0", render_mode="human").env;
1520
else:
1621
logger.error(f'Type of env [{env_cfg["name"]}] is not supported.\nPlease replace or add by yourself.')
1722
raise callback.CustomException('NetCfgTypeError');

env/openai_gym.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class Main_Body:
1212
env: None
1313
total_reward:None
1414
t:None
15+
is_terminated:None
1516

1617
class OpenaiEnv(Env):
1718
'''the openai environment
@@ -37,7 +38,7 @@ def __init__(self, env_cfg) -> None:
3738
self.train_env_data = None;
3839
total_reward = 0;
3940
t = 0;
40-
self.main_body = Main_Body(env, total_reward, t);
41+
self.main_body = Main_Body(env, total_reward, t, False);
4142

4243
def get_state_and_action_dim(self):
4344
'''(state_dim, action_choice)
@@ -54,7 +55,7 @@ def get_total_reward(self):
5455

5556
def is_terminated(self):
5657
'''Is the current environment terminated'''
57-
return True if self.main_body.env.steps_beyond_terminated is not None else False;
58+
return self.main_body.is_terminated;
5859

5960
def _save_train_env(self):
6061
'''Save the training environment for recovery'''
@@ -68,7 +69,7 @@ def _resume_train_env(self):
6869
def train(self):
6970
'''set train mode
7071
'''
71-
if (not self.is_training):
72+
if not self.is_training:
7273
self.is_training = True;
7374
if glb_var.get_value('mode') == 'train':
7475
self._resume_train_env();
@@ -90,16 +91,20 @@ def reset(self):
9091
'''Reset the env'''
9192
self.main_body.total_reward = 0;
9293
self.main_body.t = 0;
94+
self.main_body.is_terminated = False;
9395
state, _ = self.main_body.env.reset();
9496
return state;
9597

9698
def step(self, action):
9799
'''Change the env through the action'''
100+
if self.main_body.is_terminated:
101+
raise RuntimeError
98102
self.main_body.t += 1;
99103
next_state, reward, done, info1, info2 = self.main_body.env.step(action);
100104
self.main_body.total_reward += reward;
101105
if self.main_body.t == self.survival_T:
102106
done = True;
107+
self.main_body.is_terminated = done;
103108
return next_state, reward, done, info1, info2;
104109

105110
def render(self):

executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from lib import glb_var, callback
66
from lib.callback import Logger
77

8-
#TODO:Add: A mode that can be trained on existing models
98
#TODO:Notes on each algorithm
109
#TODO:Add algorithm:https://openai.com/research/openai-baselines-acktr-a2c
1110
#TODO:a3c gpu

room/system/onpolicy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ def train(self):
139139
self.train_mode();
140140
#collect experiences
141141
self._explore();
142-
#start to train
143-
self._train_epoch(epoch);
144-
#algorithm update
145-
self.agent.algorithm.update();
142+
#check for off policy algorithm
143+
if self._check_train_point(epoch):
144+
#start to train
145+
self._train_epoch(epoch);
146+
#algorithm update
147+
self.agent.algorithm.update();
146148
#valid mode
147149
if self._check_valid_point(epoch):
148150
if self._valid_epoch(epoch):

0 commit comments

Comments
 (0)
0