10
10
@dataclass
11
11
class Main_Body :
12
12
env : None
13
+ state : None
13
14
total_reward :None
14
15
t :None
15
16
is_terminated :None
16
17
18
+ image_envs = ['pong' ];
19
+ #TODO:a new verison
17
20
class OpenaiEnv (Env ):
18
21
'''the openai environment
19
22
@@ -40,14 +43,21 @@ def __init__(self, env_cfg) -> None:
40
43
t = 0 ;
41
44
self .main_body = Main_Body (env , total_reward , t , False );
42
45
46
+ def _transpose (self , state ):
47
+ return state .transpose ((2 , 0 , 1 ));
48
+
43
49
def get_state_and_action_dim (self ):
44
50
'''(state_dim, action_choice)
45
51
'''
46
- return self .main_body .env .observation_space .shape [0 ], self .main_body .env .action_space .n ;
52
+ if self .name .lower () in image_envs :
53
+ state_dim = self ._transpose (self .main_body .env .reset ()[0 ]).shape ;
54
+ else :
55
+ state_dim = self .main_body .env .observation_space .shape [0 ];
56
+ return state_dim , self .main_body .env .action_space .n ;
47
57
48
58
def get_state (self ):
49
59
'''get the current state'''
50
- return np .asarray (self .main_body .env . state , dtype = np .float32 );
60
+ return np .asarray (self .main_body .state , dtype = np .float32 );
51
61
52
62
def get_total_reward (self ):
53
63
'''Get the total rewards of the current trajectory so far'''
@@ -93,6 +103,9 @@ def reset(self):
93
103
self .main_body .t = 0 ;
94
104
self .main_body .is_terminated = False ;
95
105
state , _ = self .main_body .env .reset ();
106
+ if self .name .lower () in image_envs :
107
+ state = self ._transpose (state );
108
+ self .main_body .state = state ;
96
109
return state ;
97
110
98
111
def step (self , action ):
@@ -101,6 +114,9 @@ def step(self, action):
101
114
raise RuntimeError
102
115
self .main_body .t += 1 ;
103
116
next_state , reward , done , info1 , info2 = self .main_body .env .step (action );
117
+ if self .name .lower () in image_envs :
118
+ next_state = self ._transpose (next_state );
119
+ self .main_body .state = next_state ;
104
120
self .main_body .total_reward += reward ;
105
121
if self .main_body .t == self .survival_T :
106
122
done = True ;
0 commit comments