From 6c05509add6d9b373451be1c60e8f0709f801ba6 Mon Sep 17 00:00:00 2001 From: Jilong Wang Date: Sat, 14 Jan 2023 17:33:38 +0800 Subject: [PATCH] Reconstruct LossAggregator and fix some typos in config files (#100) * fix Gait3D configs typo * Use ModuleDict to reconstruct LossAggregator * fix typo --- configs/baseline/baseline_Gait3D.yaml | 2 +- configs/smplgait/smplgait.yaml | 2 +- opengait/modeling/loss_aggregator.py | 16 +++++++++++----- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/configs/baseline/baseline_Gait3D.yaml b/configs/baseline/baseline_Gait3D.yaml index 858f060..36f365b 100644 --- a/configs/baseline/baseline_Gait3D.yaml +++ b/configs/baseline/baseline_Gait3D.yaml @@ -83,7 +83,7 @@ trainer_cfg: enable_float16: true # half_percesion float for memory reduction and speedup fix_BN: false log_iter: 100 - with_test: 10000 + with_test: true restore_ckpt_strict: true restore_hint: 0 save_iter: 10000 diff --git a/configs/smplgait/smplgait.yaml b/configs/smplgait/smplgait.yaml index 2ae433c..889489a 100644 --- a/configs/smplgait/smplgait.yaml +++ b/configs/smplgait/smplgait.yaml @@ -85,7 +85,7 @@ trainer_cfg: enable_float16: true # half_percesion float for memory reduction and speedup fix_BN: false log_iter: 100 - with_test: 10000 + with_test: true restore_ckpt_strict: true restore_hint: 0 save_iter: 10000 diff --git a/opengait/modeling/loss_aggregator.py b/opengait/modeling/loss_aggregator.py index f5d72b3..a3f5982 100644 --- a/opengait/modeling/loss_aggregator.py +++ b/opengait/modeling/loss_aggregator.py @@ -1,13 +1,14 @@ """The loss aggregator.""" import torch +import torch.nn as nn from . import losses from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module from utils import Odict from utils import get_msg_mgr -class LossAggregator(): +class LossAggregator(nn.Module): """The loss aggregator. This class is used to aggregate the losses. @@ -18,16 +19,21 @@ class LossAggregator(): Attributes: losses: A dict of losses. """ - def __init__(self, loss_cfg) -> None: """ Initialize the loss aggregator. + LossAggregator can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all Module methods. + All parameters registered in losses can be accessed by the method 'self.parameters()', + thus they can be trained properly. + Args: loss_cfg: Config of losses. List for multiple losses. """ - self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ - else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg} + super().__init__() + self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ + else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}) def _build_loss_(self, loss_cfg): """Build the losses from loss_cfg. @@ -41,7 +47,7 @@ def _build_loss_(self, loss_cfg): loss = get_ddp_module(Loss(**valid_loss_arg).cuda()) return loss - def __call__(self, training_feats): + def forward(self, training_feats): """Compute the sum of all losses. The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in