Skip to content

Commit

Permalink
implement ema warmup from @crowsonkb (lucidrains#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
nousr authored Jun 4, 2022
1 parent 22cc613 commit 64c2f9c
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions dalle2_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper

class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_after_step = 10000,
update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
):
super().__init__()
self.beta = beta
Expand All @@ -190,6 +212,10 @@ def __init__(
self.update_every = update_every
self.update_after_step = update_after_step

self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value

self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))

Expand All @@ -201,6 +227,11 @@ def copy_params_from_model_to_ema(self):
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)

def get_current_decay(self):
epoch = max(0, self.step.item() - self.update_after_step - 1)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))

def update(self):
step = self.step.item()
self.step += 1
Expand All @@ -220,14 +251,16 @@ def update(self):

@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
current_decay = self.get_current_decay()

for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
difference = ma_params.data - current_params.data
difference.mul_(1.0 - self.beta)
difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)

for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
difference = ma_buffer - current_buffer
difference.mul_(1.0 - self.beta)
difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)

def __call__(self, *args, **kwargs):
Expand Down

0 comments on commit 64c2f9c

Please sign in to comment.