Skip to content

Commit

Permalink
Class-conditioned generation + img2img pipeline
Browse files Browse the repository at this point in the history
* replace text– with class–conditioning
* initial code for basic img2img method
* save initial pipeline once for all experiments
* gradient norm logging
* various fixes & improvements
  • Loading branch information
Thomas Boyer committed Jun 22, 2023
1 parent dbaf2dd commit 809e7bd
Show file tree
Hide file tree
Showing 7 changed files with 981 additions and 194 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ run_script.zsh
data/
accelerate_configs
.ipynb_checkpoints/
.initial_pipeline_save
159 changes: 88 additions & 71 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer

import wandb
from src.args_parser import parse_args
from src.custom_pipeline_stable_diffusion_img2img import (
CustomStableDiffusionImg2ImgPipeline,
)
from src.utils_dataset import setup_dataset
from src.utils_misc import (
args_checker,
Expand All @@ -26,7 +23,7 @@
setup_xformers_memory_efficient_attention,
)
from src.utils_training import (
checkpoint_model,
save_model,
generate_samples_and_compute_metrics,
get_training_setup,
perform_training_epoch,
Expand All @@ -38,7 +35,7 @@

def main(args):
# ------------------------- Checks -------------------------
args_checker(args)
args_checker(args, logger)

# ----------------------- Accelerator ----------------------
accelerator_project_config = ProjectConfiguration(
Expand Down Expand Up @@ -68,7 +65,7 @@ def main(args):
# Make one log on every process with the configuration for debugging.
setup_logger(logger, accelerator)

# ------------------- Repository scruture ------------------
# ------------------ Repository Structure ------------------
(
image_generation_tmp_save_folder,
initial_pipeline_save_folder,
Expand All @@ -91,12 +88,9 @@ def main(args):
# Note that the actual folder to pull components from is
# initial_pipeline_save_folder/snapshots/<gibberish>/ (probably a hash?)
# hence the need to get the *true* save folder (initial_pipeline_save_path)
initial_pipeline_save_path = StableDiffusionPipeline.download(
initial_pipeline_save_path = CustomStableDiffusionImg2ImgPipeline.download(
args.pretrained_model_name_or_path,
cache_dir=initial_pipeline_save_folder,
# override some useless components
safety_checker=None,
feature_extractor=None,
)

# Load the pretrained components
Expand All @@ -106,36 +100,48 @@ def main(args):
local_files_only=True,
)
if args.learn_denoiser_from_scratch:
denoiser_model: UNet2DConditionModel = UNet2DConditionModel.from_config(
denoiser_model_config = UNet2DConditionModel.load_config(
Path(initial_pipeline_save_path, "unet", "config.json"),
)
denoiser_model: UNet2DConditionModel = UNet2DConditionModel.from_config(
denoiser_model_config,
)
else:
denoiser_model: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
initial_pipeline_save_path,
subfolder="unet",
local_files_only=True,
)
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
initial_pipeline_save_path,
subfolder="text_encoder",
local_files_only=True,
)
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
initial_pipeline_save_path,
subfolder="tokenizer",
local_files_only=True,
)

denoiser_model.time_embed_dim = 1024 # TODO: fix this hack

# ------------------ Custom Class Embeddings ---------------
match args.class_embedding_type:
case "one_hot":
raise NotImplementedError(
"Dimensions will mismatch with one-hot encoding; TODO: fix"
)
# class_embedding = torch.nn.functional.one_hot(torch.arange(nb_classes))
# ..?
# class_embedding.to(accelerator.device)
case "embedding":
class_embedding = torch.nn.Embedding(nb_classes, args.class_embedding_dim)
class_embedding.to(accelerator.device)
class_embedding.dtype = torch.float32 # hard fix for accelerate(?!)
case _:
raise ValueError(
f"Unrecognized class embedding type: {args.class_embedding_type}"
)

# ---------------- Move & Freeze Components ----------------
# Move components to device
autoencoder_model.to(accelerator.device)
denoiser_model.to(accelerator.device)
text_encoder.to(accelerator.device)

# ❄️ >>> Freeze components <<< ❄️
autoencoder_model.requires_grad_(False)
text_encoder.requires_grad_(False)

# --------------------- Noise scheduler --------------------
# --------------------- Noise Scheduler --------------------
noise_scheduler = DDIMScheduler(
num_train_timesteps=args.num_training_steps,
beta_start=args.beta_start,
Expand Down Expand Up @@ -174,20 +180,32 @@ def main(args):
eps=args.adam_epsilon,
)

# ----------------- Learning rate scheduler -----------------
# ----------------- Learning Rate Scheduler -----------------
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs),
)

# ------------------ Distributed compute ------------------
denoiser_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
denoiser_model, optimizer, train_dataloader, lr_scheduler
# ------------------ Distributed Compute ------------------
(
denoiser_model,
optimizer,
train_dataloader,
lr_scheduler,
class_embedding,
autoencoder_model,
) = accelerator.prepare(
denoiser_model,
optimizer,
train_dataloader,
lr_scheduler,
class_embedding,
autoencoder_model,
)

# --------------------- Training setup ---------------------
# --------------------- Training Setup ---------------------
if args.use_ema:
ema_unet.to(accelerator.device)

Expand All @@ -207,7 +225,7 @@ def main(args):
actual_eval_batch_sizes_for_this_process,
) = get_training_setup(args, accelerator, train_dataloader, logger, dataset)

# ----------------- Resume from checkpoint -----------------
# ----------------- Resume from Checkpoint -----------------
if args.resume_from_checkpoint:
first_epoch, resume_step, global_step = resume_from_checkpoint(
args, logger, accelerator, num_update_steps_per_epoch, global_step
Expand All @@ -220,55 +238,53 @@ def main(args):
for epoch in range(first_epoch, args.num_epochs):
# Training epoch
global_step = perform_training_epoch(
denoiser_model,
autoencoder_model,
tokenizer,
text_encoder,
num_update_steps_per_epoch,
accelerator,
epoch,
train_dataloader,
args,
first_epoch,
resume_step,
noise_scheduler,
global_step,
optimizer,
lr_scheduler,
ema_unet,
logger,
rng,
denoiser_model=denoiser_model,
autoencoder_model=autoencoder_model,
num_update_steps_per_epoch=num_update_steps_per_epoch,
accelerator=accelerator,
epoch=epoch,
train_dataloader=train_dataloader,
args=args,
first_epoch=first_epoch,
resume_step=resume_step,
noise_scheduler=noise_scheduler,
global_step=global_step,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
ema_model=ema_unet,
logger=logger,
class_embedding=class_embedding,
)

# Generate sample images for visual inspection & metrics computation
if (
epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1
) and epoch > 0:
if epoch % args.generate_images_epochs == 0:
generate_samples_and_compute_metrics(
args,
accelerator,
denoiser_model,
ema_unet,
autoencoder_model,
text_encoder,
tokenizer,
noise_scheduler,
image_generation_tmp_save_folder,
actual_eval_batch_sizes_for_this_process,
epoch,
global_step,
args=args,
accelerator=accelerator,
denoiser_model=denoiser_model,
class_embedding=class_embedding,
ema_model=ema_unet,
noise_scheduler=noise_scheduler,
image_generation_tmp_save_folder=image_generation_tmp_save_folder,
actual_eval_batch_sizes_for_this_process=actual_eval_batch_sizes_for_this_process,
epoch=epoch,
global_step=global_step,
initial_pipeline_save_path=initial_pipeline_save_path,
nb_classes=nb_classes,
logger=logger,
dataset=dataset,
)

if (
accelerator.is_main_process
and (epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1)
and epoch % args.save_model_epochs == 0
and epoch != 0
):
checkpoint_model(
save_model(
accelerator,
denoiser_model,
autoencoder_model,
text_encoder,
tokenizer,
class_embedding,
args,
ema_unet,
noise_scheduler,
Expand All @@ -279,6 +295,7 @@ def main(args):
)

# do not start new epoch before generation & checkpointing is done
# (note that checkpointing may start a BG process?)
accelerator.wait_for_everyone()

accelerator.end_training()
Expand Down
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
setup_xformers_memory_efficient_attention,
)
from .utils_training import (
checkpoint_model,
save_model,
generate_samples_and_compute_metrics,
get_training_setup,
perform_training_epoch,
Expand Down
38 changes: 36 additions & 2 deletions src/args_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,25 @@ def parse_args():
)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument(
"--save_images_epochs",
"--generate_images_epochs",
type=int,
default=100,
help="How often to save images during training.",
)
parser.add_argument("--compute_fid", action="store_true")
parser.add_argument("--compute_isc", action="store_true")
parser.add_argument("--compute_kid", action="store_true")
help_msg = "How many images to generate (per class) for metrics computation. "
help_msg += (
"Only a fraction of the first batch will be logged; the rest will be lost."
)
parser.add_argument("--nb_generated_images", type=int, default=1000, help=help_msg)
parser.add_argument(
"--kid_subset_size",
type=int,
default=1000,
help="Change this if generating very few images (for testing purposes only)",
)
parser.add_argument(
"--save_model_epochs",
type=int,
Expand All @@ -138,8 +147,33 @@ def parse_args():
parser.add_argument(
"--guidance_factor",
type=float,
help="The scaling factor of the guidance ('ω' in the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf; *not* the same definition that in the Classifier-Free Diffusion Guidance paper!). Set to 1 to disable guidance.",
help="The scaling factor of the guidance ('ω' in the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf; *not* the same definition that in the Classifier-Free Diffusion Guidance paper!). Set to <= 1 to disable guidance.",
)
parser.add_argument(
"--proba_uncond",
type=float,
default=0.1,
help="The probability of sampling unconditionally instead of conditionally for the CLF.",
)
parser.add_argument(
"--class_embedding_type",
type=str,
default="embedding",
choices=["OHE", "embedding"],
help="The kind of class embedding to use.",
)
parser.add_argument(
"--class_embedding_dim",
type=int,
default=1024,
help="The dimension of the class embedding.",
)
# TODO: To be used if testing img2img while training
# parser.add_argument(
# "--denoising_starting_point",
# type=float,
# help="The starting point of the denoising schedule (between 0 and 1).",
# )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
Expand Down
Loading

0 comments on commit 809e7bd

Please sign in to comment.