11
11
EPISODES = 300
12
12
13
13
14
- # this is DQN Agent for the Cartpole
14
+ # DQN Agent for the Cartpole
15
15
# it uses Neural Network to approximate q function
16
16
# and replay memory & target q network
17
17
class DQNAgent :
18
18
def __init__ (self , state_size , action_size ):
19
19
# if you want to see Cartpole learning, then change to True
20
- self .render = True
20
+ self .render = False
21
+ self .load_model = False
21
22
22
23
# get size of state and action
23
24
self .state_size = state_size
@@ -37,17 +38,23 @@ def __init__(self, state_size, action_size):
37
38
# create main model and target model
38
39
self .model = self .build_model ()
39
40
self .target_model = self .build_model ()
40
- # copy the model to target model
41
- # --> initialize the target model so that the parameters of model & target model to be same
41
+
42
+ # initialize target model
42
43
self .update_target_model ()
43
44
45
+ if self .load_model :
<
10000
/td>
46
+ self .model .load_weights ("./save_model/cartpole_dqn.h5" )
47
+
44
48
# approximate Q function using Neural Network
45
49
# state is input and Q Value of each action is output of network
46
50
def build_model (self ):
47
51
model = Sequential ()
48
- model .add (Dense (24 , input_dim = self .state_size , activation = 'relu' , kernel_initializer = 'he_uniform' ))
49
- model .add (Dense (24 , activation = 'relu' , kernel_initializer = 'he_uniform' ))
50
- model .add (Dense (self .action_size , activation = 'linear' , kernel_initializer = 'he_uniform' ))
52
+ model .add (Dense (24 , input_dim = self .state_size , activation = 'relu' ,
53
+ kernel_initializer = 'he_uniform' ))
54
+ model .add (Dense (24 , activation = 'relu' ,
55
+ kernel_initializer = 'he_uniform' ))
56
+ model .add (Dense (self .action_size , activation = 'linear' ,
57
+ kernel_initializer = 'he_uniform' ))
51
58
model .summary ()
52
59
model .compile (loss = 'mse' , optimizer = Adam (lr = self .learning_rate ))
53
60
return model
@@ -92,27 +99,20 @@ def train_replay(self):
92
99
target_val = self .target_model .predict (update_target )
93
100
94
101
for i in range (self .batch_size ):
95
- # like Q Learning, get maximum Q value at s'
96
- # But from target model
102
+ # Q Learning: get maximum Q value at s' from target model
97
103
if done [i ]:
98
104
target [i ][action [i ]] = reward [i ]
99
105
else :
100
- target [i ][action [i ]] = reward [i ] + self .discount_factor * np .amax (target_val [i ])
106
+ target [i ][action [i ]] = reward [i ] + self .discount_factor * (
107
+ np .amax (target_val [i ]))
101
108
102
109
# and do the model fit!
103
- self .model .fit (update_input , target , batch_size = self .batch_size , epochs = 1 , verbose = 0 )
104
-
105
- # load the saved model
106
- def load_model (self , name ):
107
- self .model .load_weights (name )
108
-
109
- # save the model which is under training
110
- def save_model (self , name ):
111
- self .model .save_weights (name )
110
+ self .model .fit (update_input , targ
F61A
et , batch_size = self .batch_size ,
111
+ epochs = 1 , verbose = 0 )
112
112
113
113
114
114
if __name__ == "__main__" :
115
- # In case of CartPole-v1, you can play until 500 time step
115
+ # In case of CartPole-v1, maximum length of episode is 500
116
116
env = gym .make ('CartPole-v1' )
117
117
# get size of state and action from environment
118
118
state_size = env .observation_space .shape [0 ]
@@ -127,7 +127,6 @@ def save_model(self, name):
127
127
score = 0
128
128
state = env .reset ()
129
129
state = np .reshape (state , [1 , state_size ])
130
- # agent.load_model("./save_model/cartpole_dqn.h5")
131
130
132
131
while not done :
133
132
if agent .render :
@@ -155,8 +154,8 @@ def save_model(self, name):
155
154
score = score if score == 500 else score + 100
156
155
scores .append (score )
157
156
episodes .append (e )
158
- # pylab.plot(episodes, scores, 'b')
159
- # pylab.savefig("./save_graph/cartpole_dqn.png")
157
+ pylab .plot (episodes , scores , 'b' )
158
+ pylab .savefig ("./save_graph/cartpole_dqn.png" )
160
159
print ("episode:" , e , " score:" , score , " memory length:" ,
161
160
len (agent .memory ), " epsilon:" , agent .epsilon )
162
161
@@ -167,4 +166,4 @@ def save_model(self, name):
167
166
168
167
# save the model
169
168
if e % 50 == 0 :
170
- agent .save_model ("./save_model/cartpole_dqn.h5" )
169
+ agent .model . save_weights ("./save_model/cartpole_dqn.h5" )
0 commit comments