Skip to content

Commit

Permalink
style: rename some identifiers (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jan 4, 2021
1 parent f6bd14f commit 9bd1ad4
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 23 deletions.
21 changes: 9 additions & 12 deletions rls/algos/hierarchical/hiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
BatchExperiences,
NamedTupleStaticClass)

LowBatchExperiences = namedtuple('LowBatchExperiences', BatchExperiences._fields + ('subgoal', 'next_subgoal'))
HighBatchExperiences = namedtuple('HighBatchExperiences', 'obs, action, reward, done, subgoal, obs_')
Low_BatchExperiences = namedtuple('Low_BatchExperiences', BatchExperiences._fields + ('subgoal', 'next_subgoal'))
High_BatchExperiences = namedtuple('High_BatchExperiences', 'obs, action, reward, done, subgoal, obs_')

class HIRO(Off_Policy):
'''
Expand Down Expand Up @@ -170,14 +170,13 @@ def store_high_buffer(self, i):
g.append(self._subgoals[i][right])
d.append(self._done[i][-1])
s_.append(self._high_s_[i][-1])
self.data_high.add(HighBatchExperiences(
self.data_high.add(High_BatchExperiences(
np.array(s),
np.array(a),
np.array(r)[:, np.newaxis],
np.array(d)[:, np.newaxis],
np.array(r),
np.array(d),
np.array(g),
np.array(s_)

))

def reset(self):
Expand Down Expand Up @@ -266,7 +265,7 @@ def learn(self, **kwargs):
self.write_training_summaries(self.global_step, self.summaries)

@tf.function
def train_low(self, BATCH: LowBatchExperiences):
def train_low(self, BATCH: Low_BatchExperiences):
with tf.device(self.device):
with tf.GradientTape() as tape:
feat = tf.concat([BATCH.obs.vector, BATCH.subgoal], axis=-1)
Expand Down Expand Up @@ -323,7 +322,7 @@ def train_low(self, BATCH: LowBatchExperiences):
])

@tf.function
def train_high(self, BATCH: HighBatchExperiences):
def train_high(self, BATCH: High_BatchExperiences):
# BATCH.obs_ : [B, N]
# BATCH.obs, BATCH.action [B, T, *]
batchs = tf.shape(BATCH.obs)[0]
Expand Down Expand Up @@ -407,8 +406,7 @@ def no_op_store(self, exps: BatchExperiences):
# subgoal = exps.obs.vector[:, self.fn_goal_dim:] + self._noop_subgoal - exps.obs_.vector[:, self.fn_goal_dim:]
subgoal = np.random.uniform(-self.high_scale, self.high_scale, size=(self.n_agents, self.sub_goal_dim))

exps = exps._replace(done=exps.done)
dl = LowBatchExperiences(*exps, self._noop_subgoal, subgoal)._replace(reward=ir)
dl = Low_BatchExperiences(*exps, self._noop_subgoal, subgoal)._replace(reward=ir)
self.data_low.add(dl)
self._noop_subgoal = subgoal

Expand All @@ -426,8 +424,7 @@ def store_data(self, exps: BatchExperiences):
ir = self.get_ir(exps.obs.vector[:, self.fn_goal_dim:], self._subgoal, exps.obs_.vector[:, self.fn_goal_dim:])
self._new_subgoal = np.where(self._c == 1, self.get_subgoal(exps.obs_.vector).numpy(), exps.obs.vector[:, self.fn_goal_dim:] + self._subgoal - exps.obs_.vector[:, self.fn_goal_dim:])

exps = exps._replace(done=exps.done)
dl = LowBatchExperiences(*exps, self._subgoal, self._new_subgoal)._replace(reward=ir)
dl = Low_BatchExperiences(*exps, self._subgoal, self._new_subgoal)._replace(reward=ir)
self.data_low.add(dl)

self._c = np.where(self._c == 1, np.full((self.n_agents, 1), self.sub_goal_steps, np.int32), self._c - 1)
Expand Down
4 changes: 2 additions & 2 deletions rls/algos/hierarchical/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rls.utils.specs import (OutputNetworkType,
BatchExperiences)

IOCBatchExperiences = namedtuple('IOCBatchExperiences', BatchExperiences._fields + ('last_options', 'options'))
IOC_BatchExperiences = namedtuple('IOC_BatchExperiences', BatchExperiences._fields + ('last_options', 'options'))

class IOC(Off_Policy):
'''
Expand Down Expand Up @@ -260,7 +260,7 @@ def store_data(self, exps: BatchExperiences):
for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer.
"""
self._running_average(exps.obs.vector)
self.data.add(IOCBatchExperiences(*exps, self.last_options, self.options))
self.data.add(IOC_BatchExperiences(*exps, self.last_options, self.options))

def no_op_store(self, exps: BatchExperiences):
pass
4 changes: 2 additions & 2 deletions rls/algos/hierarchical/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rls.utils.specs import (OutputNetworkType,
BatchExperiences)

OCBatchExperiences = namedtuple('OCBatchExperiences', BatchExperiences._fields + ('last_options', 'options'))
OC_BatchExperiences = namedtuple('OC_BatchExperiences', BatchExperiences._fields + ('last_options', 'options'))

class OC(Off_Policy):
'''
Expand Down Expand Up @@ -263,7 +263,7 @@ def store_data(self, exps: BatchExperiences):
for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer.
"""
self._running_average(exps.obs.vector)
self.data.add(OCBatchExperiences(*exps, self.last_options, self.options))
self.data.add(OC_BatchExperiences(*exps, self.last_options, self.options))

def no_op_store(self, exps: BatchExperiences):
pass
6 changes: 3 additions & 3 deletions rls/algos/single/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rls.utils.specs import (OutputNetworkType,
BatchExperiences)

ACBatchExperiences = namedtuple('ACBatchExperiences', BatchExperiences._fields + ('old_log_prob',))
AC_BatchExperiences = namedtuple('AC_BatchExperiences', BatchExperiences._fields + ('old_log_prob',))

class AC(Off_Policy):
# off-policy actor-critic
Expand Down Expand Up @@ -89,11 +89,11 @@ def _get_action(self, obs, cell_state):

def store_data(self, exps: BatchExperiences):
self._running_average(exps.obs.vector)
self.data.add(ACBatchExperiences(*exps, self._log_prob))
self.data.add(AC_BatchExperiences(*exps, self._log_prob))

def no_op_store(self, exps: BatchExperiences):
self._running_average(exps.obs.vector)
self.data.add(ACBatchExperiences(*exps, np.ones_like(exps.reward)))
self.data.add(AC_BatchExperiences(*exps, np.ones_like(exps.reward)))

def learn(self, **kwargs):
self.train_step = kwargs.get('train_step')
Expand Down
8 changes: 4 additions & 4 deletions rls/algos/single/curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
BatchExperiences,
NamedTupleStaticClass)

CURLBatchExperiences = namedtuple('CURLBatchExperiences', BatchExperiences._fields + ('pos',))
CURL_BatchExperiences = namedtuple('CURL_BatchExperiences', BatchExperiences._fields + ('pos',))


class VisualEncoder(M):
Expand Down Expand Up @@ -212,7 +212,7 @@ def _get_action(self, obs):
pi = cate_dist.sample()
return mu, pi

def _process_before_train(self, data: BatchExperiences) -> CURLBatchExperiences:
def _process_before_train(self, data: BatchExperiences) -> CURL_BatchExperiences:
data = data._replace(
obs=data.obs._replace(visual=np.transpose(data.obs.visual[:, 0].numpy(), (0, 3, 1, 2))),
obs_=data.obs_._replace(visual=np.transpose(data.obs_.visual[:, 0].numpy(), (0, 3, 1, 2))))
Expand All @@ -221,7 +221,7 @@ def _process_before_train(self, data: BatchExperiences) -> CURLBatchExperiences:
obs=data.obs._replace(visual=np.transpose(random_crop(data.obs.visual, self.img_size), (0, 2, 3, 1))),
obs_=data.obs_._replace(visual=np.transpose(random_crop(data.obs_.visual, self.img_size), (0, 2, 3, 1)))
)
new_data = CURLBatchExperiences(*data, pos)
new_data = CURL_BatchExperiences(*data, pos)
return NamedTupleStaticClass.data_convert(self.data_convert, new_data)

def _target_params_update(self):
Expand Down Expand Up @@ -253,7 +253,7 @@ def _train(self, BATCH: BatchExperiences, isw, cell_state):
return td_error, summaries

@tf.function
def train(self, BATCH: CURLBatchExperiences, isw, cell_state):
def train(self, BATCH: CURL_BatchExperiences, isw, cell_state):
with tf.device(self.device):
with tf.GradientTape(persistent=True) as tape:
vis_feat = self.encoder(visual_s)
Expand Down

0 comments on commit 9bd1ad4

Please sign in to comment.