Skip to content

Commit

Permalink
v0.0.3 feat(logger): added wandb and implemented tensorboard logger…
Browse files Browse the repository at this point in the history
… (default) and wandb logger, which could be configured at `algorithms.yaml/policy` (#47)
  • Loading branch information
StepNeverStop committed Sep 13, 2021
1 parent 3cb1f1c commit f3ebb82
Show file tree
Hide file tree
Showing 19 changed files with 197 additions and 103 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,5 @@ test.py
.vscode/
data/
videos/
unitylog.txt
unitylog.txt
wandb/
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '0'
_MINOR_VERSION = '0'
_PATCH_VERSION = '2'
_PATCH_VERSION = '3'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
4 changes: 0 additions & 4 deletions rls/algorithms/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,3 @@ def resume(self):
@abstractmethod
def still_learn(self):
pass

@abstractmethod
def write_recorder_summaries(self):
pass
3 changes: 2 additions & 1 deletion rls/algorithms/base/marl_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def _train(self, BATCH_DICT):
raise NotImplementedError

def _after_train(self):
self._write_train_summaries(self._cur_train_step, self.summaries)
self._write_log(summaries=self.summaries,
step_type='step')
if self._should_save_model(self._cur_train_step):
self.save()
self._cur_train_step += 1
61 changes: 20 additions & 41 deletions rls/algorithms/base/marl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rls.algorithms.base.policy import Policy
from rls.common.specs import Data, EnvAgentSpec, SensorSpec
from rls.utils.converter import to_tensor
from rls.utils.loggers import Log_REGISTER
from rls.utils.np_utils import int2one_hot


Expand All @@ -27,12 +28,9 @@ def __init__(self,
self.agent_specs = agent_specs
self.n_agents_percopy = len(agent_specs)
self.agent_ids = list(self.agent_specs.keys())
self.obs_specs = {id: agent_spec.obs_spec for id,
agent_spec in agent_specs.items()}
self.is_continuouss = {
id: agent_spec.is_continuous for id, agent_spec in agent_specs.items()}
self.a_dims = {id: agent_spec.a_dim for id,
agent_spec in agent_specs.items()}
self.obs_specs = {id: agent_spec.obs_spec for id, agent_spec in agent_specs.items()}
self.is_continuouss = {id: agent_spec.is_continuous for id, agent_spec in agent_specs.items()}
self.a_dims = {id: agent_spec.a_dim for id, agent_spec in agent_specs.items()}

self.state_spec = state_spec
self.share_params = share_params
Expand All @@ -57,8 +55,16 @@ def __init__(self,
if self.agent_specs[self.agent_ids[i]] == self.agent_specs[id]:
self.model_ids[i] = id
break
self.agent_writers = {id: self._create_writer(
self.log_dir + f'_{id}') for id in self.agent_ids}

def _build_loggers(self):
return [
Log_REGISTER[logger_type](
log_dir=self.log_dir,
ids=['model'] + self.agent_ids,
training_name=self._training_name, # wandb
)
for logger_type in self._logger_types
]

def _preprocess_obs(self, obs: Dict):
for i, id in enumerate(self.agent_ids):
Expand All @@ -69,8 +75,7 @@ def _preprocess_obs(self, obs: Dict):
else:
other = self._pre_acts[id]
if self._obs_with_agent_id:
_id_onehot = int2one_hot(
np.full(self.n_copys, i), self.n_agents_percopy)
_id_onehot = int2one_hot(np.full(self.n_copys, i), self.n_agents_percopy)
if other is not None:
other = np.concatenate((
other,
Expand All @@ -95,26 +100,21 @@ def random_action(self):
self._acts_info = {}
for id in self.agent_ids:
if self.is_continuouss[id]:
actions[id] = np.random.uniform(-1.0,
1.0, (self.n_copys, self.a_dims[id]))
actions[id] = np.random.uniform(-1.0, 1.0, (self.n_copys, self.a_dims[id]))
else:
actions[id] = np.random.randint(
0, self.a_dims[id], self.n_copys)
actions[id] = np.random.randint(0, self.a_dims[id], self.n_copys)
self._acts_info[id] = Data(action=actions[id])
self._pre_acts = actions
return actions

def episode_reset(self):
self._pre_acts = {}
for id in self.agent_ids:
self._pre_acts[id] = np.zeros(
(self.n_copys, self.a_dims[id])) if self.is_continuouss[id] else np.zeros(self.n_copys)
self._pre_acts[id] = np.zeros((self.n_copys, self.a_dims[id])) if self.is_continuouss[id] else np.zeros(self.n_copys)
self.rnncs, self.rnncs_ = {}, {}
for id in self.agent_ids:
self.rnncs[id] = to_tensor(self._initial_rnncs(
batch=self.n_copys), device=self.device)
self.rnncs_[id] = to_tensor(self._initial_rnncs(
batch=self.n_copys), device=self.device)
self.rnncs[id] = to_tensor(self._initial_rnncs(batch=self.n_copys), device=self.device)
self.rnncs_[id] = to_tensor(self._initial_rnncs(batch=self.n_copys), device=self.device)

def episode_step(self,
obs,
Expand Down Expand Up @@ -143,28 +143,7 @@ def episode_step(self,
for k in self.rnncs[id].keys():
self.rnncs[id][k][idxs] = 0.

def write_recorder_summaries(self, summaries):
if 'model' in summaries.keys():
super()._write_train_summaries(self._cur_episode,
summaries=summaries.pop('model'), writer=self.writer)
for id, summary in summaries.items():
super()._write_train_summaries(self._cur_episode,
summaries=summary, writer=self.agent_writers[id])

# customed

def _train(self, BATCH_DICT):
raise NotImplementedError

def _write_train_summaries(self,
cur_train_step: Union[int, t.Tensor],
summaries: Dict) -> NoReturn:
'''
write summaries showing in tensorboard.
'''
if 'model' in summaries.keys():
super()._write_train_summaries(cur_train_step,
summaries=summaries.pop('model'), writer=self.writer)
for id, summary in summaries.items():
super()._write_train_summaries(cur_train_step,
summaries=summary, writer=self.agent_writers[id])
55 changes: 31 additions & 24 deletions rls/algorithms/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self,
save2single_file=False,
n_step_value=4,
gamma=0.999,
logger_types=['none'],
decay_lr=False,
normalize_vector_obs=False,
obs_with_pre_action=False,
Expand Down Expand Up @@ -63,7 +64,8 @@ def __init__(self,
'''
self.n_copys = n_copys
self._is_save = is_save
self.base_dir = base_dir
self._base_dir = base_dir
self._training_name = os.path.split(self._base_dir)[-1]
self.device = device
logger.info(colorize(f"PyTorch Tensor Device: {self.device}"))
self._max_train_step = max_train_step
Expand All @@ -75,6 +77,7 @@ def __init__(self,

self._save2single_file = save2single_file
self.gamma = gamma
self._logger_types = logger_types
self._n_step_value = n_step_value
self._decay_lr = decay_lr # TODO: implement
self._normalize_vector_obs = normalize_vector_obs # TODO: implement
Expand All @@ -96,7 +99,6 @@ def __init__(self,

if self._is_save:
check_or_create(self.cp_dir, 'checkpoints(models)')
self.writer = self._create_writer(self.log_dir) # TODO: Annotation

self._cur_interact_step = t.tensor(0).long().to(self.device)
self._cur_train_step = t.tensor(0).long().to(self.device)
Expand All @@ -111,6 +113,7 @@ def __init__(self,
}

self._buffer = self._build_buffer()
self._loggers = self._build_loggers() if self._is_save else list()

def __call__(self, obs):
raise NotImplementedError
Expand Down Expand Up @@ -166,7 +169,7 @@ def resume(self, base_dir: Optional[str] = None) -> Dict:
"""
check whether chekpoint and model be within cp_dir, if in it, restore otherwise initialize randomly.
"""
cp_dir = os.path.join(base_dir or self.base_dir, 'model')
cp_dir = os.path.join(base_dir or self._base_dir, 'model')
if self._save2single_file:
ckpt_path = os.path.join(cp_dir, 'checkpoint.pth')
if os.path.exists(ckpt_path):
Expand All @@ -183,43 +186,47 @@ def resume(self, base_dir: Optional[str] = None) -> Dict:
model_path = os.path.join(cp_dir, f'{k}.pth')
if os.path.exists(model_path):
if hasattr(v, 'load_state_dict'):
self._trainer_modules[k].load_state_dict(
t.load(model_path))
self._trainer_modules[k].load_state_dict(t.load(model_path))
else:
getattr(self, k).fill_(t.load(model_path))
logger.info(
colorize(f'Resume model from {model_path} SUCCESSFULLY.', color='green'))
logger.info(colorize(f'Resume model from {model_path} SUCCESSFULLY.', color='green'))

@property
def still_learn(self):
return self._should_learn_cond_train_step(self._cur_train_step) \
and self._should_learn_cond_frame_step(self._cur_frame_step) \
and self._should_learn_cond_train_episode(self._cur_episode)

def write_recorder_summaries(self, summaries):
raise NotImplementedError
def write_log(self,
log_step: Union[int, t.Tensor] = None,
summaries: Dict = {},
step_type: str = None):
self._write_log(log_step, summaries, step_type)

# customed

def _build_buffer(self):
raise NotImplementedError

def _create_writer(self, log_dir: str) -> SummaryWriter:
if self._is_save:
check_or_create(log_dir, 'logs(summaries)')
return SummaryWriter(log_dir)
def _build_loggers(self):
raise NotImplementedError

def _write_train_summaries(self,
cur_train_step: Union[int, t.Tensor],
summaries: Dict = {},
writer: Optional[SummaryWriter] = None) -> NoReturn:
'''
write summaries showing in tensorboard.
'''
if self._is_save:
writer = writer or self.writer
for k, v in summaries.items():
writer.add_scalar(k, v, global_step=cur_train_step)
def _write_log(self,
log_step: Union[int, t.Tensor] = None,
summaries: Dict = {},
step_type: str = None):
assert step_type is not None or log_step is not None, 'assert step_type is not None or log_step is not None'
if log_step is None:
if step_type == 'step':
log_step = self._cur_train_step
elif step_type == 'episode':
log_step = self._cur_episode
elif log_step == 'frame':
log_step = self._cur_frame_step
else:
raise NotImplementedError("log_step must be in ['step', 'episode', 'frame'] for now.")
for logger in self._loggers:
logger.write(summaries=summaries, step=log_step)

def _initial_rnncs(self, batch: int, rnn_units: int = None, keys: Optional[List[str]] = None) -> Dict[str, np.ndarray]:
rnn_units = rnn_units or self.memory_net_params['rnn_units']
Expand Down
4 changes: 2 additions & 2 deletions rls/algorithms/base/sarl_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def _train(self, BATCH):
raise NotImplementedError

def _after_train(self):
self._write_train_summaries(
self._cur_train_step, self.summaries, self.writer)
self._write_log(summaries=self.summaries,
step_type='step')
if self._should_save_model(self._cur_train_step):
self.save()
self._cur_train_step += 1
6 changes: 4 additions & 2 deletions rls/algorithms/base/sarl_on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch as t

from rls.algorithms.base.sarl_policy import SarlPolicy
from rls.common.decorator import iton
from rls.common.specs import Data
from rls.utils.np_utils import int2one_hot

Expand Down Expand Up @@ -75,12 +76,13 @@ def _before_train(self, BATCH):
self.summaries.update(crsty_summaries)
return BATCH

@iton
def _train(self, BATCH):
raise NotImplementedError

def _after_train(self):
self._write_train_summaries(
self._cur_train_step, self.summaries, self.writer)
self._write_log(summaries=self.summaries,
step_type='step')
if self._should_save_model(self._cur_train_step):
self.save()
self._cur_train_step += 1
13 changes: 10 additions & 3 deletions rls/algorithms/base/sarl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rls.common.specs import Data, EnvAgentSpec
from rls.nn.modules import CuriosityModel
from rls.utils.converter import to_tensor
from rls.utils.loggers import Log_REGISTER
from rls.utils.np_utils import int2one_hot
from rls.utils.vector_runing_average import (DefaultRunningAverage,
SimpleRunningAverage)
Expand Down Expand Up @@ -115,10 +116,16 @@ def episode_step(self,
def learn(self, BATCH: Data):
raise NotImplementedError

def write_recorder_summaries(self, summaries):
self._write_train_summaries(self._cur_episode, summaries, self.writer)

# customed

def _train(self, BATCH):
raise NotImplementedError

def _build_loggers(self):
return [
Log_REGISTER[logger_type](
log_dir=self.log_dir,
training_name=self._training_name, # wandb
)
for logger_type in self._logger_types
]
2 changes: 1 addition & 1 deletion rls/algorithms/multi/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _train(self, BATCH_DICT):
q_target_next_max = q_target.max(-1, keepdim=True)[0]

q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1]

q_evals = t.stack(q_evals, -1) # [T, B, 1, N]
q_target_next_choose_maxs = t.stack(q_target_next_choose_maxs, -1) # [T, B, 1, N]
q_eval_tot = self.mixer(q_evals, BATCH_DICT['global'].obs,
Expand Down
11 changes: 9 additions & 2 deletions rls/algorithms/wrapper/IndependentMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self,
logger.info(colorize(
'using SARL algorithm to train Multi-Agent task, model has been changed to independent-SARL automatically.'))

assert 'wandb' not in algo_args.logger_types, "assert 'wandb' not in algo_args.logger_types"

self.models = {}
for id in self._agent_ids:
_algo_args = deepcopy(algo_args)
Expand Down Expand Up @@ -97,9 +99,14 @@ def resume(self, base_dir: Optional[str] = None) -> Dict:
def still_learn(self):
return all(model.still_learn for model in self.models.values())

def write_recorder_summaries(self, summaries: Dict[str, Dict]) -> NoReturn:
def write_log(self,
log_step: Union[int, t.Tensor] = None,
summaries: Dict[str, Dict] = {},
step_type: str = None):
'''
write summaries showing in tensorboard.
'''
for id in self._agent_ids:
self.models[id].write_recorder_summaries(summaries=summaries[id])
self.models[id].write_log(log_step=log_step,
summaries=summaries[id],
step_type=step_type)
8 changes: 6 additions & 2 deletions rls/configs/algorithms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ policy: &policy
save2single_file: false
n_step_value: 4
gamma: 0.99
logger_types:
# - "none"
- "tensorboard"
# - "wandb"
decay_lr: false
normalize_vector_obs: false

Expand Down Expand Up @@ -49,13 +53,13 @@ marl_policy: &marl_policy
sarl_on_policy: &sarl_on_policy
<<: *sarl_policy
epochs: 4 # train multiple times per agent step
chunk_length: 4 # n-step or rnn length
chunk_length: 4 # rnn length
batch_size: 64
sample_allow_repeat: true

sarl_off_policy: &sarl_off_policy
<<: *sarl_policy
chunk_length: 4 # n-step or rnn length
chunk_length: 4 # rnn length
epochs: 1 # train multiple times per agent step
train_times: 1
batch_size: 64
Expand Down
Loading

0 comments on commit f3ebb82

Please sign in to comment.