Skip to content

Commit

Permalink
v5.1.10 perf: optimized dreamer related. (#34)
Browse files Browse the repository at this point in the history
1. renamed `iTensor_oNumpy` to `iton`
2. optimized `auto_format.py`
3. added general params `oplr_params` to initializing optimizer
  • Loading branch information
StepNeverStop committed Sep 3, 2021
1 parent 6081ec0 commit 309d63f
Show file tree
Hide file tree
Showing 44 changed files with 224 additions and 266 deletions.
16 changes: 10 additions & 6 deletions auto_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@ def get_args():
help='.py file path that need to be formatted.')
parser.add_argument('-d', '--file-dir', type=str, default=None,
help='.py dictionary that need to be formatted.')
parser.add_argument('--ignore-pep', default=False, action='store_true',
help='whether format the file')
return parser.parse_args()


def autopep8(file_path):
def autopep8(file_path, ignore_pep):
isort.file(file_path)
os.system(f"autopep8 -i {file_path}")
print(f'autopep8 file: {file_path} SUCCESSFULLY.')
print(f'isort file: {file_path} SUCCESSFULLY.')
if not ignore_pep:
os.system(f"autopep8 -j 0 -i {file_path} --max-line-length 200")
print(f'autopep8 file: {file_path} SUCCESSFULLY.')


if __name__ == '__main__':

args = get_args()

if args.file_path:
autopep8(args.file_path)
autopep8(args.file_path, args.ignore_pep)

if args.file_dir:
py_files = []
for root, dirnames, filenames in os.walk(args.file_dir):
py_files.extend(glob.glob(root + "/*.py"))

for path in py_files:
autopep8(path)
autopep8(path, args.ignore_pep)

print('autopep8 finished.')
print('auto-format finished.')
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '5'
_MINOR_VERSION = '1'
_PATCH_VERSION = '9'
_PATCH_VERSION = '10'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
4 changes: 2 additions & 2 deletions rls/algorithms/base/marl_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch as t

from rls.algorithms.base.marl_policy import MarlPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.common.yaml_ops import load_config
from rls.utils.np_utils import int2one_hot
Expand Down Expand Up @@ -107,7 +107,7 @@ def _before_train(self, BATCH_DICT):
self.summaries = {}
return BATCH_DICT

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions rls/algorithms/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self,
decay_lr=False,
normalize_vector_obs=False,
obs_with_pre_action=False,
optim_params=dict(),
oplr_params=dict(),
rep_net_params={
'vector_net_params': {
'h_dim': 16,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self,
self._normalize_vector_obs = normalize_vector_obs # TODO: implement
self._obs_with_pre_action = obs_with_pre_action
self._rep_net_params = dict(rep_net_params)
self._optim_params = dict(optim_params)
self._oplr_params = dict(oplr_params)

super().__init__()

Expand Down
4 changes: 2 additions & 2 deletions rls/algorithms/base/sarl_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch as t

from rls.algorithms.base.sarl_policy import SarlPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.common.when import Every
from rls.common.yaml_ops import load_config
Expand Down Expand Up @@ -98,7 +98,7 @@ def _before_train(self, BATCH):
self.summaries.update(crsty_summaries)
return BATCH

@iTensor_oNumpy
@iton
def _train(self, BATCH):
raise NotImplementedError

Expand Down
10 changes: 5 additions & 5 deletions rls/algorithms/multi/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import distributions as td

from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.models import ActorDct, ActorDPG, MACriticQvalueOne
from rls.nn.modules.wrappers import TargetTwin
Expand Down Expand Up @@ -66,8 +66,8 @@ def __init__(self,
self.a_dims.values()),
network_settings=network_settings['q']),
self.ployak).to(self.device)
self.actor_oplr = OPLR(list(self.actors.values()), actor_lr)
self.critic_oplr = OPLR(list(self.critics.values()), critic_lr)
self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params)
self.critic_oplr = OPLR(list(self.critics.values()), critic_lr, **self._oplr_params)

# TODO: 添加动作类型判断
self.noised_actions = {id: Noise_action_REGISTER[noise_action](**noise_params)
Expand All @@ -85,7 +85,7 @@ def episode_reset(self):
for noised_action in self.noised_actions.values():
noised_action.reset()

@iTensor_oNumpy
@iton
def select_action(self, obs: Dict):
acts_info = {}
actions = {}
Expand All @@ -106,7 +106,7 @@ def select_action(self, obs: Dict):
actions[aid] = action
return actions, acts_info

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
'''
TODO: Annotation
Expand Down
12 changes: 6 additions & 6 deletions rls/algorithms/multi/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import distributions as td

from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.models import ActorCts, ActorDct, MACriticQvalueOne
from rls.nn.modules.wrappers import TargetTwin
Expand Down Expand Up @@ -79,13 +79,13 @@ def __init__(self,
network_settings=network_settings['q']),
self.ployak).to(self.device)
self.critics2[id] = deepcopy(self.critics[id])
self.actor_oplr = OPLR(list(self.actors.values()), actor_lr)
self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params)
self.critic_oplr = OPLR(
list(self.critics.values())+list(self.critics2.values()), critic_lr)
list(self.critics.values())+list(self.critics2.values()), critic_lr, **self._oplr_params)

if self.auto_adaption:
self.log_alpha = t.tensor(0., requires_grad=True).to(self.device)
self.alpha_oplr = OPLR(self.log_alpha, alpha_lr)
self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params)
self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
else:
self.log_alpha = t.tensor(alpha).log().to(self.device)
Expand All @@ -106,7 +106,7 @@ def __init__(self,
def alpha(self):
return self.log_alpha.exp()

@iTensor_oNumpy
@iton
def select_action(self, obs: Dict):
acts_info = {}
actions = {}
Expand All @@ -128,7 +128,7 @@ def select_action(self, obs: Dict):
actions[aid] = action
return actions, acts_info

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
'''
TODO: Annotation
Expand Down
4 changes: 2 additions & 2 deletions rls/algorithms/multi/qplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch as t

from rls.algorithms.multi.vdn import VDN
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.nn.mixers import Mixer_REGISTER
from rls.nn.modules.wrappers import TargetTwin
from rls.utils.torch_utils import n_step_return
Expand Down Expand Up @@ -37,7 +37,7 @@ def _build_mixer(self):
**self._mixer_settings)
).to(self.device)

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
summaries = {}
reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1]
Expand Down
4 changes: 2 additions & 2 deletions rls/algorithms/multi/qtran.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch as t

from rls.algorithms.multi.vdn import VDN
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.nn.mixers import Mixer_REGISTER
from rls.nn.modules.wrappers import TargetTwin
from rls.utils.torch_utils import n_step_return
Expand Down Expand Up @@ -42,7 +42,7 @@ def _build_mixer(self):
**self._mixer_settings)
).to(self.device)

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
summaries = {}
reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1]
Expand Down
8 changes: 4 additions & 4 deletions rls/algorithms/multi/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch as t

from rls.algorithms.base.marl_off_policy import MultiAgentOffPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.mixers import Mixer_REGISTER
from rls.nn.models import CriticDueling
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self,

self.mixer = self._build_mixer()

self.oplr = OPLR(tuple(self.q_nets.values())+(self.mixer,), lr)
self.oplr = OPLR(tuple(self.q_nets.values())+(self.mixer,), lr, **self._oplr_params)
self._trainer_modules.update(
{f"model_{id}": self.q_nets[id] for id in set(self.model_ids)})
self._trainer_modules.update(mixer=self.mixer,
Expand All @@ -79,7 +79,7 @@ def _build_mixer(self):
**self._mixer_settings)
).to(self.device)

@iTensor_oNumpy # TODO: optimization
@iton # TODO: optimization
def select_action(self, obs):
acts_info = {}
actions = {}
Expand All @@ -97,7 +97,7 @@ def select_action(self, obs):
acts_info[aid] = Data(action=action)
return actions, acts_info

@iTensor_oNumpy
@iton
def _train(self, BATCH_DICT):
summaries = {}
reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1]
Expand Down
12 changes: 6 additions & 6 deletions rls/algorithms/single/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import distributions as td

from rls.algorithms.base.sarl_on_policy import SarlOnPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.models import ActorDct, ActorMuLogstd, CriticValue
from rls.nn.utils import OPLR
Expand Down Expand Up @@ -52,15 +52,15 @@ def __init__(self,
rep_net_params=self._rep_net_params,
network_settings=network_settings['critic']).to(self.device)

self.actor_oplr = OPLR(self.actor, actor_lr)
self.critic_oplr = OPLR(self.critic, critic_lr)
self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

self._trainer_modules.update(actor=self.actor,
critic=self.critic,
actor_oplr=self.actor_oplr,
critic_oplr=self.critic_oplr)

@iTensor_oNumpy
@iton
def select_action(self, obs):
output = self.actor(obs, cell_state=self.cell_state) # [B, A]
self.next_cell_state = self.actor.get_cell_state()
Expand All @@ -78,7 +78,7 @@ def select_action(self, obs):
acts_info.update(cell_state=self.cell_state)
return action, acts_info

@iTensor_oNumpy
@iton
def _get_value(self, obs, cell_state=None):
value = self.critic(obs, cell_state=self.cell_state)
return value
Expand All @@ -93,7 +93,7 @@ def _preprocess_BATCH(self, BATCH): # [T, B, *]
init_value=value)
return BATCH

@iTensor_oNumpy
@iton
def _train(self, BATCH):
v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1]
td_error = BATCH.discounted_reward - v # [T, B, 1]
Expand Down
10 changes: 5 additions & 5 deletions rls/algorithms/single/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import distributions as td

from rls.algorithms.base.sarl_off_policy import SarlOffPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.models import ActorDct, ActorMuLogstd, CriticQvalueOne
from rls.nn.utils import OPLR
Expand Down Expand Up @@ -47,15 +47,15 @@ def __init__(self,
action_dim=self.a_dim,
network_settings=network_settings['critic']).to(self.device)

self.actor_oplr = OPLR(self.actor, actor_lr)
self.critic_oplr = OPLR(self.critic, critic_lr)
self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

self._trainer_modules.update(actor=self.actor,
critic=self.critic,
actor_oplr=self.actor_oplr,
critic_oplr=self.critic_oplr)

@iTensor_oNumpy
@iton
def select_action(self, obs):
output = self.actor(obs, cell_state=self.cell_state) # [B, *]
self.next_cell_state = self.actor.get_cell_state()
Expand All @@ -82,7 +82,7 @@ def random_action(self):
self.n_copys, 1./self.a_dim)) # [B,]
return actions

@iTensor_oNumpy
@iton
def _train(self, BATCH):
q = self.critic(BATCH.obs, BATCH.action,
begin_mask=BATCH.begin_mask) # [T, B, 1]
Expand Down
8 changes: 4 additions & 4 deletions rls/algorithms/single/averaged_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch as t

from rls.algorithms.base.sarl_off_policy import SarlOffPolicy
from rls.common.decorator import iTensor_oNumpy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.nn.models import CriticQvalueAll
from rls.nn.utils import OPLR
Expand Down Expand Up @@ -55,11 +55,11 @@ def __init__(self,
sync_params(target_q_net, self.q_net)
self.target_nets.append(target_q_net)

self.oplr = OPLR(self.q_net, lr)
self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
self._trainer_modules.update(model=self.q_net,
oplr=self.oplr)

@iTensor_oNumpy
@iton
def select_action(self, obs):
q_values = self.q_net(obs, cell_state=self.cell_state) # [B, *]
self.next_cell_state = self.q_net.get_cell_state()
Expand All @@ -74,7 +74,7 @@ def select_action(self, obs):
actions = q_values.argmax(-1) # 不取平均也可以 [B, ]
return actions, Data(action=actions)

@iTensor_oNumpy
@iton
def _train(self, BATCH):
q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *]
q_next = 0
Expand Down
Loading

0 comments on commit 309d63f

Please sign in to comment.