8000 clean up and sync with book from policy iteration to qlearning · rlcode/reinforcement-learning@9a0d98b · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a0d98b

Browse files
committed
clean up and sync with book from policy iteration to qlearning
1 parent 454f77c commit 9a0d98b

File tree

5 files changed

+21
-31
lines changed

5 files changed

+21
-31
lines changed

1-grid-world/1-policy-iteration/environment.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ def _build_canvas(self):
7575

7676
return canvas
7777

78-
# (rectangle, triangle1, triangle2, circle)
79-
8078
def load_images(self):
8179
up = PhotoImage(Image.open("../img/up.png").resize((13, 13)))
8280
right = PhotoImage(Image.open("../img/right.png").resize((13, 13)))
@@ -99,7 +97,7 @@ def reset(self):
9997
self.agent.value_table = [[0.0] * WIDTH for _ in range(HEIGHT)]
10098
self.agent.policy_table = ([[[0.25, 0.25, 0.25, 0.25]] * WIDTH
10199
for _ in range(HEIGHT)])
102-
self.policy_table[2][2] = []
100+
self.agent.policy_table[2][2] = []
103101
x, y = self.canvas.coords(self.rectangle)
104102
self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
105103

1-grid-world/1-policy-iteration/policy_iteration.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import random
33
from environment import GraphicDisplay, Env
44

5-
DISCOUNT_FACTOR = 0.9
6-
75

86
class PolicyIteration:
97
def __init__(self, env):
@@ -15,6 +13,7 @@ def __init__(self, env):
1513
for _ in range(env.height)]
1614
# setting terminal state
1715
self.policy_table[2][2] = []
16+
self.discount_factor = 0.9
1817

1918
def policy_evaluation(self):
2019
next_value_table = [[0.00] * self.env.width
@@ -32,8 +31,8 @@ def policy_evaluation(self):
3231
next_state = self.env.state_after_action(state, action)
3332
reward = self.env.get_reward(state, action)
3433
next_value = self.get_value(next_state)
35-
value += self.get_policy(state, action) * \
36-
(reward + DISCOUNT_FACTOR * next_value)
34+
value += (self.get_policy(state)[action] *
35+
(reward + self.discount_factor * next_value))
3736

3837
next_value_table[state[0]][state[1]] = round(value, 2)
3938

@@ -46,15 +45,15 @@ def policy_improvement(self):
4645
continue
4746
value = -99999
4847
max_index = []
49-
result = [0.0, 0.0, 0.0, 0.0] # initialize the policy
48+
result = [0.0, 0.0, 0.0, 0.0] # initialize the policy
5049

5150
# for every actions, calculate
5251
# [reward + (discount factor) * (next state value function)]
5352
for index, action in enumerate(self.env.possible_actions):
5453
next_state = self.env.state_after_action(state, action)
5554
reward = self.env.get_reward(state, action)
5655
next_value = self.get_value(next_state)
57-
temp = reward + DISCOUNT_FACTOR * next_value
56+
temp = reward + self.discount_factor * next_value
5857

5958
# We normally can't pick multiple actions in greedy policy.
6059
# but here we allow multiple actions with same max values
@@ -75,6 +74,7 @@ def policy_improvement(self):
7574

7675
self.policy_table = next_policy
7776

77+
# get action according to the current policy
7878
def get_action(self, state):
7979
random_pick = random.randrange(100) / 100
8080

@@ -84,21 +84,17 @@ def get_action(self, state):
8484
for index, value in enumerate(policy):
8585
policy_sum += value
8686
if random_pick < policy_sum:
87-
return self.env.possible_actions[index]
87+
return index
8888

89-
def get_policy(self, state, action=None):
90-
# if no action is given, then return the probabilities of all actions
91-
if action is None:
92-
return self.policy_table[state[0]][state[1]]
89+
# get policy of specific state
90+
def get_policy(self, state):
9391
if state == [2, 2]:
9492
return 0.0
95-
action_index = self.env.possible_actions.index(action)
96-
return self.policy_table[state[0]][state[1]][action_index]
93+
return self.policy_table[state[0]][state[1]]
9794

9895
def get_value(self, state):
9996
return round(self.value_table[state[0]][state[1]], 2)
10097

101-
10298
if __name__ == "__main__":
10399
env = Env()
104100
policy_iteration = PolicyIteration(env)

1-grid-world/2-value-iteration/value_iteration.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
# -*- coding: utf-8 -*-
2-
import random
32
from environment import GraphicDisplay, Env
43

5-
DISCOUNT_FACTOR = 0.9
6-
7-
84
class ValueIteration:
95
def __init__(self, env):
106
self.env = env
117
# 2-d list for the value function
128
self.value_table = [[0.0] * env.width for _ in range(env.height)]
9+
self.discount_factor = 0.9
1310

1411
# get next value function table from the current value function table
1512
def value_iteration(self):
16-
next_value_table = [[0.0] * \
17-
self.env.width for _ in range(self.env.height)]
13+
next_value_table = [[0.0] * self.env.width for _ in
14+
range(self.env.height)]
1815
for state in self.env.get_all_states():
1916
if state == [2, 2]:
2017
next_value_table[state[0]][state[1]] = 0.0
@@ -25,14 +22,13 @@ def value_iteration(self):
2522
next_state = self.env.state_after_action(state, action)
2623
reward = self.env.get_reward(state, action)
2724
next_value = self.get_value(next_state)
28-
value_list.append((reward + DISCOUNT_FACTOR * next_value))
25+
value_list.append((reward + self.discount_factor * next_value))
2926
# return the maximum value(it is the optimality equation!!)
3027
next_value_table[state[0]][state[1]] = round(max(value_list), 2)
3128
self.value_table = next_value_table
3229

3330
# get action according to the current value function table
3431
def get_action(self, state):
35-
3632
action_list = []
3733
max_value = -99999
3834

@@ -46,7 +42,7 @@ def get_action(self, state):
4642
next_state = self.env.state_after_action(state, action)
4743
reward = self.env.get_reward(state, action)
4844
next_value = self.get_value(next_state)
49-
value = (reward + DISCOUNT_FACTOR * next_value)
45+
value = (reward + self.discount_factor * next_value)
5046

5147
if value > max_value:
5248
action_list.clear()

1-grid-world/4-sarsa/sarsa_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, actions):
1111
self.actions = actions
1212
self.learning_rate = 0.01
1313
self.discount_factor = 0.9
14-
self.epsilon = 0.9
14+
self.epsilon = 0.1
1515
self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])
1616

1717
# with sample <s, a, r, s', a'>, learns new q function

1-grid-world/5-q-learning/q_learning_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ def __init__(self, actions):
99
self.actions = actions
1010
self.learning_rate = 0.01
1111
self.discount_factor = 0.9
12-
self.epsilon = 0.9
12+
self.epsilon = 0.1
1313
self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])
1414

1515
# update q function with sample <s, a, r, s'>
1616
def learn(self, state, action, reward, next_state):
17-
q_1 = self.q_table[state][action]
17+
current_q = self.q_table[state][action]
1818
# using Bellman Optimality Equation to update q function
19-
q_2 = reward + self.discount_factor * max(self.q_table[next_state])
20-
self.q_table[state][action] += self.learning_rate * (q_2 - q_1)
19+
new_q = reward + self.discount_factor * max(self.q_table[next_state])
20+
self.q_table[state][action] += self.learning_rate * (current_q - new_q)
2121

2222
# get action for the state according to the q function table
2323
# agent pick action of epsilon-greedy policy

0 commit comments

Comments
 (0)
0