Skip to content

Commit

Permalink
Reconstruct LossAggregator and fix some typos in config files (ShiqiY…
Browse files Browse the repository at this point in the history
…u#100)

* fix Gait3D configs typo

* Use ModuleDict to reconstruct LossAggregator

* fix typo
  • Loading branch information
wj1tr0y authored Jan 14, 2023
1 parent bf8d036 commit 6c05509
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion configs/baseline/baseline_Gait3D.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ trainer_cfg:
enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false
log_iter: 100
with_test: 10000
with_test: true
restore_ckpt_strict: true
restore_hint: 0
save_iter: 10000
Expand Down
2 changes: 1 addition & 1 deletion configs/smplgait/smplgait.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ trainer_cfg:
enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false
log_iter: 100
with_test: 10000
with_test: true
restore_ckpt_strict: true
restore_hint: 0
save_iter: 10000
Expand Down
16 changes: 11 additions & 5 deletions opengait/modeling/loss_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""The loss aggregator."""

import torch
import torch.nn as nn
from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict
from utils import get_msg_mgr


class LossAggregator():
class LossAggregator(nn.Module):
"""The loss aggregator.
This class is used to aggregate the losses.
Expand All @@ -18,16 +19,21 @@ class LossAggregator():
Attributes:
losses: A dict of losses.
"""

def __init__(self, loss_cfg) -> None:
"""
Initialize the loss aggregator.
LossAggregator can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all Module methods.
All parameters registered in losses can be accessed by the method 'self.parameters()',
thus they can be trained properly.
Args:
loss_cfg: Config of losses. List for multiple losses.
"""
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
super().__init__()
self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg})

def _build_loss_(self, loss_cfg):
"""Build the losses from loss_cfg.
Expand All @@ -41,7 +47,7 @@ def _build_loss_(self, loss_cfg):
loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
return loss

def __call__(self, training_feats):
def forward(self, training_feats):
"""Compute the sum of all losses.
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in
Expand Down

0 comments on commit 6c05509

Please sign in to comment.