Skip to content

Commit

Permalink
[Examples] add compute_snr() to training utils. (huggingface#5188)
Browse files Browse the repository at this point in the history
add compute_snr() to training utils.
  • Loading branch information
sayakpaul authored Sep 27, 2023
1 parent ba59e92 commit cdcc01b
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 256 deletions.
27 changes: 2 additions & 25 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -224,30 +225,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")


def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR
snr = (alpha / sigma) ** 2
return snr


def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
Expand Down Expand Up @@ -1302,7 +1279,7 @@ def compute_text_embeddings(prompt):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps, noise_scheduler)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
28 changes: 2 additions & 26 deletions examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import diffusers
from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -530,30 +530,6 @@ def deepspeed_zero_init_disabled_context_manager():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -800,7 +776,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available


Expand Down Expand Up @@ -419,30 +420,6 @@ def main():

unet.set_attn_processor(lora_attn_procs)

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

lora_layers = AttnProcsLayers(unet.attn_processors)

if args.allow_tf32:
Expand Down Expand Up @@ -653,7 +630,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available


Expand Down Expand Up @@ -413,31 +414,6 @@ def main():
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)

prior.set_attn_processor(lora_attn_procs)

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

lora_layers = AttnProcsLayers(prior.attn_processors)

if args.allow_tf32:
Expand Down Expand Up @@ -684,7 +660,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
28 changes: 2 additions & 26 deletions examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import diffusers
from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid


Expand Down Expand Up @@ -523,30 +523,6 @@ def deepspeed_zero_init_disabled_context_manager():
ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
ema_prior.to(accelerator.device)

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -832,7 +808,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -524,30 +524,6 @@ def deepspeed_zero_init_disabled_context_manager():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -871,7 +847,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
28 changes: 2 additions & 26 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -601,30 +601,6 @@ def deepspeed_zero_init_disabled_context_manager():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -951,7 +927,7 @@ def collate_fn(examples):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
Expand Down
Loading

0 comments on commit cdcc01b

Please sign in to comment.