Skip to content

Commit

Permalink
FrozenLake environment enhanced with static reachable state distribut…
Browse files Browse the repository at this point in the history
…ion provider
  • Loading branch information
erwanlecarpentier committed Nov 19, 2018
1 parent c1aaecc commit 3aa6d47
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
37 changes: 24 additions & 13 deletions dyna_gym/envs/ns_frozen_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_position(s):
p = s[0]
else:
p = s
assert(isinstance(p, np.int64) or isinstance(p, int))
assert (isinstance(p, np.int64) or isinstance(p, int)), 'Error: position type is not int: type={}'.format(type(p))
return p

class NSFrozenLakeEnv(Env):
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self, desc=None, map_name="random", map_size=(3,5), is_slippery=Tru
self.L_r = 10 # reward function Lipschitz constant

self.action_space = spaces.Discrete(self.nA)
self.pos_space = np.array(range(self.nS))
self.pos_space = np.arange(self.nS)
self.observation_space = spaces.Discrete(self.nS)

self.is_slippery = is_slippery
Expand Down Expand Up @@ -182,18 +182,18 @@ def generate_transition_matrix(self):

def transition_probability_distribution(self, s, t, a):
p = get_position(s)
assert(p < self.nS)
assert(t < self.nT)
assert(a < self.nA)
assert p < self.nS, 'Error: position bigger than nS: p={} nS={}'.format(p, nS)
assert t < self.nT, 'Error: time bigger than nT: t={} nT={}'.format(t, nT)
assert a < self.nA, 'Error: action bigger than nA: a={} nA={}'.format(a, nA)
return self.T[p, a, t]

def transition_probability(self, s_p, s, t, a):
p = get_position(s)
p_p = get_position(s_p)
assert(p_p < self.nS)
assert(p < self.nS)
assert(t < self.nT)
assert(a < self.nA)
assert p_p < self.nS, 'Error: position bigger than nS: p_p={} nS={}'.format(p_p, nS)
assert p < self.nS, 'Error: position bigger than nS: p={} nS={}'.format(p, nS)
assert t < self.nT, 'Error: time bigger than nT: t={} nT={}'.format(t, nT)
assert a < self.nA, 'Error: action bigger than nA: a={} nA={}'.format(a, nA)
return self.T[p, a, t, p_p]

def is_terminal(self, s):
Expand All @@ -205,19 +205,30 @@ def is_terminal(self, s):
def get_time(self):
return self.state[1]

def get_state_space_at_time(self, t):
return [(x, t) for x in self.pos_space]
def static_reachable_states(self, s, a):
"""
Return an array of the reachable states.
Static means that no time increment is performed.
"""
rs = self.reachable_states(s[0], a)
srs = np.zeros(shape=sum(rs), dtype=tuple)
idx = 0
for i in range(len(rs)):
if rs[i] == 1:
srs[idx] = (i, s[1])
idx += 1
return srs

def equality_operator(self, s1, s2):
return (s1 == s2)

def transition(self, s, a, is_model_dynamic=True):
'''
"""
Transition operator, return the resulting state, reward and a boolean indicating
whether the termination criterion is reached or not.
The boolean is_model_dynamic indicates whether the temporal transition is applied
to the state vector or not.
'''
"""
p, t = s
d = self.transition_probability_distribution(p, t, a)
p_p = categorical_sample(d, self.np_random)
Expand Down
3 changes: 1 addition & 2 deletions dyna_gym/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def worst_dist(v, w0, c):
Ae = np.reshape(np.concatenate((np.ones(n),np.zeros(n))), newshape=(1,2*n))
be = np.asarray([1])

res = linprog(obj, A_eq=Ae, b_eq=be, A_ub=A, b_ub=b)#, method='interior-point')

res = linprog(obj, A_eq=Ae, b_eq=be, A_ub=A, b_ub=b)
x = res.x[:n]
return clean_distribution(x)

0 comments on commit 3aa6d47

Please sign in to comment.