10000 feature(whl): add SIL policy by kxzxvbk · Pull Request #675 · opendilab/DI-engine · GitHub
[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add SIL policy #675

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
update recompute adv
  • Loading branch information
‘whl’ committed Jun 22, 2023
commit eef28a3656fbf29b75e277c23db933e0a0c4cf07
18 changes: 15 additions & 3 deletions ding/policy/sil.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class SILA2CPolicy(A2CPolicy):
priority_IS_weight=False,
# (int) Number of epochs to use SIL loss to update the policy.
sil_update_per_collect=1,
sil_recompute_adv=True,
learn=dict(
update_per_collect=1, # fixed value, this line should not be modified by users
batch_size=64,
Expand Down Expand Up @@ -123,10 +124,21 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:

for batch in data_sil:
# forward
with torch.no_grad():
recomputed_value = self._learn_model.forward(data_onpolicy['obs'], mode='compute_critic')['value']
recomputed_next_value = self._learn_model.forward(data_onpolicy['next_obs'], mode='compute_critic')['value']

traj_flag = data_onpolicy.get('traj_flag', None) # traj_flag indicates termination of trajectory
compute_adv_data = gae_data(
recomputed_value, recomputed_next_value, data_onpolicy['reward'], data_onpolicy['done'], traj_flag
)
recomputed_adv = gae(compute_adv_data, self._gamma, self._gae_lambda)

recomputed_returns = recomputed_value + recomputed_adv
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')

adv = batch['adv']
return_ = batch['value'] + adv
adv = batch['adv'] if not self._cfg.sil_recompute_adv else recomputed_adv
return_ = batch['value'] + adv if not self._cfg.sil_recompute_adv else recomputed_returns
if self._adv_norm:
# norm adv in total train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
Expa 10000 nd Down Expand Up @@ -394,7 +406,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:

grad_norm = torch.nn.utils.clip_grad_norm_(
list(self._learn_model.parameters()),
max_norm=self.config.learn.grad_norm,
max_norm=self.config["learn"]["grad_norm"],
)
self._optimizer.step()

Expand Down
2 changes: 1 addition & 1 deletion dizoo/minigrid/config/minigrid_sil_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo_sil'),
policy=dict(type='sil_ppo'),
)
minigrid_sil_ppo_create_config = EasyDict(minigrid_sil_ppo_create_config)
create_config = minigrid_sil_ppo_create_config
Expand Down
4 changes: 2 additions & 2 deletions dizoo/minigrid/envs/minigrid_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def reset(self) -> np.ndarray:
self._env = ObsPlusPrevActRewWrapper(self._env)
self._init_flag = True
if self._flat_obs:
self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ), dytpe=np.float32)
self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ))
else:
self._observation_space = self._env.observation_space
# to be compatiable with subprocess env manager
Expand All @@ -70,7 +70,7 @@ def reset(self) -> np.ndarray:
self._observation_space.dtype = np.dtype('float32')
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, )
)

self._eval_episode_return = 0
Expand Down
0