From 309d63fbc1e06f5f47de145ae2cc3bf18b4e8e21 Mon Sep 17 00:00:00 2001 From: Keavnn Date: Fri, 3 Sep 2021 22:10:58 +0800 Subject: [PATCH] v5.1.10 perf: optimized `dreamer` related. (#34) 1. renamed `iTensor_oNumpy` to `iton` 2. optimized `auto_format.py` 3. added general params `oplr_params` to initializing optimizer --- auto_format.py | 16 +++-- rls/_metadata.py | 2 +- rls/algorithms/base/marl_off_policy.py | 4 +- rls/algorithms/base/policy.py | 4 +- rls/algorithms/base/sarl_off_policy.py | 4 +- rls/algorithms/multi/maddpg.py | 10 +-- rls/algorithms/multi/masac.py | 12 ++-- rls/algorithms/multi/qplex.py | 4 +- rls/algorithms/multi/qtran.py | 4 +- rls/algorithms/multi/vdn.py | 8 +-- rls/algorithms/single/a2c.py | 12 ++-- rls/algorithms/single/ac.py | 10 +-- rls/algorithms/single/averaged_dqn.py | 8 +-- rls/algorithms/single/bootstrappeddqn.py | 8 +-- rls/algorithms/single/c51.py | 8 +-- rls/algorithms/single/dddqn.py | 8 +-- rls/algorithms/single/ddpg.py | 10 +-- rls/algorithms/single/ddqn.py | 4 +- rls/algorithms/single/dpg.py | 10 +-- rls/algorithms/single/dqn.py | 8 +-- rls/algorithms/single/hierarchical/aoc.py | 12 ++-- rls/algorithms/single/hierarchical/ioc.py | 16 ++--- rls/algorithms/single/hierarchical/oc.py | 14 ++-- rls/algorithms/single/hierarchical/ppoc.py | 12 ++-- rls/algorithms/single/iqn.py | 8 +-- rls/algorithms/single/maxsqn.py | 10 +-- .../single/modelbased/dreamer_v1.py | 13 ++-- rls/algorithms/single/modelbased/planet.py | 67 +++++++------------ rls/algorithms/single/npg.py | 10 +-- rls/algorithms/single/pg.py | 8 +-- rls/algorithms/single/ppo.py | 31 +++------ rls/algorithms/single/qrdqn.py | 8 +-- rls/algorithms/single/rainbow.py | 8 +-- rls/algorithms/single/sac.py | 12 ++-- rls/algorithms/single/sac_v.py | 14 ++-- rls/algorithms/single/sql.py | 8 +-- rls/algorithms/single/tac.py | 12 ++-- rls/algorithms/single/td3.py | 10 +-- rls/algorithms/single/trpo.py | 4 +- rls/common/decorator.py | 2 +- rls/common/specs.py | 19 ++---- rls/configs/algorithms.yaml | 18 ++--- rls/nn/modules/icm.py | 2 +- rls/nn/utils.py | 18 ++--- 44 files changed, 224 insertions(+), 266 deletions(-) diff --git a/auto_format.py b/auto_format.py index fa2ee3a..424341c 100644 --- a/auto_format.py +++ b/auto_format.py @@ -13,13 +13,17 @@ def get_args(): help='.py file path that need to be formatted.') parser.add_argument('-d', '--file-dir', type=str, default=None, help='.py dictionary that need to be formatted.') + parser.add_argument('--ignore-pep', default=False, action='store_true', + help='whether format the file') return parser.parse_args() -def autopep8(file_path): +def autopep8(file_path, ignore_pep): isort.file(file_path) - os.system(f"autopep8 -i {file_path}") - print(f'autopep8 file: {file_path} SUCCESSFULLY.') + print(f'isort file: {file_path} SUCCESSFULLY.') + if not ignore_pep: + os.system(f"autopep8 -j 0 -i {file_path} --max-line-length 200") + print(f'autopep8 file: {file_path} SUCCESSFULLY.') if __name__ == '__main__': @@ -27,7 +31,7 @@ def autopep8(file_path): args = get_args() if args.file_path: - autopep8(args.file_path) + autopep8(args.file_path, args.ignore_pep) if args.file_dir: py_files = [] @@ -35,6 +39,6 @@ def autopep8(file_path): py_files.extend(glob.glob(root + "/*.py")) for path in py_files: - autopep8(path) + autopep8(path, args.ignore_pep) - print('autopep8 finished.') + print('auto-format finished.') diff --git a/rls/_metadata.py b/rls/_metadata.py index b640f46..d25dc09 100644 --- a/rls/_metadata.py +++ b/rls/_metadata.py @@ -8,7 +8,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = '5' _MINOR_VERSION = '1' -_PATCH_VERSION = '9' +_PATCH_VERSION = '10' # Example: '0.4.2' __version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) diff --git a/rls/algorithms/base/marl_off_policy.py b/rls/algorithms/base/marl_off_policy.py index a144dd7..4aaa915 100644 --- a/rls/algorithms/base/marl_off_policy.py +++ b/rls/algorithms/base/marl_off_policy.py @@ -8,7 +8,7 @@ import torch as t from rls.algorithms.base.marl_policy import MarlPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.common.yaml_ops import load_config from rls.utils.np_utils import int2one_hot @@ -107,7 +107,7 @@ def _before_train(self, BATCH_DICT): self.summaries = {} return BATCH_DICT - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): raise NotImplementedError diff --git a/rls/algorithms/base/policy.py b/rls/algorithms/base/policy.py index ac314fe..311842f 100644 --- a/rls/algorithms/base/policy.py +++ b/rls/algorithms/base/policy.py @@ -37,7 +37,7 @@ def __init__(self, decay_lr=False, normalize_vector_obs=False, obs_with_pre_action=False, - optim_params=dict(), + oplr_params=dict(), rep_net_params={ 'vector_net_params': { 'h_dim': 16, @@ -80,7 +80,7 @@ def __init__(self, self._normalize_vector_obs = normalize_vector_obs # TODO: implement self._obs_with_pre_action = obs_with_pre_action self._rep_net_params = dict(rep_net_params) - self._optim_params = dict(optim_params) + self._oplr_params = dict(oplr_params) super().__init__() diff --git a/rls/algorithms/base/sarl_off_policy.py b/rls/algorithms/base/sarl_off_policy.py index 29b2754..1345253 100644 --- a/rls/algorithms/base/sarl_off_policy.py +++ b/rls/algorithms/base/sarl_off_policy.py @@ -7,7 +7,7 @@ import torch as t from rls.algorithms.base.sarl_policy import SarlPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.common.when import Every from rls.common.yaml_ops import load_config @@ -98,7 +98,7 @@ def _before_train(self, BATCH): self.summaries.update(crsty_summaries) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): raise NotImplementedError diff --git a/rls/algorithms/multi/maddpg.py b/rls/algorithms/multi/maddpg.py index 1f7f474..2005030 100644 --- a/rls/algorithms/multi/maddpg.py +++ b/rls/algorithms/multi/maddpg.py @@ -9,7 +9,7 @@ from torch import distributions as td from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorDPG, MACriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -66,8 +66,8 @@ def __init__(self, self.a_dims.values()), network_settings=network_settings['q']), self.ployak).to(self.device) - self.actor_oplr = OPLR(list(self.actors.values()), actor_lr) - self.critic_oplr = OPLR(list(self.critics.values()), critic_lr) + self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(list(self.critics.values()), critic_lr, **self._oplr_params) # TODO: 添加动作类型判断 self.noised_actions = {id: Noise_action_REGISTER[noise_action](**noise_params) @@ -85,7 +85,7 @@ def episode_reset(self): for noised_action in self.noised_actions.values(): noised_action.reset() - @iTensor_oNumpy + @iton def select_action(self, obs: Dict): acts_info = {} actions = {} @@ -106,7 +106,7 @@ def select_action(self, obs: Dict): actions[aid] = action return actions, acts_info - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): ''' TODO: Annotation diff --git a/rls/algorithms/multi/masac.py b/rls/algorithms/multi/masac.py index 142a68c..5b6bd41 100644 --- a/rls/algorithms/multi/masac.py +++ b/rls/algorithms/multi/masac.py @@ -10,7 +10,7 @@ from torch import distributions as td from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorCts, ActorDct, MACriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -79,13 +79,13 @@ def __init__(self, network_settings=network_settings['q']), self.ployak).to(self.device) self.critics2[id] = deepcopy(self.critics[id]) - self.actor_oplr = OPLR(list(self.actors.values()), actor_lr) + self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params) self.critic_oplr = OPLR( - list(self.critics.values())+list(self.critics2.values()), critic_lr) + list(self.critics.values())+list(self.critics2.values()), critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = t.tensor(0., requires_grad=True).to(self.device) - self.alpha_oplr = OPLR(self.log_alpha, alpha_lr) + self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = t.tensor(alpha).log().to(self.device) @@ -106,7 +106,7 @@ def __init__(self, def alpha(self): return self.log_alpha.exp() - @iTensor_oNumpy + @iton def select_action(self, obs: Dict): acts_info = {} actions = {} @@ -128,7 +128,7 @@ def select_action(self, obs: Dict): actions[aid] = action return actions, acts_info - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): ''' TODO: Annotation diff --git a/rls/algorithms/multi/qplex.py b/rls/algorithms/multi/qplex.py index a0eb348..4f5d553 100644 --- a/rls/algorithms/multi/qplex.py +++ b/rls/algorithms/multi/qplex.py @@ -4,7 +4,7 @@ import torch as t from rls.algorithms.multi.vdn import VDN -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.nn.mixers import Mixer_REGISTER from rls.nn.modules.wrappers import TargetTwin from rls.utils.torch_utils import n_step_return @@ -37,7 +37,7 @@ def _build_mixer(self): **self._mixer_settings) ).to(self.device) - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] diff --git a/rls/algorithms/multi/qtran.py b/rls/algorithms/multi/qtran.py index dc0df9e..3e8ec46 100644 --- a/rls/algorithms/multi/qtran.py +++ b/rls/algorithms/multi/qtran.py @@ -4,7 +4,7 @@ import torch as t from rls.algorithms.multi.vdn import VDN -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.nn.mixers import Mixer_REGISTER from rls.nn.modules.wrappers import TargetTwin from rls.utils.torch_utils import n_step_return @@ -42,7 +42,7 @@ def _build_mixer(self): **self._mixer_settings) ).to(self.device) - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] diff --git a/rls/algorithms/multi/vdn.py b/rls/algorithms/multi/vdn.py index bfd10f6..abd7522 100644 --- a/rls/algorithms/multi/vdn.py +++ b/rls/algorithms/multi/vdn.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.mixers import Mixer_REGISTER from rls.nn.models import CriticDueling @@ -61,7 +61,7 @@ def __init__(self, self.mixer = self._build_mixer() - self.oplr = OPLR(tuple(self.q_nets.values())+(self.mixer,), lr) + self.oplr = OPLR(tuple(self.q_nets.values())+(self.mixer,), lr, **self._oplr_params) self._trainer_modules.update( {f"model_{id}": self.q_nets[id] for id in set(self.model_ids)}) self._trainer_modules.update(mixer=self.mixer, @@ -79,7 +79,7 @@ def _build_mixer(self): **self._mixer_settings) ).to(self.device) - @iTensor_oNumpy # TODO: optimization + @iton # TODO: optimization def select_action(self, obs): acts_info = {} actions = {} @@ -97,7 +97,7 @@ def select_action(self, obs): acts_info[aid] = Data(action=action) return actions, acts_info - @iTensor_oNumpy + @iton def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] diff --git a/rls/algorithms/single/a2c.py b/rls/algorithms/single/a2c.py index 66fb905..4e0d1e6 100644 --- a/rls/algorithms/single/a2c.py +++ b/rls/algorithms/single/a2c.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorMuLogstd, CriticValue from rls.nn.utils import OPLR @@ -52,15 +52,15 @@ def __init__(self, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR(self.critic, critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.actor.get_cell_state() @@ -78,7 +78,7 @@ def select_action(self, obs): acts_info.update(cell_state=self.cell_state) return action, acts_info - @iTensor_oNumpy + @iton def _get_value(self, obs, cell_state=None): value = self.critic(obs, cell_state=self.cell_state) return value @@ -93,7 +93,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *] init_value=value) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - v # [T, B, 1] diff --git a/rls/algorithms/single/ac.py b/rls/algorithms/single/ac.py index b8aa2d3..3660365 100644 --- a/rls/algorithms/single/ac.py +++ b/rls/algorithms/single/ac.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorMuLogstd, CriticQvalueOne from rls.nn.utils import OPLR @@ -47,15 +47,15 @@ def __init__(self, action_dim=self.a_dim, network_settings=network_settings['critic']).to(self.device) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR(self.critic, critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, *] self.next_cell_state = self.actor.get_cell_state() @@ -82,7 +82,7 @@ def random_action(self): self.n_copys, 1./self.a_dim)) # [B,] return actions - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] diff --git a/rls/algorithms/single/averaged_dqn.py b/rls/algorithms/single/averaged_dqn.py index b9439d1..6a89ab1 100644 --- a/rls/algorithms/single/averaged_dqn.py +++ b/rls/algorithms/single/averaged_dqn.py @@ -8,7 +8,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll from rls.nn.utils import OPLR @@ -55,11 +55,11 @@ def __init__(self, sync_params(target_q_net, self.q_net) self.target_nets.append(target_q_net) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [B, *] self.next_cell_state = self.q_net.get_cell_state() @@ -74,7 +74,7 @@ def select_action(self, obs): actions = q_values.argmax(-1) # 不取平均也可以 [B, ] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] q_next = 0 diff --git a/rls/algorithms/single/bootstrappeddqn.py b/rls/algorithms/single/bootstrappeddqn.py index 666e9f1..3b2b021 100644 --- a/rls/algorithms/single/bootstrappeddqn.py +++ b/rls/algorithms/single/bootstrappeddqn.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueBootstrap from rls.nn.modules.wrappers import TargetTwin @@ -49,7 +49,7 @@ def __init__(self, head_num=self.head_num, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @@ -57,7 +57,7 @@ def episode_reset(self): super().episode_reset() self.now_head = np.random.randint(self.head_num) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [H, B, A] self.next_cell_state = self.q_net.get_cell_state() @@ -69,7 +69,7 @@ def select_action(self, obs): actions = q_values[self.now_head].argmax(-1) return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] diff --git a/rls/algorithms/single/c51.py b/rls/algorithms/single/c51.py index 3c329b9..d07a41b 100644 --- a/rls/algorithms/single/c51.py +++ b/rls/algorithms/single/c51.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import C51Distributional from rls.nn.modules.wrappers import TargetTwin @@ -52,11 +52,11 @@ def __init__(self, action_dim=self.a_dim, atoms=self._atoms, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): feat = self.q_net(obs, cell_state=self.cell_state) # [B, A, N] self.next_cell_state = self.q_net.get_cell_state() @@ -68,7 +68,7 @@ def select_action(self, obs): actions = q.argmax(-1) # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q_dist = self.q_net( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] diff --git a/rls/algorithms/single/dddqn.py b/rls/algorithms/single/dddqn.py index 7986c97..3f6c9ad 100644 --- a/rls/algorithms/single/dddqn.py +++ b/rls/algorithms/single/dddqn.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticDueling from rls.nn.modules.wrappers import TargetTwin @@ -47,11 +47,11 @@ def __init__(self, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.q_net.get_cell_state() @@ -62,7 +62,7 @@ def select_action(self, obs): actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] next_q = self.q_net( diff --git a/rls/algorithms/single/ddpg.py b/rls/algorithms/single/ddpg.py index 9017024..976e13e 100644 --- a/rls/algorithms/single/ddpg.py +++ b/rls/algorithms/single/ddpg.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorDPG, CriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -70,8 +70,8 @@ def __init__(self, network_settings=network_settings['q']), self.ployak).to(self.device) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR(self.critic, critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, @@ -82,7 +82,7 @@ def episode_reset(self): if self.is_continuous: self.noised_action.reset() - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.actor.get_cell_state() @@ -97,7 +97,7 @@ def select_action(self, obs): actions = pi if self._is_train_mode else mu return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): if self.is_continuous: action_target = self.actor.t( diff --git a/rls/algorithms/single/ddqn.py b/rls/algorithms/single/ddqn.py index 09b2f91..9b09375 100644 --- a/rls/algorithms/single/ddqn.py +++ b/rls/algorithms/single/ddqn.py @@ -7,7 +7,7 @@ import torch as t from rls.algorithms.single.dqn import DQN -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.utils.torch_utils import n_step_return @@ -21,7 +21,7 @@ class DDQN(DQN): def __init__(self, **kwargs): super().__init__(**kwargs) - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net( diff --git a/rls/algorithms/single/dpg.py b/rls/algorithms/single/dpg.py index 418c8f8..5b4e231 100644 --- a/rls/algorithms/single/dpg.py +++ b/rls/algorithms/single/dpg.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorDPG, CriticQvalueOne from rls.nn.noised_actions import (ClippedNormalNoisedAction, @@ -60,8 +60,8 @@ def __init__(self, action_dim=self.a_dim, network_settings=network_settings['q']).to(self.device) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR(self.critic, critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, @@ -72,7 +72,7 @@ def episode_reset(self): if self.is_continuous: self.noised_action.reset() - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.actor.get_cell_state() @@ -87,7 +87,7 @@ def select_action(self, obs): actions = pi if self._is_train_mode else mu return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): if self.is_continuous: action_target = self.actor( diff --git a/rls/algorithms/single/dqn.py b/rls/algorithms/single/dqn.py index edb7101..03c9877 100644 --- a/rls/algorithms/single/dqn.py +++ b/rls/algorithms/single/dqn.py @@ -7,7 +7,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll from rls.nn.modules.wrappers import TargetTwin @@ -44,11 +44,11 @@ def __init__(self, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net) self._trainer_modules.update(oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [B, *] self.next_cell_state = self.q_net.get_cell_state() @@ -59,7 +59,7 @@ def select_action(self, obs): actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] q_next = self.q_net.t( diff --git a/rls/algorithms/single/hierarchical/aoc.py b/rls/algorithms/single/hierarchical/aoc.py index 259fe49..1f6cc5a 100644 --- a/rls/algorithms/single/hierarchical/aoc.py +++ b/rls/algorithms/single/hierarchical/aoc.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import AocShare from rls.nn.utils import OPLR @@ -75,9 +75,9 @@ def __init__(self, if self.is_continuous: self.log_std = t.as_tensor(np.full( (self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] - self.oplr = OPLR([self.net, self.log_std], lr) + self.oplr = OPLR([self.net, self.log_std], lr, **self._oplr_params) else: - self.oplr = OPLR(self.net, lr) + self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) @@ -98,7 +98,7 @@ def episode_step(self, self.options = self.new_options self.oc_mask = t.zeros_like(self.oc_mask) - @iTensor_oNumpy + @iton def select_action(self, obs): # [B, P], [B, P, A], [B, P] (q, pi, beta) = self.net(obs, cell_state=self.cell_state) @@ -140,7 +140,7 @@ def select_action(self, obs): acts_info.update(cell_state=self.cell_state) return action, acts_info - @iTensor_oNumpy + @iton def _get_value(self, obs, options, cell_state=None): (q, _, _) = self.net(obs, cell_state=cell_state) # [B, P] value = (q * options).sum(-1, keepdim=True) # [B, 1] @@ -185,7 +185,7 @@ def learn(self, BATCH: Data): if sum(kls)/len(kls) > self.kl_stop: break - @iTensor_oNumpy + @iton def _train(self, BATCH): # [T, B, P], [T, B, P, A], [T, B, P] (q, pi, beta) = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) diff --git a/rls/algorithms/single/hierarchical/ioc.py b/rls/algorithms/single/hierarchical/ioc.py index 3c3cb97..9a56779 100644 --- a/rls/algorithms/single/hierarchical/ioc.py +++ b/rls/algorithms/single/hierarchical/ioc.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll, OcIntraOption from rls.nn.modules.wrappers import TargetTwin @@ -77,15 +77,15 @@ def __init__(self, self.log_std = t.as_tensor(np.full( (self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.intra_option_oplr = OPLR( - [self.intra_option_net, self.log_std], intra_option_lr, clipvalue=5.) + [self.intra_option_net, self.log_std], intra_option_lr, **self._oplr_params) else: self.intra_option_oplr = OPLR( - self.intra_option_net, intra_option_lr, clipvalue=5.) + self.intra_option_net, intra_option_lr, **self._oplr_params) - self.q_oplr = OPLR(self.q_net, q_lr, clipvalue=5.) + self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params) self.termination_oplr = OPLR( - self.termination_net, termination_lr, clipvalue=5.) - self.interest_oplr = OPLR(self.interest_net, interest_lr, clipvalue=5.) + self.termination_net, termination_lr, **self._oplr_params) + self.interest_oplr = OPLR(self.interest_net, interest_lr, **self._oplr_params) self._trainer_modules.update(q_net=self.q_net, intra_option_net=self.intra_option_net, @@ -106,7 +106,7 @@ def episode_step(self, super().episode_step(obs, env_rets, begin_mask) self.options = self.new_options - @iTensor_oNumpy + @iton def select_action(self, obs): q = self.q_net(obs, cell_state=self.cell_state) # [B, P] self.next_cell_state = self.q_net.get_cell_state() @@ -147,7 +147,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH.options = int2one_hot(BATCH.options, self.options_num) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t( diff --git a/rls/algorithms/single/hierarchical/oc.py b/rls/algorithms/single/hierarchical/oc.py index 88be7ac..393f214 100644 --- a/rls/algorithms/single/hierarchical/oc.py +++ b/rls/algorithms/single/hierarchical/oc.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll, OcIntraOption from rls.nn.modules.wrappers import TargetTwin @@ -83,13 +83,13 @@ def __init__(self, self.log_std = t.as_tensor(np.full( (self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.intra_option_oplr = OPLR( - [self.intra_option_net, self.log_std], intra_option_lr, clipvalue=5.) + [self.intra_option_net, self.log_std], intra_option_lr, **self._oplr_params) else: self.intra_option_oplr = OPLR( - self.intra_option_net, intra_option_lr, clipvalue=5.) - self.q_oplr = OPLR(self.q_net, q_lr, clipvalue=5.) + self.intra_option_net, intra_option_lr, **self._oplr_params) + self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params) self.termination_oplr = OPLR( - self.termination_net, termination_lr, clipvalue=5.) + self.termination_net, termination_lr, **self._oplr_params) self._trainer_modules.update(q_net=self.q_net, intra_option_net=self.intra_option_net, @@ -110,7 +110,7 @@ def episode_step(self, super().episode_step(obs, env_rets, begin_mask) self.options = self.new_options - @iTensor_oNumpy + @iton def select_action(self, obs): q = self.q_net(obs, cell_state=self.cell_state) # [B, P] self.next_cell_state = self.q_net.get_cell_state() @@ -158,7 +158,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH.options = int2one_hot(BATCH.options, self.options_num) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t( diff --git a/rls/algorithms/single/hierarchical/ppoc.py b/rls/algorithms/single/hierarchical/ppoc.py index f78d9c0..c29d288 100644 --- a/rls/algorithms/single/hierarchical/ppoc.py +++ b/rls/algorithms/single/hierarchical/ppoc.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import PpocShare from rls.nn.utils import OPLR @@ -76,9 +76,9 @@ def __init__(self, if self.is_continuous: self.log_std = t.as_tensor(np.full( (self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] - self.oplr = OPLR([self.net, self.log_std], lr) + self.oplr = OPLR([self.net, self.log_std], lr, **self._oplr_params) else: - self.oplr = OPLR(self.net, lr) + self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) @@ -99,7 +99,7 @@ def episode_step(self, self.options = self.new_options self.oc_mask = t.zeros_like(self.oc_mask) - @iTensor_oNumpy + @iton def select_action(self, obs): # [B, P], [B, P, A], [B, P], [B, P] (q, pi, beta, o) = self.net(obs, cell_state=self.cell_state) @@ -145,7 +145,7 @@ def select_action(self, obs): acts_info.update(cell_state=self.cell_state) return action, acts_info - @iTensor_oNumpy + @iton def _get_value(self, obs, options, cell_state=None): (q, _, _, _) = self.net(obs, cell_state=cell_state) # [T, B, P] value = (q * options).sum(-1, keepdim=True) # [T, B, 1] @@ -190,7 +190,7 @@ def learn(self, BATCH: Data): if sum(kls)/len(kls) > self.kl_stop: break - @iTensor_oNumpy + @iton def _train(self, BATCH): # [T, B, P], [T, B, P, A], [T, B, P], [T, B, P] (q, pi, beta, o) = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) diff --git a/rls/algorithms/single/iqn.py b/rls/algorithms/single/iqn.py index 708e543..aff31b2 100644 --- a/rls/algorithms/single/iqn.py +++ b/rls/algorithms/single/iqn.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import IqnNet from rls.nn.modules.wrappers import TargetTwin @@ -57,11 +57,11 @@ def __init__(self, action_dim=self.a_dim, quantiles_idx=self.quantiles_idx, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): _, select_quantiles_tiled = self._generate_quantiles( # [N*B, X] batch_size=self.n_copys, @@ -92,7 +92,7 @@ def _generate_quantiles(self, batch_size, quantiles_num): batch_size, quantiles_num, 1) # [N*B, 1] => [B, N, 1] return _quantiles, _quantiles_tiled # [B, N, 1], [N*B, X] - @iTensor_oNumpy + @iton def _train(self, BATCH): time_step = BATCH.reward.shape[0] batch_size = BATCH.reward.shape[1] diff --git a/rls/algorithms/single/maxsqn.py b/rls/algorithms/single/maxsqn.py index c7aaac8..708777d 100644 --- a/rls/algorithms/single/maxsqn.py +++ b/rls/algorithms/single/maxsqn.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll from rls.nn.modules.wrappers import TargetTwin @@ -56,11 +56,11 @@ def __init__(self, self.ployak).to(self.device) self.critic2 = deepcopy(self.critic) - self.critic_oplr = OPLR([self.critic, self.critic2], q_lr) + self.critic_oplr = OPLR([self.critic, self.critic2], q_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = t.tensor(0., requires_grad=True).to(self.device) - self.alpha_oplr = OPLR(self.log_alpha, alpha_lr) + self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = t.tensor(alpha).log().to(self.device) @@ -74,7 +74,7 @@ def __init__(self, def alpha(self): return self.log_alpha.exp() - @iTensor_oNumpy + @iton def select_action(self, obs): q = self.critic(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.critic.get_cell_state() @@ -87,7 +87,7 @@ def select_action(self, obs): actions = pi = cate_dist.sample() # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2 = self.critic2( diff --git a/rls/algorithms/single/modelbased/dreamer_v1.py b/rls/algorithms/single/modelbased/dreamer_v1.py index bd539e7..72642a0 100644 --- a/rls/algorithms/single/modelbased/dreamer_v1.py +++ b/rls/algorithms/single/modelbased/dreamer_v1.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.dreamer import ActionDecoder, DenseModel, RecurrentStateSpaceModel from rls.nn.dreamer.utils import FreezeParameters, compute_return @@ -119,11 +119,10 @@ def __init__(self, _modules.append(self.pcont_decoder) self.model_oplr = OPLR( - _modules, model_lr, optimizer_params=self._optim_params, clipnorm=100) - self.actor_oplr = OPLR(self.actor, actor_lr, - optimizer_params=self._optim_params, clipnorm=100) + _modules, model_lr, **self._oplr_params) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR( - self.critic, critic_lr, optimizer_params=self._optim_params, clipnorm=100) + self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(obs_encoder=self.obs_encoder, obs_decoder=self.obs_decoder, reward_predictor=self.reward_predictor, @@ -156,7 +155,7 @@ def _dreamer_build_critic(self): 1, **self._network_settings['critic']).to(self.device) - @iTensor_oNumpy + @iton def select_action(self, obs): if self._is_visual: obs = obs.visual.visual_0 @@ -192,7 +191,7 @@ def _exploration(self, action: t.Tensor) -> t.Tensor: action[..., index] = 1 return action - @iTensor_oNumpy + @iton def _train(self, BATCH): T, B = BATCH.action.shape[:2] if self._is_visual: diff --git a/rls/algorithms/single/modelbased/planet.py b/rls/algorithms/single/modelbased/planet.py index 9daee52..180263b 100644 --- a/rls/algorithms/single/modelbased/planet.py +++ b/rls/algorithms/single/modelbased/planet.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.dreamer import DenseModel, RecurrentStateSpaceModel from rls.nn.utils import OPLR @@ -46,8 +46,7 @@ def __init__(self, assert self.use_rnn == False, 'assert self.use_rnn == False' - if self.obs_spec.has_visual_observation and len( - self.obs_spec.visual_dims) == 1 and not self.obs_spec.has_vector_observation: + if self.obs_spec.has_visual_observation and len(self.obs_spec.visual_dims) == 1 and not self.obs_spec.has_vector_observation: visual_dim = self.obs_spec.visual_dims[0] # TODO: optimize this assert visual_dim[0] == visual_dim[1] == 64, 'visual dimension must be [64, 64, *]' @@ -91,9 +90,9 @@ def __init__(self, 1, **network_settings['reward']).to(self.device) - self.model_oplr = OPLR( - [self.obs_encoder, self.rssm, - self.obs_decoder, self.reward_predictor], model_lr, optimizer_params=self._optim_params, clipnorm=100) + self.model_oplr = OPLR([self.obs_encoder, self.rssm, + self.obs_decoder, self.reward_predictor], + model_lr, **self._oplr_params) self._trainer_modules.update(obs_encoder=self.obs_encoder, obs_decoder=self.obs_decoder, reward_predictor=self.reward_predictor, @@ -111,7 +110,7 @@ def _dreamer_build_rssm(self): self.obs_encoder.h_dim, **self._network_settings['rssm']).to(self.device) - @iTensor_oNumpy + @iton def select_action(self, obs): if self._is_visual: obs = obs.visual.visual_0 @@ -120,60 +119,42 @@ def select_action(self, obs): # Compute starting state for planning # while taking information from current observation (posterior) embedded_obs = self.obs_encoder(obs) # [B, *] - state_posterior = self.rssm.posterior( - self.cell_state['hx'], embedded_obs) # dist # [B, *] + state_posterior = self.rssm.posterior(self.cell_state['hx'], embedded_obs) # dist # [B, *] # Initialize action distribution - mean = t.zeros((self.cem_horizon, self.n_copys, - self.a_dim)) # [H, B, A] - stddev = t.ones((self.cem_horizon, self.n_copys, - self.a_dim)) # [H, B, A] - action_dist = td.Normal(mean, stddev) + mean = t.zeros((self.cem_horizon, 1, self.n_copys, self.a_dim)) # [H, 1, B, A] + stddev = t.ones((self.cem_horizon, 1, self.n_copys, self.a_dim)) # [H, 1, B, A] # Iteratively improve action distribution with CEM for itr in range(self.cem_iter_nums): - # [N, H, B, A] - action_candidates = action_dist.sample((self.cem_candidates,)) - action_candidates = action_candidates.swapaxes(0, 1) - action_candidates = action_candidates.reshape( - self.cem_horizon, -1, self.a_dim) # [H, N*B, A] + action_candidates = mean + stddev * t.randn(self.cem_horizon, self.n_copys, self.cem_candidates, self.a_dim) # [H, N, B, A] + action_candidates = action_candidates.reshape(self.cem_horizon, -1, self.a_dim) # [H, N*B, A] # Initialize reward, state, and rnn hidden state # These are for parallel exploration - total_predicted_reward = t.zeros( - (self.cem_candidates*self.n_copys, 1)) # [N*B, 1] + total_predicted_reward = t.zeros((self.cem_candidates*self.n_copys, 1)) # [N*B, 1] - state = state_posterior.sample( - (self.cem_candidates,)) # [N, B, *] + state = state_posterior.sample((self.cem_candidates,)) # [N, B, *] state = state.view(-1, state.shape[-1]) # [N*B, *] - rnn_hidden = self.cell_state['hx'].repeat( - (self.cem_candidates, 1)) # [B, *] => [N*B, *] + rnn_hidden = self.cell_state['hx'].repeat((self.cem_candidates, 1)) # [B, *] => [N*B, *] # Compute total predicted reward by open-loop prediction using prior for _t in range(self.cem_horizon): - next_state_prior, rnn_hidden = \ - self.rssm.prior(state, action_candidates[_t], rnn_hidden) + next_state_prior, rnn_hidden = self.rssm.prior(state, t.tanh(action_candidates[_t]), rnn_hidden) state = next_state_prior.sample() # [N*B, *] post_feat = t.cat([state, rnn_hidden], -1) # [N*B, *] - # [N*B, 1] - total_predicted_reward += self.reward_predictor(post_feat).mean + total_predicted_reward += self.reward_predictor(post_feat).mean # [N*B, 1] # update action distribution using top-k samples - total_predicted_reward = total_predicted_reward.view( - self.cem_candidates, self.n_copys, 1) # [N, B, 1] - top_indexes = total_predicted_reward.argsort(dim=0, descending=True)[ - :self.cem_tops] # [N', B, 1] - action_candidates = action_candidates.view( - self.cem_horizon, self.cem_candidates, self.n_copys, -1) # [H, N, B, A] - top_action_candidates = action_candidates[:, top_indexes, t.arange( - self.n_copys).reshape(self.n_copys, 1), t.arange(self.a_dim)] # [H, N', B, A] - mean = top_action_candidates.mean(dim=1) # [H, B, A] - stddev = (top_action_candidates - mean.unsqueeze(1) - ).abs().sum(dim=1) / (self.cem_tops - 1) # [H, B, A] - action_dist = td.Normal(mean, stddev) + total_predicted_reward = total_predicted_reward.view(self.cem_candidates, self.n_copys, 1) # [N, B, 1] + _, top_indexes = total_predicted_reward.topk(self.cem_tops, dim=0, largest=True, sorted=False) # [N', B, 1] + action_candidates = action_candidates.view(self.cem_horizon, self.cem_candidates, self.n_copys, -1) # [H, N, B, A] + top_action_candidates = action_candidates[:, top_indexes, t.arange(self.n_copys).reshape(self.n_copys, 1), t.arange(self.a_dim)] # [H, N', B, A] + mean = top_action_candidates.mean(dim=1, keepdim=True) # [H, 1, B, A] + stddev = top_action_candidates.std(dim=1, unbiased=False, keepdim=True) # [H, 1, B, A] # Return only first action (replan each state based on new observation) - actions = t.tanh(mean[0]) # [B, A] + actions = t.tanh(mean[0].squeeze(0)) # [B, A] actions = self._exploration(actions) _, self.next_cell_state['hx'] = self.rssm.prior(state_posterior.sample(), actions, @@ -189,7 +170,7 @@ def _exploration(self, action: t.Tensor) -> t.Tensor: noise = t.randn(*action.shape) * sigma return t.clamp(action + noise, -1, 1) - @iTensor_oNumpy + @iton def _train(self, BATCH): T, B = BATCH.action.shape[:2] if self._is_visual: diff --git a/rls/algorithms/single/npg.py b/rls/algorithms/single/npg.py index 1200ae6..1a71bfa 100644 --- a/rls/algorithms/single/npg.py +++ b/rls/algorithms/single/npg.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorMuLogstd, CriticValue from rls.nn.utils import OPLR @@ -65,12 +65,12 @@ def __init__(self, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) - self.critic_oplr = OPLR(self.critic, critic_lr) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic_oplr=self.critic_oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.actor.get_cell_state() @@ -97,7 +97,7 @@ def select_action(self, obs): acts_info.update(logp_all=logp_all) return action, acts_info - @iTensor_oNumpy + @iton def _get_value(self, obs, cell_state=None): value = self.critic(obs, cell_state=cell_state) # [B, 1] return value @@ -123,7 +123,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *] normalize=True) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): output = self.actor( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] diff --git a/rls/algorithms/single/pg.py b/rls/algorithms/single/pg.py index b61fa05..a03ce40 100644 --- a/rls/algorithms/single/pg.py +++ b/rls/algorithms/single/pg.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorMuLogstd from rls.nn.utils import OPLR @@ -40,12 +40,12 @@ def __init__(self, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) - self.oplr = OPLR(self.net, lr) + self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.net(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.net.get_cell_state() @@ -73,7 +73,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *] normalize=True) return BATCH - @iTensor_oNumpy + @iton def _train(self, BATCH): # [B, T, *] output = self.net( BATCH.obs, begin_mask=BATCH.begin_mask) # [B, T, A] diff --git a/rls/algorithms/single/ppo.py b/rls/algorithms/single/ppo.py index 90f5b1d..37dea70 100644 --- a/rls/algorithms/single/ppo.py +++ b/rls/algorithms/single/ppo.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_on_policy import SarlOnPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import (ActorCriticValueCts, ActorCriticValueDct, ActorDct, ActorMuLogstd, CriticValue) @@ -38,7 +38,6 @@ def __init__(self, share_net: bool = True, actor_lr: float = 3e-4, critic_lr: float = 1e-3, - max_grad_norm: float = 0.5, kl_reverse: bool = False, kl_target: float = 0.02, kl_target_cutoff: float = 2, @@ -88,7 +87,6 @@ def __init__(self, self.kl_coef = kl_coef self.extra_coef = extra_coef self.vf_coef = vf_coef - self.max_grad_norm = max_grad_norm self.use_duel_clip = use_duel_clip self.duel_epsilon = duel_epsilon @@ -116,11 +114,7 @@ def __init__(self, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['share']['discrete']).to(self.device) - if self.max_grad_norm is not None: - self.oplr = OPLR(self.net, lr, clipnorm=self.max_grad_norm) - else: - self.oplr = OPLR(self.net, lr) - + self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) else: @@ -137,21 +131,14 @@ def __init__(self, self.critic = CriticValue(self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) - if self.max_grad_norm is not None: - self.actor_oplr = OPLR( - self.actor, actor_lr, clipnorm=self.max_grad_norm) - self.critic_oplr = OPLR( - self.critic, critic_lr, clipnorm=self.max_grad_norm) - else: - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR(self.critic, critic_lr) - + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): if self.is_continuous: if self.share_net: @@ -188,7 +175,7 @@ def select_action(self, obs): acts_info.update(cell_state=self.cell_state) return action, acts_info - @iTensor_oNumpy + @iton def _get_value(self, obs, cell_state=None): if self.share_net: if self.is_continuous: @@ -257,7 +244,7 @@ def _train(self, BATCH): return summaries, kl - @iTensor_oNumpy + @iton def train_share(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A], [T, B, 1] @@ -332,7 +319,7 @@ def train_share(self, BATCH): ['LEARNING_RATE/lr', self.oplr.lr] ]), kl - @iTensor_oNumpy + @iton def train_actor(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A] @@ -382,7 +369,7 @@ def train_actor(self, BATCH): ['LEARNING_RATE/actor_lr', self.actor_oplr.lr] ]), kl - @iTensor_oNumpy + @iton def train_critic(self, BATCH): value = self.critic( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] diff --git a/rls/algorithms/single/qrdqn.py b/rls/algorithms/single/qrdqn.py index 54a2de3..c18bbc7 100644 --- a/rls/algorithms/single/qrdqn.py +++ b/rls/algorithms/single/qrdqn.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import QrdqnDistributional from rls.nn.modules.wrappers import TargetTwin @@ -51,11 +51,11 @@ def __init__(self, action_dim=self.a_dim, nums=self.nums, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [B, A, N] self.next_cell_state = self.q_net.get_cell_state() @@ -67,7 +67,7 @@ def select_action(self, obs): actions = q.argmax(-1) # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q_dist = self.q_net( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] diff --git a/rls/algorithms/single/rainbow.py b/rls/algorithms/single/rainbow.py index 50355cd..d1d5a14 100644 --- a/rls/algorithms/single/rainbow.py +++ b/rls/algorithms/single/rainbow.py @@ -5,7 +5,7 @@ import torch as t from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import RainbowDueling from rls.nn.modules.wrappers import TargetTwin @@ -61,11 +61,11 @@ def __init__(self, action_dim=self.a_dim, atoms=self._atoms, network_settings=network_settings)).to(self.device) - self.oplr = OPLR(self.rainbow_net, lr) + self.oplr = OPLR(self.rainbow_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.rainbow_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.rainbow_net( obs, cell_state=self.cell_state) # [B, A, N] @@ -78,7 +78,7 @@ def select_action(self, obs): actions = q.argmax(-1) # [B,] return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): q_dist = self.rainbow_net( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] diff --git a/rls/algorithms/single/sac.py b/rls/algorithms/single/sac.py index 6caea8f..6fd8b66 100644 --- a/rls/algorithms/single/sac.py +++ b/rls/algorithms/single/sac.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorCts, ActorDct, CriticQvalueAll, CriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -82,8 +82,8 @@ def __init__(self, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = t.tensor(0., requires_grad=True).to(self.device) @@ -105,7 +105,7 @@ def __init__(self, def alpha(self): return self.log_alpha.exp() - @iTensor_oNumpy + @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor( @@ -128,7 +128,7 @@ def _train(self, BATCH): td_error, summaries = self._train_discrete(BATCH) return td_error, summaries - @iTensor_oNumpy + @iton def _train_continuous(self, BATCH): q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] @@ -220,7 +220,7 @@ def _train_continuous(self, BATCH): ]) return (td_error1 + td_error2) / 2, summaries - @iTensor_oNumpy + @iton def _train_discrete(self, BATCH): q1_all = self.critic( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] diff --git a/rls/algorithms/single/sac_v.py b/rls/algorithms/single/sac_v.py index c661929..4f7671b 100644 --- a/rls/algorithms/single/sac_v.py +++ b/rls/algorithms/single/sac_v.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import (ActorCts, ActorDct, CriticQvalueAll, CriticQvalueOne, CriticValue) @@ -88,13 +88,13 @@ def __init__(self, network_settings=network_settings['q']).to(self.device) self.q_net2 = deepcopy(self.q_net) - self.actor_oplr = OPLR(self.actor, actor_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR( - [self.q_net, self.q_net2, self.v_net], critic_lr) + [self.q_net, self.q_net2, self.v_net], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = t.tensor(0., requires_grad=True).to(self.device) - self.alpha_oplr = OPLR(self.log_alpha, alpha_lr) + self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = t.tensor(alpha).log().to(self.device) @@ -113,7 +113,7 @@ def __init__(self, def alpha(self): return self.log_alpha.exp() - @iTensor_oNumpy + @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor( @@ -136,7 +136,7 @@ def _train(self, BATCH): td_error, summaries = self._train_discrete(BATCH) return td_error, summaries - @iTensor_oNumpy + @iton def _train_continuous(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t( @@ -238,7 +238,7 @@ def _train_continuous(self, BATCH): ]) return (td_error1 + td_error2) / 2, summaries - @iTensor_oNumpy + @iton def _train_discrete(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t( diff --git a/rls/algorithms/single/sql.py b/rls/algorithms/single/sql.py index ce5f9f8..99aae7f 100644 --- a/rls/algorithms/single/sql.py +++ b/rls/algorithms/single/sql.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import CriticQvalueAll from rls.nn.modules.wrappers import TargetTwin @@ -39,11 +39,11 @@ def __init__(self, network_settings=network_settings), self.ployak).to(self.device) - self.oplr = OPLR(self.q_net, lr) + self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) - @iTensor_oNumpy + @iton def select_action(self, obs): q_values = self.q_net(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.q_net.get_cell_state() @@ -59,7 +59,7 @@ def _get_v(self, q): keepdim=True).log() # [B, 1] or [T, B, 1] return v - @iTensor_oNumpy + @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t( diff --git a/rls/algorithms/single/tac.py b/rls/algorithms/single/tac.py index b74d61e..ab011cf 100644 --- a/rls/algorithms/single/tac.py +++ b/rls/algorithms/single/tac.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorCts, ActorDct, CriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -75,12 +75,12 @@ def __init__(self, self.target_entropy = 0.98 * \ (-self.a_dim if self.is_continuous else np.log(self.a_dim)) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = t.tensor(0., requires_grad=True).to(self.device) - self.alpha_oplr = OPLR(self.log_alpha, alpha_lr) + self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = t.tensor(alpha).log().to(self.device) @@ -98,7 +98,7 @@ def __init__(self, def alpha(self): return self.log_alpha.exp() - @iTensor_oNumpy + @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor( @@ -114,7 +114,7 @@ def select_action(self, obs): actions = pi if self._is_train_mode else mu return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): if self.is_continuous: target_mu, target_log_std = self.actor( diff --git a/rls/algorithms/single/td3.py b/rls/algorithms/single/td3.py index 5ca13fe..1da5c31 100644 --- a/rls/algorithms/single/td3.py +++ b/rls/algorithms/single/td3.py @@ -8,7 +8,7 @@ from torch import distributions as td from rls.algorithms.base.sarl_off_policy import SarlOffPolicy -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorDPG, CriticQvalueOne from rls.nn.modules.wrappers import TargetTwin @@ -66,8 +66,8 @@ def __init__(self, self.ployak).to(self.device) self.critic2 = deepcopy(self.critic) - self.actor_oplr = OPLR(self.actor, actor_lr) - self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr) + self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) + self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic2=self.critic2, @@ -79,7 +79,7 @@ def episode_reset(self): if self.is_continuous: self.noised_action.reset() - @iTensor_oNumpy + @iton def select_action(self, obs): output = self.actor(obs, cell_state=self.cell_state) # [B, A] self.next_cell_state = self.actor.get_cell_state() @@ -94,7 +94,7 @@ def select_action(self, obs): actions = pi if self._is_train_mode else mu return actions, Data(action=actions) - @iTensor_oNumpy + @iton def _train(self, BATCH): for _ in range(self.delay_num): if self.is_continuous: diff --git a/rls/algorithms/single/trpo.py b/rls/algorithms/single/trpo.py index 5e8e123..fcf494d 100644 --- a/rls/algorithms/single/trpo.py +++ b/rls/algorithms/single/trpo.py @@ -6,7 +6,7 @@ from torch import distributions as td from rls.algorithms.single.npg import NPG -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.common.specs import Data from rls.nn.models import ActorDct, ActorMuLogstd, CriticValue from rls.nn.utils import OPLR @@ -40,7 +40,7 @@ def __init__(self, self._backtrack_iters = backtrack_iters self._backtrack_coeff = backtrack_coeff - @iTensor_oNumpy + @iton def _train(self, BATCH): output = self.actor( BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] diff --git a/rls/common/decorator.py b/rls/common/decorator.py index 5a94ccd..6d7cdcf 100644 --- a/rls/common/decorator.py +++ b/rls/common/decorator.py @@ -21,7 +21,7 @@ def wrapper(self): return wrapper -def iTensor_oNumpy(func, dtype=t.float32, device='cpu'): +def iton(func, dtype=t.float32, device='cpu'): def wrapper(*args, **kwargs): if args and hasattr(args[0], 'device'): diff --git a/rls/common/specs.py b/rls/common/specs.py index 0c87ad0..d6668aa 100644 --- a/rls/common/specs.py +++ b/rls/common/specs.py @@ -146,18 +146,9 @@ def get(self, name, value=None): else: return value - # TODO: remove - def unpack(self) -> Iterator: - for i in range(len(self)): - yield self[i] - @staticmethod - def pack(ds: List, func: Callable = lambda x: np.asarray(x)): - ''' - TODO: Annotation - ''' - params = {} - for k, v in ds[0].__dict__.items(): - d = [getattr(rds, k) for rds in ds] - params[k] = Data.pack(d, func) if isinstance(v, Data) else func(d) - return ds[0].__class__(**params) +class DictCls(dict): + + def __getattr__(self, name): + assert name not in self.keys(), 'assert name not in self.keys()' + return [v.get(name) if isinstance(v, dict) else getattr(v, name) for k, v in self.items()] diff --git a/rls/configs/algorithms.yaml b/rls/configs/algorithms.yaml index ab755d4..ae9da67 100644 --- a/rls/configs/algorithms.yaml +++ b/rls/configs/algorithms.yaml @@ -10,7 +10,12 @@ policy: &policy obs_with_pre_action: false - optim_params: &optim_params {} + oplr_params: + optim_params: + eps: 1.e-4 + grad_params: + grad_max_norm: 100. + grad_clip_value: 100. # ----- could be overrided in specific algorithms, i.e. dqn, so as to using different type of visual net, memory net. rep_net_params: &rep_net_params @@ -310,7 +315,6 @@ ppo: lambda_: 0.97 actor_lr: 3.0e-4 critic_lr: 1.0e-3 - max_grad_norm: ~ # duel clip use_duel_clip: false @@ -605,8 +609,6 @@ qplex: planet: <<: *sarl_off_policy - optim_params: - eps: 1.e-4 train_times: 1 train_interval: 1 @@ -638,17 +640,15 @@ planet: hidden_units: 64 dist: "mse" rssm: - hidden_units: 200 + hidden_units: 64 std_act: "softplus" reward: - layers: 3 - hidden_units: 300 + layers: 2 + hidden_units: 64 dist: "mse" dreamer: &dreamer <<: *sarl_off_policy - optim_params: - eps: 1.e-4 train_times: 1 train_interval: 1 diff --git a/rls/nn/modules/icm.py b/rls/nn/modules/icm.py index 22dbe56..5e1f1ea 100644 --- a/rls/nn/modules/icm.py +++ b/rls/nn/modules/icm.py @@ -2,7 +2,7 @@ import torch as t from torch.nn import Linear, Sequential, Tanh -from rls.common.decorator import iTensor_oNumpy +from rls.common.decorator import iton from rls.nn.activations import Act_REGISTER, default_act from rls.nn.represent_nets import RepresentationNetwork from rls.nn.utils import OPLR diff --git a/rls/nn/utils.py b/rls/nn/utils.py index 5927dc4..ebf4554 100644 --- a/rls/nn/utils.py +++ b/rls/nn/utils.py @@ -23,9 +23,8 @@ def __init__(self, optimizer: str = 'adam', *, scheduler_params: Dict = {}, - optimizer_params: Dict = {}, - clipvalue: Optional[float] = None, - clipnorm: Optional[float] = None): + optim_params: Dict = {}, + grad_params: Dict = {}): self.params = [] if not isinstance(models, (list, tuple)): models = [models] @@ -37,23 +36,20 @@ def __init__(self, self.params.append(model) self.optimizer = OP_REGISTER[optimizer]( - self.params, lr, **optimizer_params) + self.params, lr, **optim_params) self.lr_scheduler = LR_REGISTER[scheduler]( self.optimizer, **scheduler_params) - self.clipnorm = clipnorm - self.clipvalue = clipvalue - self._hooks = [] - if self.clipnorm: + if 'grad_max_norm' in grad_params.keys(): self._hooks.append( lambda: t.nn.utils.clip_grad_norm_( - self.params, max_norm=self.clipnorm) + self.params, max_norm=grad_params['grad_max_norm']) ) - if self.clipvalue: + if 'grad_clip_value' in grad_params.keys(): self._hooks.append( lambda: t.nn.utils.clip_grad_value_( - self.params, clip_value=self.clipvalue) + self.params, clip_value=grad_params['grad_clip_value']) ) @property