Skip to content

Commit

Permalink
Fix EMA for multi-gpu training in the unconditional example (huggingf…
Browse files Browse the repository at this point in the history
…ace#1930)

* improve EMA

* style

* one EMA model

* quality

* fix tests

* fix test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* re organise the unconditional script

* backwards compatibility

* default to init values for some args

* fix ort script

* issubclass => isinstance

* update state_dict

* docstr

* doc

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* use .to if device is passed

* deprecate device

* make flake happy

* fix typo

Co-authored-by: patil-suraj <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
4 people authored Jan 19, 2023
1 parent f354dd9 commit 7c82a16
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 216 deletions.
113 changes: 2 additions & 111 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and

import argparse
import copy
import logging
import math
import os
import random
from pathlib import Path
from typing import Iterable, Optional
from typing import Optional

import numpy as np
import torch
Expand All @@ -36,6 +35,7 @@
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
Expand Down Expand Up @@ -305,115 +305,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
}


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""

def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]

self.collected_params = None

self.decay = decay
self.optimization_step = 0

@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)

self.optimization_step += 1

# Compute the decay factor for the exponential moving average.
value = (1 + self.optimization_step) / (10 + self.optimization_step)
one_minus_decay = 1 - min(self.decay, value)

for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param))
else:
s_param.copy_(param)

torch.cuda.empty_cache()

def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]

def state_dict(self) -> dict:
r"""
Returns the state of the ExponentialMovingAverage as a dict.
This method is used by accelerate during checkpointing to save the ema state dict.
"""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"optimization_step": self.optimization_step,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}

def load_state_dict(self, state_dict: dict) -> None:
r"""
Loads the ExponentialMovingAverage state.
This method is used by accelerate during checkpointing to save the ema state dict.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)

self.decay = state_dict["decay"]
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError("Decay must be between 0 and 1")

self.optimization_step = state_dict["optimization_step"]
if not isinstance(self.optimization_step, int):
raise ValueError("Invalid optimization_step")

self.shadow_params = state_dict["shadow_params"]
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")

self.collected_params = state_dict["collected_params"]
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
raise ValueError("collected_params must all be Tensors")
if len(self.collected_params) != len(self.shadow_params):
raise ValueError("collected_params and shadow_params must have the same length")


def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
Expand Down
3 changes: 3 additions & 0 deletions examples/unconditional_image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ accelerate launch train_unconditional.py \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--use_ema \
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
Expand All @@ -63,6 +64,7 @@ accelerate launch train_unconditional.py \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--use_ema \
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
Expand Down Expand Up @@ -150,6 +152,7 @@ accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="ddpm-ema-flowers-64" \
--use_ema \
--train_batch_size=16 \
--num_epochs=1 \
--gradient_accumulation_steps=1 \
Expand Down
Loading

0 comments on commit 7c82a16

Please sign in to comment.