-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
35 changed files
with
5,913 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
data/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from gym.envs.registration import register | ||
|
||
register( | ||
id='OnlyLong-v0', | ||
entry_point='drltr.envs:OnlyLongEnv', | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from collections import OrderedDict | ||
|
||
from .base_agent import BaseAgent | ||
from drltr.policies.MLP_policy import MLPPolicyAC | ||
from drltr.critics.bootstrapped_continuous_critic import BootstrappedContinuousCritic | ||
from drltr.infrastructure.replay_buffer import ReplayBuffer | ||
from drltr.infrastructure.utils import * | ||
|
||
class ACAgent(BaseAgent): | ||
def __init__(self, sess, env, agent_params): | ||
super(ACAgent, self).__init__() | ||
|
||
self.env = env | ||
self.sess = sess | ||
self.agent_params = agent_params | ||
|
||
self.gamma = self.agent_params['gamma'] | ||
self.standardize_advantages = self.agent_params['standardize_advantages'] | ||
|
||
self.actor = MLPPolicyAC(sess, | ||
self.agent_params['ac_dim'], | ||
self.agent_params['ob_dim'], | ||
self.agent_params['n_layers'], | ||
self.agent_params['size'], | ||
discrete=self.agent_params['discrete'], | ||
learning_rate=self.agent_params['learning_rate'], | ||
) | ||
self.critic = BootstrappedContinuousCritic(sess, self.agent_params) | ||
|
||
self.replay_buffer = ReplayBuffer() | ||
|
||
def estimate_advantage(self, ob_no, next_ob_no, re_n, terminal_n): | ||
|
||
# TODO Implement the following pseudocode: | ||
# 1) query the critic with ob_no, to get V(s) | ||
# 2) query the critic with next_ob_no, to get V(s') | ||
# 3) estimate the Q value as Q(s, a) = r(s, a) + gamma*V(s') | ||
# HINT: Remember to cut off the V(s') term (ie set it to 0) at terminal states (ie terminal_n=1) | ||
# 4) calculate advantage (adv_n) as A(s, a) = Q(s, a) - V(s) | ||
|
||
V_s = self.critic.forward(ob_no) | ||
next_V_s = self.critic.forward(next_ob_no) * (1 - terminal_n) | ||
Q_val = re_n + self.gamma * next_V_s | ||
adv_n = Q_val - V_s | ||
|
||
if self.standardize_advantages: | ||
adv_n = (adv_n - np.mean(adv_n)) / (np.std(adv_n) + 1e-8) | ||
return adv_n | ||
|
||
def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n): | ||
|
||
# TODO Implement the following pseudocode: | ||
# for agent_params['num_critic_updates_per_agent_update'] steps, | ||
# update the critic | ||
|
||
# advantage = estimate_advantage(...) | ||
|
||
# for agent_params['num_actor_updates_per_agent_update'] steps, | ||
# update the actor | ||
|
||
for i in range(self.agent_params['num_critic_updates_per_agent_update']): | ||
critic_loss = self.critic.update(ob_no, next_ob_no, re_n, terminal_n) | ||
|
||
adv_n = self.estimate_advantage(ob_no, next_ob_no, re_n, terminal_n) | ||
|
||
for i in range(self.agent_params['num_actor_updates_per_agent_update']): | ||
actor_loss = self.actor.update(ob_no, ac_na, adv_n) | ||
|
||
loss = OrderedDict() | ||
loss['Critic_Loss'] = critic_loss # put final critic loss here | ||
loss['Actor_Loss'] = actor_loss # put final actor loss here | ||
return loss | ||
|
||
def add_to_replay_buffer(self, paths): | ||
self.replay_buffer.add_rollouts(paths) | ||
|
||
def sample(self, batch_size): | ||
return self.replay_buffer.sample_recent_data(batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
class BaseAgent(object): | ||
def __init__(self, **kwargs): | ||
super(BaseAgent, self).__init__(**kwargs) | ||
|
||
def train(self): | ||
raise NotImplementedError | ||
|
||
def add_to_replay_buffer(self, paths): | ||
raise NotImplementedError | ||
|
||
def sample(self, batch_size): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import time | ||
|
||
from .base_agent import BaseAgent | ||
from drltr.policies.MLP_policy import * | ||
from drltr.infrastructure.replay_buffer import ReplayBuffer | ||
from drltr.infrastructure.utils import * | ||
|
||
class BCAgent(BaseAgent): | ||
def __init__(self, sess, env, agent_params): | ||
super(BCAgent, self).__init__() | ||
|
||
# init vars | ||
self.env = env | ||
self.sess = sess | ||
self.agent_params = agent_params | ||
|
||
# actor/policy | ||
self.actor = MLPPolicySL(sess, | ||
self.agent_params['ac_dim'], | ||
self.agent_params['ob_dim'], | ||
self.agent_params['n_layers'], | ||
self.agent_params['size'], | ||
discrete = self.agent_params['discrete'], | ||
learning_rate = self.agent_params['learning_rate'], | ||
) | ||
|
||
# replay buffer | ||
self.replay_buffer = ReplayBuffer(self.agent_params['max_replay_buffer_size']) | ||
|
||
def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n): | ||
# training a BC agent refers to updating its actor using | ||
# the given observations and corresponding action labels | ||
self.actor.update(ob_no, ac_na) | ||
|
||
def add_to_replay_buffer(self, paths): | ||
self.replay_buffer.add_rollouts(paths) | ||
|
||
def sample(self, batch_size): | ||
return self.replay_buffer.sample_random_data(batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
from drltr.infrastructure.dqn_utils import MemoryOptimizedReplayBuffer, PiecewiseSchedule | ||
from drltr.policies.argmax_policy import ArgMaxPolicy | ||
from drltr.critics.dqn_critic import DQNCritic | ||
from drltr.infrastructure.replay_buffer import ReplayBuffer | ||
|
||
|
||
class DQNAgent(object): | ||
def __init__(self, sess, env, agent_params): | ||
|
||
self.env = env | ||
self.sess = sess | ||
self.agent_params = agent_params | ||
self.batch_size = agent_params['batch_size'] | ||
self.last_obs = self.env.reset() | ||
|
||
self.num_actions = agent_params['ac_dim'] | ||
self.learning_starts = agent_params['learning_starts'] | ||
self.learning_freq = agent_params['learning_freq'] | ||
self.target_update_freq = agent_params['target_update_freq'] | ||
|
||
self.replay_buffer_idx = None | ||
self.exploration = agent_params['exploration_schedule'] | ||
self.optimizer_spec = agent_params['optimizer_spec'] | ||
|
||
self.critic = DQNCritic(sess, agent_params, self.optimizer_spec) | ||
self.actor = ArgMaxPolicy(sess, self.critic) | ||
|
||
lander = agent_params['env_name'] == 'LunarLander-v2' | ||
self.replay_buffer = MemoryOptimizedReplayBuffer( | ||
agent_params['replay_buffer_size'], agent_params['frame_history_len'], lander=lander) | ||
# self.replay_buffer = ReplayBuffer(agent_params['replay_buffer_size']) | ||
|
||
self.t = 0 | ||
self.num_param_updates = 0 | ||
|
||
def add_to_replay_buffer(self, paths): | ||
pass | ||
|
||
def step_env(self): | ||
|
||
""" | ||
Step the env and store the transition | ||
At the end of this block of code, the simulator should have been | ||
advanced one step, and the replay buffer should contain one more transition. | ||
Note that self.last_obs must always point to the new latest observation. | ||
""" | ||
|
||
# TODO store the latest observation into the replay buffer | ||
# HINT: see replay buffer's function store_frame | ||
self.replay_buffer_idx = self.replay_buffer.store_frame(self.last_obs) | ||
|
||
eps = self.exploration.value(self.t) | ||
# TODO use epsilon greedy exploration when selecting action | ||
# HINT: take random action | ||
# with probability eps (see np.random.random()) | ||
# OR if your current step number (see self.t) is less that self.learning_starts | ||
perform_random_action = (np.random.random() < eps) or (self.t < self.learning_starts) | ||
|
||
if perform_random_action: | ||
action = self.env.action_space.sample() | ||
else: | ||
# TODO query the policy to select action | ||
# HINT: you cannot use "self.last_obs" directly as input | ||
# into your network, since it needs to be processed to include context | ||
# from previous frames. | ||
# Check out the replay buffer, which has a function called | ||
# encode_recent_observation that will take the latest observation | ||
# that you pushed into the buffer and compute the corresponding | ||
# input that should be given to a Q network by appending some | ||
# previous frames. | ||
enc_last_obs = self.replay_buffer.encode_recent_observation() | ||
enc_last_obs = enc_last_obs[None, :] | ||
|
||
# TODO query the policy with enc_last_obs to select action | ||
action = self.actor.get_action(enc_last_obs) | ||
action = action[0] | ||
|
||
# TODO take a step in the environment using the action from the policy | ||
# HINT1: remember that self.last_obs must always point to the newest/latest observation | ||
# HINT2: remember the following useful function that you've seen before: | ||
#obs, reward, done, info = env.step(action) | ||
self.last_obs, reward, done, info = self.env.step(action) | ||
|
||
# TODO store the result of taking this action into the replay buffer | ||
# HINT1: see replay buffer's store_effect function | ||
# HINT2: one of the arguments you'll need to pass in is self.replay_buffer_idx from above | ||
self.replay_buffer.store_effect(self.replay_buffer_idx, action, reward, done) | ||
|
||
# TODO if taking this step resulted in done, reset the env (and the latest observation) | ||
if done: | ||
self.last_obs = self.env.reset() | ||
|
||
def sample(self, batch_size): | ||
if self.replay_buffer.can_sample(self.batch_size): | ||
return self.replay_buffer.sample(batch_size) | ||
else: | ||
return [],[],[],[],[] | ||
|
||
def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n): | ||
|
||
""" | ||
Here, you should train the DQN agent. | ||
This consists of training the critic, as well as periodically updating the target network. | ||
""" | ||
|
||
loss = 0.0 | ||
if (self.t > self.learning_starts and \ | ||
self.t % self.learning_freq == 0 and \ | ||
self.replay_buffer.can_sample(self.batch_size)): | ||
|
||
# TODO populate all placeholders necessary for calculating the critic's total_error | ||
# HINT: obs_t_ph, act_t_ph, rew_t_ph, obs_tp1_ph, done_mask_ph | ||
feed_dict = { | ||
self.critic.learning_rate: self.optimizer_spec.lr_schedule.value(self.t), | ||
self.critic.obs_t_ph: ob_no, | ||
self.critic.act_t_ph: ac_na, | ||
self.critic.rew_t_ph: re_n, | ||
self.critic.obs_tp1_ph: next_ob_no, | ||
self.critic.done_mask_ph: terminal_n, | ||
} | ||
|
||
# TODO: create a LIST of tensors to run in order to | ||
# train the critic as well as get the resulting total_error | ||
tensors_to_run = [self.critic.total_error, | ||
self.critic.train_fn] | ||
loss, _ = self.sess.run(tensors_to_run, feed_dict=feed_dict) | ||
# Note: remember that the critic's total_error value is what you | ||
# created to compute the Bellman error in a batch, | ||
# and the critic's train function performs a gradient step | ||
# and update the network parameters to reduce that total_error. | ||
|
||
# TODO: use sess.run to periodically update the critic's target function | ||
# HINT: see update_target_fn | ||
if self.num_param_updates % self.target_update_freq == 0: | ||
self.sess.run(self.critic.update_target_fn, feed_dict=feed_dict) | ||
|
||
self.num_param_updates += 1 | ||
|
||
self.t += 1 | ||
return loss |
Oops, something went wrong.