8000 [2023.07.16] commit-1 · DarriusL/DRL-ExampleCode@20bc51d · GitHub
[go: up one dir, main page]

65F1
Skip to content

Commit 20bc51d

Browse files
committed
[2023.07.16] commit-1
add pong to env.
1 parent ce0434f commit 20bc51d

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

env/openai_gym.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
@dataclass
1111
class Main_Body:
1212
env: None
13+
state: None
1314
total_reward:None
1415
t:None
1516
is_terminated:None
1617

18+
image_envs = ['pong'];
19+
#TODO:a new verison
1720
class OpenaiEnv(Env):
1821
'''the openai environment
1922
@@ -40,14 +43,21 @@ def __init__(self, env_cfg) -> None:
4043
t = 0;
4144
self.main_body = Main_Body(env, total_reward, t, False);
4245

46+
def _transpose(self, state):
47+
return state.transpose((2, 0, 1));
48+
4349
def get_state_and_action_dim(self):
4450
'''(state_dim, action_choice)
4551
'''
46-
return self.main_body.env.observation_space.shape[0], self.main_body.env.action_space.n;
52+
if self.name.lower() in image_envs:
53+
state_dim = self._transpose(self.main_body.env.reset()[0]).shape;
54+
else:
55+
state_dim = self.main_body.env.observation_space.shape[0];
56+
return state_dim, self.main_body.env.action_space.n;
4757

4858
def get_state(self):
4959
'''get the current state'''
50-
return np.asarray(self.main_body.env.state, dtype=np.float32);
60+
return np.asarray(self.main_body.state, dtype=np.float32);
5161

5262
def get_total_reward(self):
5363
'''Get the total rewards of the current trajectory so far'''
@@ -93,6 +103,9 @@ def reset(self):
93103
self.main_body.t = 0;
94104
self.main_body.is_terminated = False;
95105
state, _ = self.main_body.env.reset();
106+
if self.name.lower() in image_envs:
107+
state = self._transpose(state);
108+
self.main_body.state = state;
96109
return state;
97110

98111
def step(self, action):
@@ -101,6 +114,9 @@ def step(self, action):
101114
raise RuntimeError
102115
self.main_body.t += 1;
103116
next_state, reward, done, info1, info2 = self.main_body.env.step(action);
117+
if self.name.lower() in image_envs:
118+
next_state = self._transpose(next_state);
119+
self.main_body.state = next_state;
104120
self.main_body.total_reward += reward;
105121
if self.main_body.t == self.survival_T:
106122
done = True;

0 commit comments

Comments
 (0)
0