10000 load model in __init__ · rlcode/reinforcement-learning@4b8da42 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b8da42

Browse files
committed
load model in __init__
1 parent f504c84 commit 4b8da42

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

2-cartpole/1-dqn/cartpole_dqn.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
EPISODES = 300
1212

1313

14-
# this is DQN Agent for the Cartpole
14+
# DQN Agent for the Cartpole
1515
# it uses Neural Network to approximate q function
1616
# and replay memory & target q network
1717
class DQNAgent:
1818
def __init__(self, state_size, action_size):
1919
# if you want to see Cartpole learning, then change to True
20-
self.render = True
20+
self.render = False
21+
self.load_model = False
2122

2223
# get size of state and action
2324
self.state_size = state_size
@@ -37,17 +38,23 @@ def __init__(self, state_size, action_size):
3738
# create main model and target model
3839
self.model = self.build_model()
3940
self.target_model = self.build_model()
40-
# copy the model to target model
41-
# --> initialize the target model so that the parameters of model & target model to be same
41+
42+
# initialize target model
4243
self.update_target_model()
4344

45+
if self.load_model:
< 10000 /td>
46+
self.model.load_weights("./save_model/cartpole_dqn.h5")
47+
4448
# approximate Q function using Neural Network
4549
# state is input and Q Value of each action is output of network
4650
def build_model(self):
4751
model = Sequential()
48-
model.add(Dense(24, input_dim=self.state_size, activation='relu', kernel_initializer='he_uniform'))
49-
model.add(Dense(24, activation='relu', kernel_initializer='he_uniform'))
50-
model.add(Dense(self.action_size, activation='linear', kernel_initializer='he_uniform'))
52+
model.add(Dense(24, input_dim=self.state_size, activation='relu',
53+
kernel_initializer='he_uniform'))
54+
model.add(Dense(24, activation='relu',
55+
kernel_initializer='he_uniform'))
56+
model.add(Dense(self.action_size, activation='linear',
57+
kernel_initializer='he_uniform'))
5158
model.summary()
5259
model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
5360
return model
@@ -92,27 +99,20 @@ def train_replay(self):
9299
target_val = self.target_model.predict(update_target)
93100

94101
for i in range(self.batch_size):
95-
# like Q Learning, get maximum Q value at s'
96-
# But from target model
102+
# Q Learning: get maximum Q value at s' from target model
97103
if done[i]:
98104
target[i][action[i]] = reward[i]
99105
else:
100-
target[i][action[i]] = reward[i] + self.discount_factor * np.amax(target_val[i])
106+
target[i][action[i]] = reward[i] + self.discount_factor * (
107+
np.amax(target_val[i]))
101108

102109
# and do the model fit!
103-
self.model.fit(update_input, target, batch_size=self.batch_size, epochs=1, verbose=0)
104-
105-
# load the saved model
106-
def load_model(self, name):
107-
self.model.load_weights(name)
108-
109-
# save the model which is under training
110-
def save_model(self, name):
111-
self.model.save_weights(name)
110+
self.model.fit(update_input, targ F61A et, batch_size=self.batch_size,
111+
epochs=1, verbose=0)
112112

113113

114114
if __name__ == "__main__":
115-
# In case of CartPole-v1, you can play until 500 time step
115+
# In case of CartPole-v1, maximum length of episode is 500
116116
env = gym.make('CartPole-v1')
117117
# get size of state and action from environment
118118
state_size = env.observation_space.shape[0]
@@ -127,7 +127,6 @@ def save_model(self, name):
127127
score = 0
128128
state = env.reset()
129129
state = np.reshape(state, [1, state_size])
130-
# agent.load_model("./save_model/cartpole_dqn.h5")
131130

132131
while not done:
133132
if agent.render:
@@ -155,8 +154,8 @@ def save_model(self, name):
155154
score = score if score == 500 else score + 100
156155
scores.append(score)
157156
episodes.append(e)
158-
#pylab.plot(episodes, scores, 'b')
159-
# pylab.savefig("./save_graph/cartpole_dqn.png")
157+
pylab.plot(episodes, scores, 'b')
158+
pylab.savefig("./save_graph/cartpole_dqn.png")
160159
print("episode:", e, " score:", score, " memory length:",
161160
len(agent.memory), " epsilon:", agent.epsilon)
162161

@@ -167,4 +166,4 @@ def save_model(self, name):
167166

168167
# save the model
169168
if e % 50 == 0:
170-
agent.save_model("./save_model/cartpole_dqn.h5")
169+
agent.model.save_weights("./save_model/cartpole_dqn.h5")

0 commit comments

Comments
 (0)
0