diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..9160aec
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+# Ignore virtual environment files
+venv/
+venv38/
+*.pyc
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..0cb99e5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,148 @@
+## simple diffusion: End-to-end diffusion for high resolution images
+### Unofficial PyTorch Implementation 
+
+**Simple diffusion: End-to-end diffusion for high resolution images**
+[Emiel Hoogeboom](https://arxiv.org/search/cs?searchtype=author&query=Hoogeboom,+E), [Jonathan Heek](https://arxiv.org/search/cs?searchtype=author&query=Heek,+J), [Tim Salimans](https://arxiv.org/search/cs?searchtype=author&query=Salimans,+T)
+https://arxiv.org/abs/2301.11093
+
+### Requirements
+* All testing and development was conducted on 4x 16GB NVIDIA V100 GPUs
+* 64-bit Python 3.8 and PyTorch 2.1 (or later). See  [https://pytorch.org](https://pytorch.org/)  for PyTorch install instructions.
+
+For convenience, a `requirements.txt` file is included to install the required dependencies in an environment of your choice.
+
+### Usage
+
+The code for training a diffusion model is self-contained in the `simpleDiffusion` class. Set-up and preparation is included in the `train.py` file:
+
+    from diffusion.unet import UNet2D
+    from diffusion.simple_diffusion import simpleDiffusion
+    
+    from datasets import load_dataset
+    from torchvision import transforms
+    import torch
+    from diffusers.optimization import get_cosine_schedule_with_warmup
+    
+    
+    class TrainingConfig:
+        image_size = 128  # the generated image resolution
+        train_batch_size = 4
+        num_epochs = 100
+        gradient_accumulation_steps = 1
+        learning_rate = 5e-5
+        lr_warmup_steps = 10000
+        save_image_epochs = 100
+        mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
+        output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub
+        overwrite_output_dir = True  # overwrite the old model when re-running the notebook
+        seed = 0
+    
+    
+    def main():
+	    
+	    config = TrainingConfig
+
+	    dataset_name = "huggan/smithsonian_butterflies_subset"
+
+	    dataset = load_dataset(dataset_name, split="train")
+
+	    preprocess = transforms.Compose(
+	        [
+	            transforms.Resize((config.image_size, config.image_size)),
+	            transforms.RandomHorizontalFlip(),
+	            transforms.ToTensor(),
+	            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+	        ]
+	    )
+
+	    def transform(examples):
+	        images = [preprocess(image.convert("RGB")) for image in examples["image"]]
+	        return {"images": images}
+
+	    dataset.set_transform(transform)
+
+	    train_loader = torch.utils.data.DataLoader(
+	        dataset,
+	        batch_size=config.train_batch_size,
+	        shuffle=True,
+	    )
+
+	    unet = UNet2D(
+	        sample_size=config.image_size,  # the target image resolution
+	        in_channels=3,  # the number of input channels, 3 for RGB images
+	        out_channels=3,  # the number of output channels
+	        layers_per_block=2,  # how many ResNet layers to use per UNet block
+	        block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
+	        down_block_types=(
+	            "DownBlock2D",
+	            "DownBlock2D",
+	            "DownBlock2D",
+	            "DownBlock2D",
+	            "AttnDownBlock2D",
+	            "DownBlock2D",
+	        ),
+	        up_block_types=(
+	            "UpBlock2D",
+	            "AttnUpBlock2D",
+	            "UpBlock2D",
+	            "UpBlock2D",
+	            "UpBlock2D",
+	            "UpBlock2D",
+	        ),
+	    )
+
+	    optimizer = torch.optim.Adam(unet.parameters(), lr=config.learning_rate)
+	    lr_scheduler = get_cosine_schedule_with_warmup(
+	        optimizer,
+	        num_warmup_steps=config.lr_warmup_steps,
+	        num_training_steps=len(train_loader) * config.num_epochs,
+	    )
+
+	    diffusion_model = simpleDiffusion(
+	        unet=unet,
+	        image_size=config.image_size
+	    )
+
+	    diffusion_model.train_loop(
+	        config=config,
+	        optimizer=optimizer,
+	        train_dataloader=train_loader,
+	        lr_scheduler=lr_scheduler
+	    )
+	
+	if __name__ == '__main__':
+	    main()
+Multiple versions of the U-Net architecture are available (UNet2DModel, ADM), with U-ViT and others planning to be included in the future.
+
+### Multi-GPU Training
+The `simpleDiffusion` class is equipped with HuggingFace's [Accelerator](https://huggingface.co/docs/accelerate/en/index) wrapper for distributed training. Multi-GPU training is easily done via:
+`accelerate launch --multi-gpu train.py`
+
+### Citations
+
+    @inproceedings{Hoogeboom2023simpleDE,
+	    title   = {simple diffusion: End-to-end diffusion for high resolution images},
+	    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
+	    year    = {2023}
+	    }
+    
+    @InProceedings{pmlr-v139-nichol21a,
+	    title       = {Improved Denoising Diffusion Probabilistic Models},
+	    author      = {Nichol, Alexander Quinn and Dhariwal, Prafulla},
+	    booktitle   = {Proceedings of the 38th International Conference on Machine Learning},
+	    pages       = {8162--8171},
+	    year        = {2021},
+	    editor      = {Meila, Marina and Zhang, Tong},
+	    volume      = {139},
+	    series      = {Proceedings of Machine Learning Research},
+	    month       = {18--24 Jul},
+	    publisher   = {PMLR},
+	    pdf         = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf},
+	    url         = {https://proceedings.mlr.press/v139/nichol21a.html}
+	    }
+
+    @inproceedings{Hang2023EfficientDT,
+	    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
+	    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
+	    year    = {2023}
+	    }
diff --git a/diffusion/__init__.py b/diffusion/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/diffusion/fp16_util.py b/diffusion/fp16_util.py
new file mode 100644
index 0000000..41586b2
--- /dev/null
+++ b/diffusion/fp16_util.py
@@ -0,0 +1,83 @@
+"""
+Helpers to train with 16-bit precision.
+
+Reference:
+Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion
+Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672
+
+The code is adapted from the official implementation at:
+https://github.com/openai/improved-diffusion/tree/main/improved_diffusion
+"""
+
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+
+def convert_module_to_f16(l):
+    """
+    Convert primitive modules to float16.
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.half()
+        l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+    """
+    Convert primitive modules to float32, undoing convert_module_to_f16().
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.float()
+        l.bias.data = l.bias.data.float()
+
+
+def make_master_params(model_params):
+    """
+    Copy model parameters into a (differently-shaped) list of full-precision
+    parameters.
+    """
+    master_params = _flatten_dense_tensors(
+        [param.detach().float() for param in model_params]
+    )
+    master_params = nn.Parameter(master_params)
+    master_params.requires_grad = True
+    return [master_params]
+
+
+def model_grads_to_master_grads(model_params, master_params):
+    """
+    Copy the gradients from the model parameters into the master parameters
+    from make_master_params().
+    """
+    master_params[0].grad = _flatten_dense_tensors(
+        [param.grad.data.detach().float() for param in model_params]
+    )
+
+
+def master_params_to_model_params(model_params, master_params):
+    """
+    Copy the master parameter data back into the model parameters.
+    """
+    # Without copying to a list, if a generator is passed, this will
+    # silently not copy any parameters.
+    model_params = list(model_params)
+
+    for param, master_param in zip(
+        model_params, unflatten_master_params(model_params, master_params)
+    ):
+        param.detach().copy_(master_param)
+
+
+def unflatten_master_params(model_params, master_params):
+    """
+    Unflatten the master parameters to look like model_params.
+    """
+    return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params))
+
+
+def zero_grad(model_params):
+    for param in model_params:
+        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
+        if param.grad is not None:
+            param.grad.detach_()
+            param.grad.zero_()
\ No newline at end of file
diff --git a/diffusion/nn.py b/diffusion/nn.py
new file mode 100644
index 0000000..0eed005
--- /dev/null
+++ b/diffusion/nn.py
@@ -0,0 +1,180 @@
+"""
+Various utilities for neural networks.
+
+Reference:
+Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion
+Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672
+
+The code is adapted from the official implementation at:
+https://github.com/openai/improved-diffusion/tree/main/improved_diffusion
+"""
+
+import math
+
+import torch as th
+import torch.nn as nn
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * th.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+    """
+    Update target parameters to be closer to those of source parameters using
+    an exponential moving average.
+
+    :param target_params: the target parameter sequence.
+    :param source_params: the source parameter sequence.
+    :param rate: the EMA rate (closer to 1 means slower).
+    """
+    for targ, src in zip(target_params, source_params):
+        targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+    """
+    Create sinusoidal timestep embeddings.
+
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if timesteps.dim() == 0:
+        timesteps = timesteps.unsqueeze(0)
+
+    half = dim // 2
+    freqs = th.exp(
+        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
+    ).to(device=timesteps.device)
+    args = timesteps[:, None].float() * freqs[None]
+    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+    if dim % 2:
+        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+    return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(th.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+        with th.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with th.enable_grad():
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = th.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
\ No newline at end of file
diff --git a/diffusion/simple_diffusion.py b/diffusion/simple_diffusion.py
new file mode 100644
index 0000000..3c53b71
--- /dev/null
+++ b/diffusion/simple_diffusion.py
@@ -0,0 +1,336 @@
+"""
+This file contains an implementation of the pseudocode from the paper 
+"Simple Diffusion: End-to-End Diffusion for High Resolution Images" 
+by Emiel Hoogeboom, Tim Salimans, and Jonathan Ho.
+
+Reference:
+Hoogeboom, E., Salimans, T., & Ho, J. (2023). 
+Simple Diffusion: End-to-End Diffusion for High Resolution Images. 
+Retrieved from https://arxiv.org/abs/2301.11093
+"""
+
+import torch
+import torch.nn as nn
+from torch.special import expm1
+import math
+from accelerate import Accelerator
+import os
+from tqdm import tqdm
+from ema_pytorch import EMA
+import matplotlib.pyplot as plt
+
+# helper
+def log(t, eps = 1e-20):
+    return torch.log(t.clamp(min = eps))
+
+class simpleDiffusion(nn.Module):
+    def __init__(
+        self, 
+        unet,
+        image_size,
+        noise_size=64,
+        pred_param='v', 
+        schedule='shifted_cosine', 
+        steps=512
+    ):
+        super().__init__()
+
+        # Training objective
+        assert pred_param in ['v', 'eps'], "Invalid prediction parameterization. Must be 'v' or 'eps'"
+        self.pred_param = pred_param
+
+        # Sampling schedule
+        assert schedule in ['cosine', 'shifted_cosine'], "Invalid schedule. Must be 'cosine' or 'shifted_cosine'"
+        self.schedule = schedule
+        self.noise_d = noise_size
+        self.image_d = image_size
+
+        # Model
+        assert isinstance(unet, nn.Module), "Model must be an instance of torch.nn.Module."
+        self.model = unet
+
+        num_params = sum(p.numel() for p in self.model.parameters())
+        print(f"Number of parameters: {num_params}")
+
+        # Steps
+        self.steps = steps
+
+    def diffuse(self, x, alpha_t, sigma_t):
+        """
+        Function to diffuse the input tensor x to a timepoint t with the given alpha_t and sigma_t.
+
+        Args:
+        x (torch.Tensor): The input tensor to diffuse.
+        alpha_t (torch.Tensor): The alpha value at timepoint t.
+        sigma_t (torch.Tensor): The sigma value at timepoint t.
+
+        Returns:
+        z_t (torch.Tensor): The diffused tensor at timepoint t.
+        eps_t (torch.Tensor): The noise tensor at timepoint t.
+        """
+        eps_t = torch.randn_like(x)
+
+        z_t = alpha_t * x + sigma_t * eps_t
+
+        return z_t, eps_t
+
+    def logsnr_schedule_cosine(self, t, logsnr_min=-15, logsnr_max=15):
+        """
+        Function to compute the logSNR schedule at timepoint t with cosine:
+
+        logSNR(t) = -2 * log (tan (pi * t / 2))
+
+        Taking into account boundary effects, the logSNR value at timepoint t is computed as:
+
+        logsnr_t = -2 * log(tan(t_min + t * (t_max - t_min)))
+
+        Args:
+        t (int): The timepoint t.
+        logsnr_min (int): The minimum logSNR value.
+        logsnr_max (int): The maximum logSNR value.
+
+        Returns:
+        logsnr_t (float): The logSNR value at timepoint t.
+        """
+        logsnr_max = logsnr_max + math.log(self.noise_d / self.image_d)
+        logsnr_min = logsnr_min + math.log(self.noise_d / self.image_d)
+        t_min = math.atan(math.exp(-0.5 * logsnr_max))
+        t_max = math.atan(math.exp(-0.5 * logsnr_min))
+
+        logsnr_t = -2 * log(torch.tan(torch.tensor(t_min + t * (t_max - t_min))))
+
+        return logsnr_t
+
+    def logsnr_schedule_cosine_shifted(self, t):
+        """
+        Function to compute the logSNR schedule at timepoint t with shifted cosine:
+
+        logSNR_shifted(t) = logSNR(t) + 2 * log(noise_d / image_d)
+
+        Args:
+        t (int): The timepoint t.
+        image_d (int): The image dimension.
+        noise_d (int): The noise dimension.
+
+        Returns:
+        logsnr_t_shifted (float): The logSNR value at timepoint t.
+        """
+        logsnr_t = self.logsnr_schedule_cosine(t)
+        logsnr_t_shifted = logsnr_t + 2 * math.log(self.noise_d / self.image_d)
+
+        return logsnr_t_shifted
+        
+    def clip(self, x):
+        """
+        Function to clip the input tensor x to the range [-1, 1].
+
+        Args:
+        x (torch.Tensor): The input tensor to clip.
+
+        Returns:
+        x (torch.Tensor): The clipped tensor.
+        """
+        return torch.clamp(x, -1, 1)
+
+    @torch.no_grad()
+    def ddpm_sampler_step(self, z_t, pred, logsnr_t, logsnr_s):
+        """
+        Function to perform a single step of the DDPM sampler.
+
+        Args:
+        z_t (torch.Tensor): The diffused tensor at timepoint t.
+        pred (torch.Tensor): The predicted value from the model (v or eps).
+        logsnr_t (float): The logSNR value at timepoint t.
+        logsnr_s (float): The logSNR value at the sampling timepoint s.
+
+        Returns:
+        z_s (torch.Tensor): The diffused tensor at sampling timepoint s.
+        """
+        c = -expm1(logsnr_t - logsnr_s)
+        alpha_t = torch.sqrt(torch.sigmoid(logsnr_t))
+        alpha_s = torch.sqrt(torch.sigmoid(logsnr_s))
+        sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t))
+        sigma_s = torch.sqrt(torch.sigmoid(-logsnr_s))
+
+        if self.pred_param == 'v':
+            x_pred = alpha_t * z_t - sigma_t * pred
+        elif self.pred_param == 'eps':
+            x_pred = (z_t - sigma_t * pred) / alpha_t
+
+        x_pred = self.clip(x_pred)
+
+        mu = alpha_s * (z_t * (1 - c) / alpha_t + c * x_pred)
+        variance = (sigma_s ** 2) * c
+
+        return mu, variance
+
+    @torch.no_grad()
+    def sample(self, x):
+        """
+        Standard DDPM sampling procedure. Begun by sampling z_T ~ N(0, 1)
+        and then repeatedly sampling z_s ~ p(z_s | z_t)
+
+        Args:
+        x_shape (tuple): The shape of the input tensor.
+
+        Returns:
+        x_pred (torch.Tensor): The predicted tensor.
+        """
+        z_t = torch.randn(x.shape).to(x.device)
+
+        # Steps T -> 1
+        for t in reversed(range(1, self.steps+1)):
+            u_t = t / self.steps
+            u_s = (t - 1) / self.steps
+
+            if self.schedule == 'cosine':
+                logsnr_t = self.logsnr_schedule_cosine(u_t)
+                logsnr_s = self.logsnr_schedule_cosine(u_s)
+            elif self.schedule == 'shifted_cosine':
+                logsnr_t = self.logsnr_schedule_cosine_shifted(u_t)
+                logsnr_s = self.logsnr_schedule_cosine_shifted(u_s)
+
+            logsnr_t = logsnr_t.to(x.device)
+            logsnr_s = logsnr_s.to(x.device)
+
+            pred = self.model(z_t, logsnr_t)
+            mu, variance = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_t), torch.tensor(logsnr_s))
+            z_t = mu + torch.randn_like(mu) * torch.sqrt(variance)
+
+        # Final step
+        if self.schedule == 'cosine':
+            logsnr_1 = self.logsnr_schedule_cosine(1/self.steps)
+            logsnr_0 = self.logsnr_schedule_cosine(0)
+        elif self.schedule == 'shifted_cosine':
+            logsnr_1 = self.logsnr_schedule_cosine_shifted(1/self.steps)
+            logsnr_0 = self.logsnr_schedule_cosine_shifted(0)
+
+        logsnr_1 = logsnr_1.to(x.device)
+        logsnr_0 = logsnr_0.to(x.device)
+
+        pred = self.model(z_t, logsnr_1)
+        x_pred, _ = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_1), torch.tensor(logsnr_0))
+        
+        x_pred = self.clip(x_pred)
+
+        # Convert x_pred to the range [0, 1]
+        x_pred = (x_pred + 1) / 2
+
+        return x_pred
+    
+    def loss(self, x):
+        """
+        A function to compute the loss of the model. The loss is computed as the mean squared error
+        between the predicted noise tensor and the true noise tensor. Various prediction parameterizations
+        imply various weighting schemes as outlined in Kingma et al. (2023)
+
+        Args:
+        x (torch.Tensor): The input tensor.
+
+        Returns:
+        loss (torch.Tensor): The loss value.
+        """
+        t = torch.rand(x.shape[0])
+
+        if self.schedule == 'cosine':
+            logsnr_t = self.logsnr_schedule_cosine(t)
+        elif self.schedule == 'shifted_cosine':
+            logsnr_t = self.logsnr_schedule_cosine_shifted(t)
+
+        logsnr_t = logsnr_t.to(x.device)
+        alpha_t = torch.sqrt(torch.sigmoid(logsnr_t)).view(-1, 1, 1, 1).to(x.device)
+        sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t)).view(-1, 1, 1, 1).to(x.device)
+        z_t, eps_t = self.diffuse(x, alpha_t, sigma_t)
+        pred = self.model(z_t, logsnr_t)
+
+        if self.pred_param == 'v':
+            eps_pred = sigma_t * z_t + alpha_t * pred
+        else: 
+            eps_pred = pred
+
+        # Apply min-SNR weighting (https://arxiv.org/pdf/2303.09556)
+        snr = torch.exp(logsnr_t).clamp_(max = 5)
+        if self.pred_param == 'v':
+            weight = 1 / (1 + snr)
+        else:
+            weight = 1 / snr
+
+        weight = weight.view(-1, 1, 1, 1)
+
+        loss = torch.mean(weight * (eps_pred - eps_t) ** 2)
+
+        return loss
+
+    def train_loop(self, config, optimizer, train_dataloader, lr_scheduler):
+        """
+        A function to train the model.
+
+        Args:
+        optimizer (torch.optim.Optimizer): The optimizer to use for training.
+        """
+        # Initialize accelerator
+        accelerator = Accelerator(
+            mixed_precision=config.mixed_precision,
+            gradient_accumulation_steps=config.gradient_accumulation_steps,
+            project_dir=os.path.join(config.output_dir, "logs"),
+        )
+        if accelerator.is_main_process:
+            if config.output_dir is not None:
+                os.makedirs(config.output_dir, exist_ok=True)
+            accelerator.init_trackers("train_example")
+
+        model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 
+            self.model, optimizer, train_dataloader, lr_scheduler
+        )
+
+        # Create an EMA model
+        ema = EMA(
+            model,
+            beta=0.9999,
+            update_after_step=100,
+            update_every=10
+        )
+
+        global_step = 0
+
+        for epoch in range(config.num_epochs):
+            progress_bar = tqdm(total=len(train_dataloader))
+            progress_bar.set_description(f"Epoch {epoch}")
+
+            for step, batch in enumerate(train_dataloader):
+                x = batch["images"]
+
+                with accelerator.accumulate(model):
+                    loss = self.loss(x)
+                    loss = loss.to(next(model.parameters()).dtype)
+                    accelerator.backward(loss)
+                    accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
+                    optimizer.step()
+                    lr_scheduler.step()
+                    optimizer.zero_grad()
+
+                # Update EMA model parameters
+                ema.update()
+
+                progress_bar.update(1)
+                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+                progress_bar.set_postfix(**logs)
+                accelerator.log(logs, step=global_step)
+                global_step += 1
+
+            # After each epoch you optionally sample some demo images
+            if accelerator.is_main_process:
+                self.model = accelerator.unwrap_model(model)
+                self.model.eval()
+
+            # Make directory for saving images
+                os.makedirs(os.path.join(config.output_dir, "images"), exist_ok=True)
+
+                if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
+                    sample = self.sample(x[0].unsqueeze(0))
+                    sample = sample.detach().cpu().numpy().transpose(0, 2, 3, 1)
+                    image_path = os.path.join(config.output_dir, "images", f"sample_{epoch}.png")
+                    plt.imsave(image_path, sample[0])
+        
+        # Save the EMA model to disk
+        torch.save(ema.ema_model.module.state_dict(), 'ema_model.pth')
\ No newline at end of file
diff --git a/diffusion/unet.py b/diffusion/unet.py
new file mode 100644
index 0000000..308bdb4
--- /dev/null
+++ b/diffusion/unet.py
@@ -0,0 +1,766 @@
+"""
+This file contains an implementation of the ADM U-Net from the paper 
+"Improved Denoising Diffusion Probabilistic Models" by Nichols and
+Dhariwal.
+
+Reference:
+Nichols, J., & Dhariwal, P. (2021). Improved Denoising Diffusion
+Probabilistic Models. Retrieved from https://arxiv.org/abs/2102.09672
+
+The code is adapted from the official implementation at:
+https://github.com/openai/improved-diffusion/tree/main/improved_diffusion
+
+Other UNet implementations are also included for comparison as well.
+"""
+
+from abc import abstractmethod
+
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers import UNet2DModel
+
+from .fp16_util import convert_module_to_f16, convert_module_to_f32
+from .nn import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+)
+
+
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
+        )
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1)  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(self._forward, (x,), self.parameters(), True)
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1).to(self.norm.weight.dtype)
+        qkv = self.qkv(self.norm(x))
+        h = self.attention(qkv)
+        h = self.proj_out(h)
+        return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial ** 2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+    """
+    The full UNet model with attention and timestep embedding.
+
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    """
+
+    def __init__(
+        self,
+        img_resolution,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        dropout_from=16,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        self.img_resolution = img_resolution
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        self.num_res_blocks = num_res_blocks
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        if self.num_classes is not None:
+            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+        ch = input_ch = int(channel_mult[0] * model_channels)
+        self.input_blocks = nn.ModuleList(
+            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+        )
+        self._feature_size = ch
+        input_block_chans = [ch]
+        ds = 1
+        for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
+            if ds < dropout_from:
+                dropout = 0
+            
+            for _ in range(num_blocks):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=int(mult * model_channels),
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = int(mult * model_channels)
+                if ds in attention_resolutions:
+                    layers.append(
+                        AttentionBlock(
+                            ch,
+                            use_checkpoint=use_checkpoint,
+                            num_heads=num_heads,
+                            num_head_channels=num_head_channels,
+                            use_new_attention_order=use_new_attention_order,
+                        )
+                    )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=num_head_channels,
+                use_new_attention_order=use_new_attention_order,
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
+            for i in range(num_blocks + 1):
+                if ds < dropout_from:
+                    dropout = 0
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=int(model_channels * mult),
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = int(model_channels * mult)
+                if ds in attention_resolutions:
+                    layers.append(
+                        AttentionBlock(
+                            ch,
+                            use_checkpoint=use_checkpoint,
+                            num_heads=num_heads_upsample,
+                            num_head_channels=num_head_channels,
+                            use_new_attention_order=use_new_attention_order,
+                        )
+                    )
+                if level and i == num_blocks:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+        )
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(self, x, timesteps, y=None):
+        """
+        Apply the model to an input batch.
+
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :param y: an [N] Tensor of labels, if class-conditional.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+
+        hs = []
+        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+        if self.num_classes is not None:
+            assert y.shape == (x.shape[0],)
+            emb = emb + self.label_emb(y)
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb)
+            hs.append(h)
+        h = self.middle_block(h, emb)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb)
+        h = h.type(x.dtype)
+        return self.out(h)
+    
+#----------------------------------------------------------------------------
+# End-to-end compression model as described in the paper 
+# "Simple Diffusion: End-to-End Diffusion for High Resolution Images"
+
+class simpleUNet(nn.Module):
+    def __init__(self,
+        img_resolution,                         # Image resolution at input/output.
+        in_channels,                            # Number of color channels at input.
+        out_channels,                           # Number of color channels at output.
+
+        model_channels          = 192,          # Base multiplier for the number of channels.
+        channel_mult            = [1,2,3,4],    # Per-resolution multipliers for the number of channels.
+        num_res_blocks          = [2,2,2,2],    # Number of residual blocks per resolution.
+        attention_resolutions   = [32,16,8],    # List of resolutions with self-attention.
+        dropout                 = 0.10,         # List of resolutions with self-attention.
+        dropout_from            = 16,            # Start applying dropout from this downsample onwards.
+
+        downsample              = 4,            # Downsample factor for the initial convolution layer.
+
+        fp16                    = False,        # Use float16 precision.
+    ):
+        super().__init__()
+
+        # Initial convolution layer to downsample the input image by a factor provided
+        # Assumes that the input image is square and divisible by the downsample factor
+        self.conv_down = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=in_channels,
+            kernel_size=downsample,
+            stride=downsample,
+            bias=False,
+        )
+
+        self.conv_up = self.conv_up = nn.ConvTranspose2d(
+            in_channels=in_channels,
+            out_channels=in_channels,
+            kernel_size=downsample,
+            stride=downsample,
+            bias=False,
+        )
+
+        # ADM backbone
+        self.unet = UNetModel(
+            img_resolution=img_resolution,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            model_channels=model_channels,
+            channel_mult=channel_mult,
+            num_res_blocks=num_res_blocks,
+            attention_resolutions=attention_resolutions,
+            dropout=dropout,
+            dropout_from=dropout_from,
+            use_fp16=fp16,
+        )
+
+    def forward(self, x, noise_labels):
+        # Downsample the input image
+        x = self.conv_down(x)
+
+        # Pass through the UNet model
+        x = self.unet(x, noise_labels)
+
+        # Upsample the output image
+        x = self.conv_up(x)
+
+        return x
+
+
+#----------------------------------------------------------------------------
+# Adaptation of HuggingFace's UNet2DModel to use with the 
+# simpleDiffusion pipeline
+
+class UNet2D(UNet2DModel):
+    def __init__(self, sample_size, in_channels, out_channels, **kwargs):
+        super().__init__(sample_size, in_channels, out_channels, **kwargs)
+
+    def forward(self, x, noise_labels):
+        x = super().forward(x, noise_labels, return_dict=False)
+        return x[0]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..a90233d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,132 @@
+accelerate==0.31.0
+accelerator==2024.6.11
+aiohttp==3.9.5
+aiosignal==1.3.1
+asttokens==2.4.1
+async-timeout==4.0.3
+attrs==23.2.0
+backcall==0.2.0
+beartype==0.18.5
+bottle==0.12.25
+certifi==2024.6.2
+charset-normalizer==3.3.2
+colorama==0.4.6
+comet-ml==3.43.2
+comm==0.2.2
+configobj==5.0.8
+contourpy==1.1.1
+cycler==0.12.1
+datasets==2.20.0
+debugpy==1.8.1
+decorator==5.1.1
+diffusers==0.29.1
+dill==0.3.8
+dulwich==0.22.1
+einops==0.8.0
+ema-pytorch==0.4.8
+everett==3.1.0
+exceptiongroup==1.2.1
+executing==2.0.1
+filelock==3.14.0
+fonttools==4.53.0
+frozenlist==1.4.1
+fsspec==2024.5.0
+huggingface-hub==0.23.3
+idna==3.7
+importlib_metadata==7.1.0
+importlib_resources==6.4.0
+iniconfig==2.0.0
+intel-openmp==2021.4.0
+ipykernel==6.29.4
+ipython==8.13.0
+jedi==0.19.1
+Jinja2==3.1.4
+jsonschema==4.22.0
+jsonschema-specifications==2023.12.1
+jupyter_client==8.6.2
+jupyter_core==5.7.2
+kiwisolver==1.4.5
+lz4==4.3.3
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+matplotlib==3.7.5
+matplotlib-inline==0.1.7
+mdurl==0.1.2
+mkl==2021.4.0
+mpmath==1.3.0
+multidict==6.0.5
+multiprocess==0.70.16
+nest-asyncio==1.6.0
+networkx==3.1
+numpy==1.24.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.5.40
+nvidia-nvtx-cu12==12.1.105
+packaging==24.0
+pandas==2.0.3
+parso==0.8.4
+patch-ng==1.17.4
+pexpect==4.9.0
+pickleshare==0.7.5
+pillow==10.3.0
+pkgutil_resolve_name==1.3.10
+platformdirs==4.2.2
+pluggy==1.5.0
+prompt_toolkit==3.0.47
+psutil==5.9.8
+ptyprocess==0.7.0
+pure-eval==0.2.2
+pyarrow==16.1.0
+pyarrow-hotfix==0.6
+Pygments==2.18.0
+pyparsing==3.1.2
+pytest==8.2.2
+python-box==6.1.0
+python-dateutil==2.9.0.post0
+pytorch-fid==0.3.0
+pytz==2024.1
+PyWavelets==1.4.1
+PyYAML==6.0.1
+pyzmq==26.0.3
+referencing==0.35.1
+regex==2024.5.15
+requests==2.32.3
+requests-toolbelt==1.0.0
+rich==13.7.1
+rpds-py==0.18.1
+safetensors==0.4.3
+scipy==1.10.1
+semantic-version==2.10.0
+sentry-sdk==2.6.0
+setproctitle==1.3.3
+simplejson==3.19.2
+six==1.16.0
+stack-data==0.6.3
+sympy==1.12.1
+tbb==2021.12.0
+tomli==2.0.1
+torch==2.3.1
+torchvision==0.18.1
+tornado==6.4.1
+tqdm==4.66.4
+traitlets==5.14.3
+triton==2.3.1
+typing_extensions==4.12.1
+tzdata==2024.1
+urllib3==2.2.1
+waitress==3.0.0
+wcwidth==0.2.13
+wrapt==1.16.0
+wurlitzer==3.1.1
+xxhash==3.4.1
+yarl==1.9.4
+zipp==3.19.2
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..5ecf7ab
--- /dev/null
+++ b/train.py
@@ -0,0 +1,94 @@
+"""
+A script for training a diffusion model on the Smithsonian Butterflies dataset following the simpleDiffusion paradigm.
+
+This script uses the UNet2D model from the diffusion library and the simpleDiffusion model from the simple_diffusion library. 
+"""
+
+from diffusion.unet import UNet2D, simpleUNet
+from diffusion.simple_diffusion import simpleDiffusion
+
+from datasets import load_dataset
+from torchvision import transforms
+import torch
+from diffusers.optimization import get_cosine_schedule_with_warmup
+
+
+class TrainingConfig:
+    image_size = 128  # the generated image resolution
+    train_batch_size = 4
+    num_epochs = 200
+    gradient_accumulation_steps = 1
+    learning_rate = 5e-5
+    lr_warmup_steps = 10000
+    save_image_epochs = 50
+    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
+    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub
+    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
+    seed = 0
+
+
+def main():
+    
+    config = TrainingConfig
+
+    dataset_name = "huggan/smithsonian_butterflies_subset"
+
+    dataset = load_dataset(dataset_name, split="train")
+
+    preprocess = transforms.Compose(
+        [
+            transforms.Resize((config.image_size, config.image_size)),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+        ]
+    )
+
+    def transform(examples):
+        images = [preprocess(image.convert("RGB")) for image in examples["image"]]
+        return {"images": images}
+
+    dataset.set_transform(transform)
+
+    train_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=config.train_batch_size,
+        shuffle=True,
+    )
+
+    unet = simpleUNet(
+        img_resolution=config.image_size,
+        in_channels=3,
+        out_channels=3,
+        model_channels=192,
+        channel_mult=(1, 1, 2, 2, 4, 4),
+        num_res_blocks=(2, 2, 2, 2, 2, 2),
+        attention_resolutions=(16, 8),
+        dropout=0.1,
+        dropout_from=16,
+        downsample=1,
+        fp16=True
+    )
+
+    optimizer = torch.optim.Adam(unet.parameters(), lr=config.learning_rate)
+    lr_scheduler = get_cosine_schedule_with_warmup(
+        optimizer,
+        num_warmup_steps=config.lr_warmup_steps,
+        num_training_steps=len(train_loader) * config.num_epochs,
+    )
+
+    diffusion_model = simpleDiffusion(
+        unet=unet,
+        image_size=config.image_size
+    )
+
+    diffusion_model.train_loop(
+        config=config,
+        optimizer=optimizer,
+        train_dataloader=train_loader,
+        lr_scheduler=lr_scheduler
+    )
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/unet_configurations.txt b/unet_configurations.txt
new file mode 100644
index 0000000..b41c881
--- /dev/null
+++ b/unet_configurations.txt
@@ -0,0 +1,39 @@
+UNet2D:
+    unet = UNet2D(
+        sample_size=config.image_size,  # the target image resolution
+        in_channels=3,  # the number of input channels, 3 for RGB images
+        out_channels=3,  # the number of output channels
+        layers_per_block=2,  # how many ResNet layers to use per UNet block
+        block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
+        down_block_types=(
+            "DownBlock2D",
+            "DownBlock2D",
+            "DownBlock2D",
+            "DownBlock2D",
+            "AttnDownBlock2D",
+            "DownBlock2D",
+        ),
+        up_block_types=(
+            "UpBlock2D",
+            "AttnUpBlock2D",
+            "UpBlock2D",
+            "UpBlock2D",
+            "UpBlock2D",
+            "UpBlock2D",
+        ),
+    )
+    
+simpleUNet:
+    unet = simpleUNet(
+        img_resolution=config.image_size,
+        in_channels=3,
+        out_channels=3,
+        model_channels=192,
+        channel_mult=(1, 1, 2, 2, 4, 4),
+        num_res_blocks=(2, 2, 2, 2, 2, 2),
+        attention_resolutions=(16, 8),
+        dropout=0.1,
+        dropout_from=16,
+        downsample=1,
+        fp16=True
+    )
\ No newline at end of file