From 3aa6d47b17d47f458eb5c555581581611ae6fa04 Mon Sep 17 00:00:00 2001 From: erwanlecarpentier Date: Mon, 19 Nov 2018 15:53:08 +0100 Subject: [PATCH] FrozenLake environment enhanced with static reachable state distribution provider --- dyna_gym/envs/ns_frozen_lake.py | 37 +++++++++++++++++++++------------ dyna_gym/utils/distribution.py | 3 +-- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/dyna_gym/envs/ns_frozen_lake.py b/dyna_gym/envs/ns_frozen_lake.py index 409b47eb..0704db9e 100644 --- a/dyna_gym/envs/ns_frozen_lake.py +++ b/dyna_gym/envs/ns_frozen_lake.py @@ -61,7 +61,7 @@ def get_position(s): p = s[0] else: p = s - assert(isinstance(p, np.int64) or isinstance(p, int)) + assert (isinstance(p, np.int64) or isinstance(p, int)), 'Error: position type is not int: type={}'.format(type(p)) return p class NSFrozenLakeEnv(Env): @@ -111,7 +111,7 @@ def __init__(self, desc=None, map_name="random", map_size=(3,5), is_slippery=Tru self.L_r = 10 # reward function Lipschitz constant self.action_space = spaces.Discrete(self.nA) - self.pos_space = np.array(range(self.nS)) + self.pos_space = np.arange(self.nS) self.observation_space = spaces.Discrete(self.nS) self.is_slippery = is_slippery @@ -182,18 +182,18 @@ def generate_transition_matrix(self): def transition_probability_distribution(self, s, t, a): p = get_position(s) - assert(p < self.nS) - assert(t < self.nT) - assert(a < self.nA) + assert p < self.nS, 'Error: position bigger than nS: p={} nS={}'.format(p, nS) + assert t < self.nT, 'Error: time bigger than nT: t={} nT={}'.format(t, nT) + assert a < self.nA, 'Error: action bigger than nA: a={} nA={}'.format(a, nA) return self.T[p, a, t] def transition_probability(self, s_p, s, t, a): p = get_position(s) p_p = get_position(s_p) - assert(p_p < self.nS) - assert(p < self.nS) - assert(t < self.nT) - assert(a < self.nA) + assert p_p < self.nS, 'Error: position bigger than nS: p_p={} nS={}'.format(p_p, nS) + assert p < self.nS, 'Error: position bigger than nS: p={} nS={}'.format(p, nS) + assert t < self.nT, 'Error: time bigger than nT: t={} nT={}'.format(t, nT) + assert a < self.nA, 'Error: action bigger than nA: a={} nA={}'.format(a, nA) return self.T[p, a, t, p_p] def is_terminal(self, s): @@ -205,19 +205,30 @@ def is_terminal(self, s): def get_time(self): return self.state[1] - def get_state_space_at_time(self, t): - return [(x, t) for x in self.pos_space] + def static_reachable_states(self, s, a): + """ + Return an array of the reachable states. + Static means that no time increment is performed. + """ + rs = self.reachable_states(s[0], a) + srs = np.zeros(shape=sum(rs), dtype=tuple) + idx = 0 + for i in range(len(rs)): + if rs[i] == 1: + srs[idx] = (i, s[1]) + idx += 1 + return srs def equality_operator(self, s1, s2): return (s1 == s2) def transition(self, s, a, is_model_dynamic=True): - ''' + """ Transition operator, return the resulting state, reward and a boolean indicating whether the termination criterion is reached or not. The boolean is_model_dynamic indicates whether the temporal transition is applied to the state vector or not. - ''' + """ p, t = s d = self.transition_probability_distribution(p, t, a) p_p = categorical_sample(d, self.np_random) diff --git a/dyna_gym/utils/distribution.py b/dyna_gym/utils/distribution.py index ebbfbde2..86d61810 100644 --- a/dyna_gym/utils/distribution.py +++ b/dyna_gym/utils/distribution.py @@ -66,7 +66,6 @@ def worst_dist(v, w0, c): Ae = np.reshape(np.concatenate((np.ones(n),np.zeros(n))), newshape=(1,2*n)) be = np.asarray([1]) - res = linprog(obj, A_eq=Ae, b_eq=be, A_ub=A, b_ub=b)#, method='interior-point') - + res = linprog(obj, A_eq=Ae, b_eq=be, A_ub=A, b_ub=b) x = res.x[:n] return clean_distribution(x)