-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv_offlineReward.py
141 lines (116 loc) · 5.11 KB
/
env_offlineReward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from collections import defaultdict
import pickle
import dm_env
from dm_env import specs
import numpy as np
def _read_transitions(filename):
pickle_file = open(filename, 'rb')
transitions = pickle.load(pickle_file)
pickle_file.close()
return transitions
class CatEnv(dm_env.Environment):
def __init__(self, max_episode_len=10, data_filename='SARS.pkl'):
transitions_raw = _read_transitions(data_filename)
self.states = {} # key=55-dim vector, value=energy
self.initial_states = set()
#self.energies = {}
for _, _, reward, state in transitions_raw:
# Numpy arrays aren't hashable, so convert to tuples.
# Maximize rewards, but want min energy --> flip sign.
self.states[tuple(np.array(state, dtype=np.float32))] = float(-reward)
if np.sum(state) == 1:
self.initial_states.add(tuple(state))
#self.energies[tuple(np.array(state, dtype=np.float32))] = reward
#self.initial_states.add(tuple(state))
self.initial_states = [np.array(t, dtype=np.float32) for t in self.initial_states]
self.dim = len(self.initial_states[0])
self.max_episode_len = max_episode_len
self.curr_state = None
self.episode_len = 0
self._reset_next_step = True
# Define the penalty for going off behavior and the terminal state SiC
self.penalty_reward = -100.0
self.terminal_state = np.zeros(self.dim, dtype=np.float32)
self.terminal_state[6] = 1.0 # C
self.terminal_state[42] = 1.0 # Si
self.cumulative_reward = 0
def reset(self) -> dm_env.TimeStep:
self._reset_next_step = False
rand_start_idx = np.random.choice(len(self.initial_states))
self.curr_state = np.copy(self.initial_states[rand_start_idx])
self.episode_len = 0
#self.cumulative_reward = self.states[tuple(self.curr_state)]
return dm_env.TimeStep(dm_env.StepType.FIRST,
self.states[tuple(self.curr_state)],
1.0,
self.curr_state)
def step(self, int_action):
if self._reset_next_step:
return self.reset()
# KHALID
self.episode_len += 1
action = np.zeros(self.dim, dtype=np.float32)
if int_action < self.dim:
action[int_action] = 1.0
elif int_action < self.dim + 3:
# Delete either the 0th, 1st or 2nd element in curr_state. If deletion idx is
# greater than the total number, deletes the last one.
to_delete = int(min(np.sum(self.curr_state) - 1, int_action - self.dim))
idx_to_remove = np.where(self.curr_state != 0)[0][to_delete]
action[idx_to_remove] = -1.0
else:
# This is our "end" action; it's just a 0 vector no-op.
pass
# Let the action 0 mean that the agent asks for the episode to end.
if np.sum(action) == 0 or self.episode_len == self.max_episode_len:
self._reset_next_step = True
# We must be at an existing state; fetch reward.
# cumulative_reward = self.states[tuple(self.curr_state)]
cumulative_reward = self.cumulative_reward
print('Final state: ', self.episode_len, np.where(self.curr_state)[0], cumulative_reward)
return dm_env.TimeStep(dm_env.StepType.LAST, 0.0, 1.0, self.curr_state)
# Take action
# prev_reward = self.prev_reward #self.states[tuple(self.curr_state)]
self.curr_state += action
# Check if action is valid. If not, end the episode and give a negative reward.
invalid = False
if np.sum(self.curr_state) > 3.0 or np.sum(self.curr_state) <= 0:
invalid = True
elif np.sum(action) < 0 and np.any((self.curr_state) < 0):
invalid = True
elif np.any((self.curr_state) == 2.0):
invalid = True
if invalid:
self._reset_next_step = True
reward = self.penalty_reward
self.cumulative_reward += reward
return dm_env.TimeStep(dm_env.StepType.LAST, reward, 1.0, self.curr_state)
# Check if the action is valid but we don't have data there.
if not tuple(self.curr_state) in self.states:
# self._reset_next_step = True
reward = self.penalty_reward
self.cumulative_reward += reward
return dm_env.TimeStep(dm_env.StepType.MID, reward, 1.0, self.curr_state)
else:
# We're arriving at a valid state.
reward = self.states[tuple(self.curr_state)]
self.cumulative_reward += reward
# Terminal condition
if np.array_equal(self.curr_state, self.terminal_state):
self._reset_next_step = True
print('Final state: ', self.episode_len, np.where(self.curr_state)[0], self.cumulative_reward)
return dm_env.TimeStep(dm_env.StepType.LAST, reward, 1.0, self.curr_state)
else:
return dm_env.TimeStep(dm_env.StepType.MID, reward, 1.0, self.curr_state)
def action_spec(self):
# Represent the action as adding one of 55 elements, removing one element (either
# the first, second, or third), or terminate. Some of these actions are invalid
# depending on the state, and we mark those in the step() function.
return specs.DiscreteArray(self.dim + 3 + 1)
def observation_spec(self):
return specs.BoundedArray(
shape=(self.dim,),
dtype=np.float32,
minimum=0,
maximum=1,
)