Skip to content

Commit

Permalink
Change EMA initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippThoelke committed May 19, 2021
1 parent b8ed694 commit cd6bc5b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def step(self, batch, loss_fn, stage):
loss_dy = loss_fn(deriv, batch.dy)

if stage in ['train', 'val'] and self.hparams.ema_alpha_dy < 1:
if self.ema[stage + '_dy'] is None:
self.ema[stage + '_dy'] = loss_dy.detach()
# apply exponential smoothing over batches to dy
loss_dy = self.hparams.ema_alpha_dy * loss_dy + (1 - self.hparams.ema_alpha_dy) * self.ema[stage + '_dy']
self.ema[stage + '_dy'] = loss_dy.detach()
Expand All @@ -84,6 +86,8 @@ def step(self, batch, loss_fn, stage):
loss_y = loss_fn(pred, batch.y)

if stage in ['train', 'val'] and self.hparams.ema_alpha_y < 1:
if self.ema[stage + '_y'] is None:
self.ema[stage + '_y'] = loss_y.detach()
# apply exponential smoothing over batches to y
loss_y = self.hparams.ema_alpha_y * loss_y + (1 - self.hparams.ema_alpha_y) * self.ema[stage + '_y']
self.ema[stage + '_y'] = loss_y.detach()
Expand Down Expand Up @@ -151,5 +155,5 @@ def _reset_losses_dict(self):
'train_dy': [], 'val_dy': [], 'test_dy': []}

def _reset_ema_dict(self):
self.ema = {'train_y': 0, 'val_y': 0,
'train_dy': 0, 'val_dy': 0}
self.ema = {'train_y': None, 'val_y': None,
'train_dy': None, 'val_dy': None}

0 comments on commit cd6bc5b

Please sign in to comment.