forked from cts198859/deeprl_signal_control
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_env.py
39 lines (32 loc) · 1.05 KB
/
test_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gym
import numpy as np
class GymEnv:
def __init__(self, name):
env = gym.make(name)
self.n_a = env.action_space.n
self.n_a_ls = [self.n_a]
self.n_s = env.observation_space.shape[0]
self.n_s_ls = [self.n_s]
self.agent = 'iqld'
self.T = 1000
s_min = np.array(env.observation_space.low)
s_max = np.array(env.observation_space.high)
s_mean, s_scale = .5 * (s_min + s_max), .5 * (s_max - s_min)
self.t = 0
def scale_ob(ob):
return (np.array(ob) - s_mean) / s_scale
def reset():
self.t = 0
return [scale_ob(env.reset())]
def step(action):
ob, r, done, info = env.step(action[0])
self.t += 1
if self.t >= self.T:
done = True
return np.array([scale_ob(np.ravel(ob))]), np.array([r]), done, r
def terminate():
return
self.seed = env.seed
self.step = step
self.reset = reset
self.terminate = terminate