8000 fix a3c to use local model · rlcode/reinforcement-learning@f757c86 · GitHub
[go: up one dir, main page]

Skip to content

Commit f757c86

Browse files
committed
fix a3c to use local model
1 parent 6226217 commit f757c86

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

3-atari/1-breakout/breakout_a3c.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def train(self):
6060
agent.start()
6161

6262
while True:
63-
time.sleep(60*5)
63+
time.sleep(60*10)
6464
self.save_model("./save_model/breakout_a3c")
6565

6666
# approximate policy and value using Neural Network
@@ -71,7 +71,6 @@ def build_model(self):
7171
input = Input(shape=self.state_size)
7272
conv = Conv2D(16, (8, 8), strides=(4, 4), activation='relu')(input)
7373
conv = Conv2D(32, (4, 4), strides=(2, 2), activation='relu')(conv)
74-
conv = Conv2D(32, (3, 3), strides=(1, 1), activation='relu')(conv)
7574
conv = Flatten()(conv)
7675
fc = Dense(256, activation='relu')(conv)
7776
policy = Dense(self.action_size, activation='softmax')(fc)
@@ -80,8 +79,8 @@ def build_model(self):
8079
actor = Model(inputs=input, outputs=policy)
8180
critic = Model(inputs=input, outputs=value)
8281

83-
actor.predict(np.random.rand(1, 84, 84, 4))
84-
critic.predict(np.random.rand(1, 84, 84, 4))
82+
actor._make_predict_function()
83+
critic._make_predict_function()
8584

8685
actor.summary()
8786
critic.summary()
@@ -163,6 +162,8 @@ def __init__(self, action_size, state_size, model, sess, optimizer, discount_fac
163162

164163
self.states, self.actions, self.rewards = [],[],[]
165164

165+
self.local_actor, self.local_critic = self.build_localmodel()
166+
166167
self.avg_p_max = 0
167168
self.avg_loss = 0
168169

@@ -209,6 +210,11 @@ def run(self):
209210
elif action == 1: real_action = 2
210211
else: real_action = 3
211212

213+
if dead:
214+
action = 0
215+
real_action = 1
216+
dead = False
217+
212218
next_observe, reward, done, info = env.step(real_action)
213219
# pre-process the observation --> history
214220
next_state = pre_processing(next_observe, observe)
@@ -232,13 +238,13 @@ def run(self):
232238
if dead:
233239
history = np.stack((next_state, next_state, next_state, next_state), axis=2)
234240
history = np.reshape([history], (1, 84, 84, 4))
235-
dead = False
236241
else:
237242
history = next_history
238243

239244
#
240245
if self.t >= self.t_max or done:
241-
self.train_t(done)
246+
self.train_model(done)
247+
self.update_localmodel()
242248
self.t = 0
243249

244250
# if done, plot the score over episodes
@@ -271,7 +277,7 @@ def discount_rewards(self, rewards, done):
271277
return discounted_rewards
272278

273279
# update policy network and value network every episode
274-
def train_t(self, done):
280+
def train_model(self, done):
275281
discounted_rewards = self.discount_rewards(self.rewards, done)
276282

277283
states = np.zeros((len(self.states), 84, 84, 4))
@@ -289,9 +295,36 @@ def train_t(self, done):
289295
self.optimizer[1]([states, discounted_rewards])
290296
self.states, self.actions, self.rewards = [], [], []
291297

298+
def build_localmodel(self):
299+
input = Input(shape=self.state_size)
300+
conv = Conv2D(16, (8, 8), strides=(4, 4), activation='relu')(input)
301+
conv = Conv2D(32, (4, 4), strides=(2, 2), activation='relu')(conv)
302+
conv = Flatten()(conv)
303+
fc = Dense(256, activation='relu')(conv)
304+
policy = Dense(self.action_size, activation='softmax')(fc)
305+
value = Dense(1, activation='linear')(fc)
306+
307+
actor = Model(inputs=input, outputs=policy)
308+
critic = Model(inputs=input, outputs=value)
309+
310+
actor._make_predict_function()
311+
critic._make_predict_function()
312+
313+
actor.set_weights(self.actor.get_weights())
314+
critic.set_weights(self.critic.get_weights())
315+
316+
actor.summary()
317+
critic.summary()
318+
319+
return actor, critic
320+
321+
def update_localmodel(self):
322+
self.local_actor.set_weights(self.actor.get_weights())
323+
self.local_critic.set_weights(self.critic.get_weights())
324+
292325
def get_action(self, history):
293326
history = np.float32(history / 255.)
294-
policy = self.actor.predict(history)[0]
327+
policy = self.local_actor.predict(history)[0]
295328

296329
policy = policy - np.finfo(np.float32).epsneg
297330

0 commit comments

Comments
 (0)
0