diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9160aec --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +# Ignore virtual environment files +venv/ +venv38/ +*.pyc diff --git a/README.md b/README.md new file mode 100644 index 0000000..0cb99e5 --- /dev/null +++ b/README.md @@ -0,0 +1,148 @@ +## simple diffusion: End-to-end diffusion for high resolution images +### Unofficial PyTorch Implementation + +**Simple diffusion: End-to-end diffusion for high resolution images** +[Emiel Hoogeboom](https://arxiv.org/search/cs?searchtype=author&query=Hoogeboom,+E), [Jonathan Heek](https://arxiv.org/search/cs?searchtype=author&query=Heek,+J), [Tim Salimans](https://arxiv.org/search/cs?searchtype=author&query=Salimans,+T) +https://arxiv.org/abs/2301.11093 + +### Requirements +* All testing and development was conducted on 4x 16GB NVIDIA V100 GPUs +* 64-bit Python 3.8 and PyTorch 2.1 (or later). See [https://pytorch.org](https://pytorch.org/) for PyTorch install instructions. + +For convenience, a `requirements.txt` file is included to install the required dependencies in an environment of your choice. + +### Usage + +The code for training a diffusion model is self-contained in the `simpleDiffusion` class. Set-up and preparation is included in the `train.py` file: + + from diffusion.unet import UNet2D + from diffusion.simple_diffusion import simpleDiffusion + + from datasets import load_dataset + from torchvision import transforms + import torch + from diffusers.optimization import get_cosine_schedule_with_warmup + + + class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 4 + num_epochs = 100 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 10000 + save_image_epochs = 100 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + + def main(): + + config = TrainingConfig + + dataset_name = "huggan/smithsonian_butterflies_subset" + + dataset = load_dataset(dataset_name, split="train") + + preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + dataset.set_transform(transform) + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=config.train_batch_size, + shuffle=True, + ) + + unet = UNet2D( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + optimizer = torch.optim.Adam(unet.parameters(), lr=config.learning_rate) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=len(train_loader) * config.num_epochs, + ) + + diffusion_model = simpleDiffusion( + unet=unet, + image_size=config.image_size + ) + + diffusion_model.train_loop( + config=config, + optimizer=optimizer, + train_dataloader=train_loader, + lr_scheduler=lr_scheduler + ) + + if __name__ == '__main__': + main() +Multiple versions of the U-Net architecture are available (UNet2DModel, ADM), with U-ViT and others planning to be included in the future. + +### Multi-GPU Training +The `simpleDiffusion` class is equipped with HuggingFace's [Accelerator](https://huggingface.co/docs/accelerate/en/index) wrapper for distributed training. Multi-GPU training is easily done via: +`accelerate launch --multi-gpu train.py` + +### Citations + + @inproceedings{Hoogeboom2023simpleDE, + title = {simple diffusion: End-to-end diffusion for high resolution images}, + author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans}, + year = {2023} + } + + @InProceedings{pmlr-v139-nichol21a, + title = {Improved Denoising Diffusion Probabilistic Models}, + author = {Nichol, Alexander Quinn and Dhariwal, Prafulla}, + booktitle = {Proceedings of the 38th International Conference on Machine Learning}, + pages = {8162--8171}, + year = {2021}, + editor = {Meila, Marina and Zhang, Tong}, + volume = {139}, + series = {Proceedings of Machine Learning Research}, + month = {18--24 Jul}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf}, + url = {https://proceedings.mlr.press/v139/nichol21a.html} + } + + @inproceedings{Hang2023EfficientDT, + title = {Efficient Diffusion Training via Min-SNR Weighting Strategy}, + author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo}, + year = {2023} + } diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffusion/fp16_util.py b/diffusion/fp16_util.py new file mode 100644 index 0000000..41586b2 --- /dev/null +++ b/diffusion/fp16_util.py @@ -0,0 +1,83 @@ +""" +Helpers to train with 16-bit precision. + +Reference: +Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion +Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672 + +The code is adapted from the official implementation at: +https://github.com/openai/improved-diffusion/tree/main/improved_diffusion +""" + +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + l.bias.data = l.bias.data.float() + + +def make_master_params(model_params): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + model_params = list(model_params) + + for param, master_param in zip( + model_params, unflatten_master_params(model_params, master_params) + ): + param.detach().copy_(master_param) + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params)) + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() \ No newline at end of file diff --git a/diffusion/nn.py b/diffusion/nn.py new file mode 100644 index 0000000..0eed005 --- /dev/null +++ b/diffusion/nn.py @@ -0,0 +1,180 @@ +""" +Various utilities for neural networks. + +Reference: +Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion +Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672 + +The code is adapted from the official implementation at: +https://github.com/openai/improved-diffusion/tree/main/improved_diffusion +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if timesteps.dim() == 0: + timesteps = timesteps.unsqueeze(0) + + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with th.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with th.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = th.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads \ No newline at end of file diff --git a/diffusion/simple_diffusion.py b/diffusion/simple_diffusion.py new file mode 100644 index 0000000..3c53b71 --- /dev/null +++ b/diffusion/simple_diffusion.py @@ -0,0 +1,336 @@ +""" +This file contains an implementation of the pseudocode from the paper +"Simple Diffusion: End-to-End Diffusion for High Resolution Images" +by Emiel Hoogeboom, Tim Salimans, and Jonathan Ho. + +Reference: +Hoogeboom, E., Salimans, T., & Ho, J. (2023). +Simple Diffusion: End-to-End Diffusion for High Resolution Images. +Retrieved from https://arxiv.org/abs/2301.11093 +""" + +import torch +import torch.nn as nn +from torch.special import expm1 +import math +from accelerate import Accelerator +import os +from tqdm import tqdm +from ema_pytorch import EMA +import matplotlib.pyplot as plt + +# helper +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +class simpleDiffusion(nn.Module): + def __init__( + self, + unet, + image_size, + noise_size=64, + pred_param='v', + schedule='shifted_cosine', + steps=512 + ): + super().__init__() + + # Training objective + assert pred_param in ['v', 'eps'], "Invalid prediction parameterization. Must be 'v' or 'eps'" + self.pred_param = pred_param + + # Sampling schedule + assert schedule in ['cosine', 'shifted_cosine'], "Invalid schedule. Must be 'cosine' or 'shifted_cosine'" + self.schedule = schedule + self.noise_d = noise_size + self.image_d = image_size + + # Model + assert isinstance(unet, nn.Module), "Model must be an instance of torch.nn.Module." + self.model = unet + + num_params = sum(p.numel() for p in self.model.parameters()) + print(f"Number of parameters: {num_params}") + + # Steps + self.steps = steps + + def diffuse(self, x, alpha_t, sigma_t): + """ + Function to diffuse the input tensor x to a timepoint t with the given alpha_t and sigma_t. + + Args: + x (torch.Tensor): The input tensor to diffuse. + alpha_t (torch.Tensor): The alpha value at timepoint t. + sigma_t (torch.Tensor): The sigma value at timepoint t. + + Returns: + z_t (torch.Tensor): The diffused tensor at timepoint t. + eps_t (torch.Tensor): The noise tensor at timepoint t. + """ + eps_t = torch.randn_like(x) + + z_t = alpha_t * x + sigma_t * eps_t + + return z_t, eps_t + + def logsnr_schedule_cosine(self, t, logsnr_min=-15, logsnr_max=15): + """ + Function to compute the logSNR schedule at timepoint t with cosine: + + logSNR(t) = -2 * log (tan (pi * t / 2)) + + Taking into account boundary effects, the logSNR value at timepoint t is computed as: + + logsnr_t = -2 * log(tan(t_min + t * (t_max - t_min))) + + Args: + t (int): The timepoint t. + logsnr_min (int): The minimum logSNR value. + logsnr_max (int): The maximum logSNR value. + + Returns: + logsnr_t (float): The logSNR value at timepoint t. + """ + logsnr_max = logsnr_max + math.log(self.noise_d / self.image_d) + logsnr_min = logsnr_min + math.log(self.noise_d / self.image_d) + t_min = math.atan(math.exp(-0.5 * logsnr_max)) + t_max = math.atan(math.exp(-0.5 * logsnr_min)) + + logsnr_t = -2 * log(torch.tan(torch.tensor(t_min + t * (t_max - t_min)))) + + return logsnr_t + + def logsnr_schedule_cosine_shifted(self, t): + """ + Function to compute the logSNR schedule at timepoint t with shifted cosine: + + logSNR_shifted(t) = logSNR(t) + 2 * log(noise_d / image_d) + + Args: + t (int): The timepoint t. + image_d (int): The image dimension. + noise_d (int): The noise dimension. + + Returns: + logsnr_t_shifted (float): The logSNR value at timepoint t. + """ + logsnr_t = self.logsnr_schedule_cosine(t) + logsnr_t_shifted = logsnr_t + 2 * math.log(self.noise_d / self.image_d) + + return logsnr_t_shifted + + def clip(self, x): + """ + Function to clip the input tensor x to the range [-1, 1]. + + Args: + x (torch.Tensor): The input tensor to clip. + + Returns: + x (torch.Tensor): The clipped tensor. + """ + return torch.clamp(x, -1, 1) + + @torch.no_grad() + def ddpm_sampler_step(self, z_t, pred, logsnr_t, logsnr_s): + """ + Function to perform a single step of the DDPM sampler. + + Args: + z_t (torch.Tensor): The diffused tensor at timepoint t. + pred (torch.Tensor): The predicted value from the model (v or eps). + logsnr_t (float): The logSNR value at timepoint t. + logsnr_s (float): The logSNR value at the sampling timepoint s. + + Returns: + z_s (torch.Tensor): The diffused tensor at sampling timepoint s. + """ + c = -expm1(logsnr_t - logsnr_s) + alpha_t = torch.sqrt(torch.sigmoid(logsnr_t)) + alpha_s = torch.sqrt(torch.sigmoid(logsnr_s)) + sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t)) + sigma_s = torch.sqrt(torch.sigmoid(-logsnr_s)) + + if self.pred_param == 'v': + x_pred = alpha_t * z_t - sigma_t * pred + elif self.pred_param == 'eps': + x_pred = (z_t - sigma_t * pred) / alpha_t + + x_pred = self.clip(x_pred) + + mu = alpha_s * (z_t * (1 - c) / alpha_t + c * x_pred) + variance = (sigma_s ** 2) * c + + return mu, variance + + @torch.no_grad() + def sample(self, x): + """ + Standard DDPM sampling procedure. Begun by sampling z_T ~ N(0, 1) + and then repeatedly sampling z_s ~ p(z_s | z_t) + + Args: + x_shape (tuple): The shape of the input tensor. + + Returns: + x_pred (torch.Tensor): The predicted tensor. + """ + z_t = torch.randn(x.shape).to(x.device) + + # Steps T -> 1 + for t in reversed(range(1, self.steps+1)): + u_t = t / self.steps + u_s = (t - 1) / self.steps + + if self.schedule == 'cosine': + logsnr_t = self.logsnr_schedule_cosine(u_t) + logsnr_s = self.logsnr_schedule_cosine(u_s) + elif self.schedule == 'shifted_cosine': + logsnr_t = self.logsnr_schedule_cosine_shifted(u_t) + logsnr_s = self.logsnr_schedule_cosine_shifted(u_s) + + logsnr_t = logsnr_t.to(x.device) + logsnr_s = logsnr_s.to(x.device) + + pred = self.model(z_t, logsnr_t) + mu, variance = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_t), torch.tensor(logsnr_s)) + z_t = mu + torch.randn_like(mu) * torch.sqrt(variance) + + # Final step + if self.schedule == 'cosine': + logsnr_1 = self.logsnr_schedule_cosine(1/self.steps) + logsnr_0 = self.logsnr_schedule_cosine(0) + elif self.schedule == 'shifted_cosine': + logsnr_1 = self.logsnr_schedule_cosine_shifted(1/self.steps) + logsnr_0 = self.logsnr_schedule_cosine_shifted(0) + + logsnr_1 = logsnr_1.to(x.device) + logsnr_0 = logsnr_0.to(x.device) + + pred = self.model(z_t, logsnr_1) + x_pred, _ = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_1), torch.tensor(logsnr_0)) + + x_pred = self.clip(x_pred) + + # Convert x_pred to the range [0, 1] + x_pred = (x_pred + 1) / 2 + + return x_pred + + def loss(self, x): + """ + A function to compute the loss of the model. The loss is computed as the mean squared error + between the predicted noise tensor and the true noise tensor. Various prediction parameterizations + imply various weighting schemes as outlined in Kingma et al. (2023) + + Args: + x (torch.Tensor): The input tensor. + + Returns: + loss (torch.Tensor): The loss value. + """ + t = torch.rand(x.shape[0]) + + if self.schedule == 'cosine': + logsnr_t = self.logsnr_schedule_cosine(t) + elif self.schedule == 'shifted_cosine': + logsnr_t = self.logsnr_schedule_cosine_shifted(t) + + logsnr_t = logsnr_t.to(x.device) + alpha_t = torch.sqrt(torch.sigmoid(logsnr_t)).view(-1, 1, 1, 1).to(x.device) + sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t)).view(-1, 1, 1, 1).to(x.device) + z_t, eps_t = self.diffuse(x, alpha_t, sigma_t) + pred = self.model(z_t, logsnr_t) + + if self.pred_param == 'v': + eps_pred = sigma_t * z_t + alpha_t * pred + else: + eps_pred = pred + + # Apply min-SNR weighting (https://arxiv.org/pdf/2303.09556) + snr = torch.exp(logsnr_t).clamp_(max = 5) + if self.pred_param == 'v': + weight = 1 / (1 + snr) + else: + weight = 1 / snr + + weight = weight.view(-1, 1, 1, 1) + + loss = torch.mean(weight * (eps_pred - eps_t) ** 2) + + return loss + + def train_loop(self, config, optimizer, train_dataloader, lr_scheduler): + """ + A function to train the model. + + Args: + optimizer (torch.optim.Optimizer): The optimizer to use for training. + """ + # Initialize accelerator + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if config.output_dir is not None: + os.makedirs(config.output_dir, exist_ok=True) + accelerator.init_trackers("train_example") + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + self.model, optimizer, train_dataloader, lr_scheduler + ) + + # Create an EMA model + ema = EMA( + model, + beta=0.9999, + update_after_step=100, + update_every=10 + ) + + global_step = 0 + + for epoch in range(config.num_epochs): + progress_bar = tqdm(total=len(train_dataloader)) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(train_dataloader): + x = batch["images"] + + with accelerator.accumulate(model): + loss = self.loss(x) + loss = loss.to(next(model.parameters()).dtype) + accelerator.backward(loss) + accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Update EMA model parameters + ema.update() + + progress_bar.update(1) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # After each epoch you optionally sample some demo images + if accelerator.is_main_process: + self.model = accelerator.unwrap_model(model) + self.model.eval() + + # Make directory for saving images + os.makedirs(os.path.join(config.output_dir, "images"), exist_ok=True) + + if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + sample = self.sample(x[0].unsqueeze(0)) + sample = sample.detach().cpu().numpy().transpose(0, 2, 3, 1) + image_path = os.path.join(config.output_dir, "images", f"sample_{epoch}.png") + plt.imsave(image_path, sample[0]) + + # Save the EMA model to disk + torch.save(ema.ema_model.module.state_dict(), 'ema_model.pth') \ No newline at end of file diff --git a/diffusion/unet.py b/diffusion/unet.py new file mode 100644 index 0000000..308bdb4 --- /dev/null +++ b/diffusion/unet.py @@ -0,0 +1,766 @@ +""" +This file contains an implementation of the ADM U-Net from the paper +"Improved Denoising Diffusion Probabilistic Models" by Nichols and +Dhariwal. + +Reference: +Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion +Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672 + +The code is adapted from the official implementation at: +https://github.com/openai/improved-diffusion/tree/main/improved_diffusion + +Other UNet implementations are also included for comparison as well. +""" + +from abc import abstractmethod + +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from diffusers import UNet2DModel + +from .fp16_util import convert_module_to_f16, convert_module_to_f32 +from .nn import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).to(self.norm.weight.dtype) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + img_resolution, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + dropout_from=16, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.img_resolution = img_resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): + if ds < dropout_from: + dropout = 0 + + for _ in range(num_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]: + for i in range(num_blocks + 1): + if ds < dropout_from: + dropout = 0 + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + +#---------------------------------------------------------------------------- +# End-to-end compression model as described in the paper +# "Simple Diffusion: End-to-End Diffusion for High Resolution Images" + +class simpleUNet(nn.Module): + def __init__(self, + img_resolution, # Image resolution at input/output. + in_channels, # Number of color channels at input. + out_channels, # Number of color channels at output. + + model_channels = 192, # Base multiplier for the number of channels. + channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels. + num_res_blocks = [2,2,2,2], # Number of residual blocks per resolution. + attention_resolutions = [32,16,8], # List of resolutions with self-attention. + dropout = 0.10, # List of resolutions with self-attention. + dropout_from = 16, # Start applying dropout from this downsample onwards. + + downsample = 4, # Downsample factor for the initial convolution layer. + + fp16 = False, # Use float16 precision. + ): + super().__init__() + + # Initial convolution layer to downsample the input image by a factor provided + # Assumes that the input image is square and divisible by the downsample factor + self.conv_down = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=downsample, + stride=downsample, + bias=False, + ) + + self.conv_up = self.conv_up = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=downsample, + stride=downsample, + bias=False, + ) + + # ADM backbone + self.unet = UNetModel( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + model_channels=model_channels, + channel_mult=channel_mult, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + dropout_from=dropout_from, + use_fp16=fp16, + ) + + def forward(self, x, noise_labels): + # Downsample the input image + x = self.conv_down(x) + + # Pass through the UNet model + x = self.unet(x, noise_labels) + + # Upsample the output image + x = self.conv_up(x) + + return x + + +#---------------------------------------------------------------------------- +# Adaptation of HuggingFace's UNet2DModel to use with the +# simpleDiffusion pipeline + +class UNet2D(UNet2DModel): + def __init__(self, sample_size, in_channels, out_channels, **kwargs): + super().__init__(sample_size, in_channels, out_channels, **kwargs) + + def forward(self, x, noise_labels): + x = super().forward(x, noise_labels, return_dict=False) + return x[0] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a90233d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,132 @@ +accelerate==0.31.0 +accelerator==2024.6.11 +aiohttp==3.9.5 +aiosignal==1.3.1 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==23.2.0 +backcall==0.2.0 +beartype==0.18.5 +bottle==0.12.25 +certifi==2024.6.2 +charset-normalizer==3.3.2 +colorama==0.4.6 +comet-ml==3.43.2 +comm==0.2.2 +configobj==5.0.8 +contourpy==1.1.1 +cycler==0.12.1 +datasets==2.20.0 +debugpy==1.8.1 +decorator==5.1.1 +diffusers==0.29.1 +dill==0.3.8 +dulwich==0.22.1 +einops==0.8.0 +ema-pytorch==0.4.8 +everett==3.1.0 +exceptiongroup==1.2.1 +executing==2.0.1 +filelock==3.14.0 +fonttools==4.53.0 +frozenlist==1.4.1 +fsspec==2024.5.0 +huggingface-hub==0.23.3 +idna==3.7 +importlib_metadata==7.1.0 +importlib_resources==6.4.0 +iniconfig==2.0.0 +intel-openmp==2021.4.0 +ipykernel==6.29.4 +ipython==8.13.0 +jedi==0.19.1 +Jinja2==3.1.4 +jsonschema==4.22.0 +jsonschema-specifications==2023.12.1 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +kiwisolver==1.4.5 +lz4==4.3.3 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.7.5 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mkl==2021.4.0 +mpmath==1.3.0 +multidict==6.0.5 +multiprocess==0.70.16 +nest-asyncio==1.6.0 +networkx==3.1 +numpy==1.24.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.5.40 +nvidia-nvtx-cu12==12.1.105 +packaging==24.0 +pandas==2.0.3 +parso==0.8.4 +patch-ng==1.17.4 +pexpect==4.9.0 +pickleshare==0.7.5 +pillow==10.3.0 +pkgutil_resolve_name==1.3.10 +platformdirs==4.2.2 +pluggy==1.5.0 +prompt_toolkit==3.0.47 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyarrow==16.1.0 +pyarrow-hotfix==0.6 +Pygments==2.18.0 +pyparsing==3.1.2 +pytest==8.2.2 +python-box==6.1.0 +python-dateutil==2.9.0.post0 +pytorch-fid==0.3.0 +pytz==2024.1 +PyWavelets==1.4.1 +PyYAML==6.0.1 +pyzmq==26.0.3 +referencing==0.35.1 +regex==2024.5.15 +requests==2.32.3 +requests-toolbelt==1.0.0 +rich==13.7.1 +rpds-py==0.18.1 +safetensors==0.4.3 +scipy==1.10.1 +semantic-version==2.10.0 +sentry-sdk==2.6.0 +setproctitle==1.3.3 +simplejson==3.19.2 +six==1.16.0 +stack-data==0.6.3 +sympy==1.12.1 +tbb==2021.12.0 +tomli==2.0.1 +torch==2.3.1 +torchvision==0.18.1 +tornado==6.4.1 +tqdm==4.66.4 +traitlets==5.14.3 +triton==2.3.1 +typing_extensions==4.12.1 +tzdata==2024.1 +urllib3==2.2.1 +waitress==3.0.0 +wcwidth==0.2.13 +wrapt==1.16.0 +wurlitzer==3.1.1 +xxhash==3.4.1 +yarl==1.9.4 +zipp==3.19.2 diff --git a/train.py b/train.py new file mode 100644 index 0000000..5ecf7ab --- /dev/null +++ b/train.py @@ -0,0 +1,94 @@ +""" +A script for training a diffusion model on the Smithsonian Butterflies dataset following the simpleDiffusion paradigm. + +This script uses the UNet2D model from the diffusion library and the simpleDiffusion model from the simple_diffusion library. +""" + +from diffusion.unet import UNet2D, simpleUNet +from diffusion.simple_diffusion import simpleDiffusion + +from datasets import load_dataset +from torchvision import transforms +import torch +from diffusers.optimization import get_cosine_schedule_with_warmup + + +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 4 + num_epochs = 200 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 10000 + save_image_epochs = 50 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +def main(): + + config = TrainingConfig + + dataset_name = "huggan/smithsonian_butterflies_subset" + + dataset = load_dataset(dataset_name, split="train") + + preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + dataset.set_transform(transform) + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=config.train_batch_size, + shuffle=True, + ) + + unet = simpleUNet( + img_resolution=config.image_size, + in_channels=3, + out_channels=3, + model_channels=192, + channel_mult=(1, 1, 2, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2, 2, 2), + attention_resolutions=(16, 8), + dropout=0.1, + dropout_from=16, + downsample=1, + fp16=True + ) + + optimizer = torch.optim.Adam(unet.parameters(), lr=config.learning_rate) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=len(train_loader) * config.num_epochs, + ) + + diffusion_model = simpleDiffusion( + unet=unet, + image_size=config.image_size + ) + + diffusion_model.train_loop( + config=config, + optimizer=optimizer, + train_dataloader=train_loader, + lr_scheduler=lr_scheduler + ) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/unet_configurations.txt b/unet_configurations.txt new file mode 100644 index 0000000..b41c881 --- /dev/null +++ b/unet_configurations.txt @@ -0,0 +1,39 @@ +UNet2D: + unet = UNet2D( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + +simpleUNet: + unet = simpleUNet( + img_resolution=config.image_size, + in_channels=3, + out_channels=3, + model_channels=192, + channel_mult=(1, 1, 2, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2, 2, 2), + attention_resolutions=(16, 8), + dropout=0.1, + dropout_from=16, + downsample=1, + fp16=True + ) \ No newline at end of file