Knowledge-distilled, smaller versions of Stable Diffusion. Unofficial implementation as described in BK-SDM.
These distillation-trained models produce images of similar quality to the full-sized Stable-Diffusion model while being significantly faster and smaller.
- data.py contains scripts to download data for training.
- distill_training.py trains the U-net using the methods described in the paper. This might need additional configuration depending on what model type you want to train (sd_small/sd_tiny),batch size, hyperparameters etc.
The basic training code was sourced from the Huggingface π€ diffusers library.
Knowledge-Distillation training a neural network is akin to a teacher guiding a student step-by-step (a somewhat loose example). A large teacher model is trained on large amounts of data and then a smaller model is trained on a smaller dataset, with the objective of aping the outputs of the larger model along with classical training on the dataset.
For the Knowledge-Distillation training, we used SG161222/Realistic_Vision_V4.0's U-net as the teacher model with a subset of recastai/LAION-art-EN-improved-captions as training data.
The final training loss is the sum of the MSE loss between the noise predicted by the teacher U-net and the noise predicted by the student U-net, the MSE Loss between the actual added noise and the predicted noise, and the sum of MSE Losses between the predictions of the teacher and student U-nets after every block.
Total Loss:
Task Loss (i.e MSE Loss between added noise and actual noise):
Knowledge Distillation Output Loss (i.e MSE Loss between final output of teacher U-net and student U-net):
Feature-level Knowledge Distillation Loss (i.e MSE Loss between outputs of each block in the U-net):
Normal Stable Diffusion U-net:
Number of parameters: 859,520,964
SD_Small U-net:
Number of parameters: 579,384,964
SD_Tiny U-net:
Number of parameters: 323,384,964
(Model parameters reported using torchinfo)
import torch
from diffusers import DiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
from torch import Generator
path = 'segmind/small-sd' # Path to the appropriate model-type
# Insert your prompt below.
prompt = "Faceshot Portrait of pretty young (18-year-old) Caucasian wearing a high neck sweater, (masterpiece, extremely detailed skin, photorealistic, heavy shadow, dramatic and cinematic lighting, key light, fill light), sharp focus, BREAK epicrealism"
# Insert negative prompt below. We recommend using this negative prompt for best results.
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
# Below code will run on gpu, please pass cpu everywhere as the device and set 'dtype' to torch.float32 for cpu inference.
with torch.inference_mode():
gen = Generator("cuda")
gen.manual_seed(1674753452)
pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False)
pipe.to('cuda')
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.to(device='cuda', dtype=torch.float16, memory_format=torch.channels_last)
img = pipe(prompt=prompt,negative_prompt=negative_prompt, width=512, height=512, num_inference_steps=25, guidance_scale = 7, num_images_per_prompt=1, generator = gen).images[0]
img.save("image.png")
Training instructions are similar to those of the diffusers text-to-image finetuning script, apart from some extra parameters:
--distill_level
: One of "sd_small" or "sd_tiny", depending on which type of model is to be trained.
--output_weight
: A floating point number representing the amount the output-level KD loss is to be scaled by.
--feature-weight
: A floating point number representing the amount the feautre-level KD loss is to be scaled by.
Also, snr_gamma
has been removed.
An example:
export MODEL_NAME="SG161222/Realistic_Vision_V4.0"
export DATASET_NAME="fantasyfish/laion-art"
accelerate launch --mixed_precision="fp16" distill_training.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--distill_level="sd_small"\
--output_weight=0.5\
--feature_weight=0.5\
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-laion-art"
- The trained "sd-small" version of the model is available at this Huggingface π€ repo
- The trained "sd-tiny" version of the model is available at this Huggingface π€ repo
- Fine-tuned version of the "sd-tiny model" on portrait images is available at this Huggingface π€ repo
Below are some of the images generated with the sd-tiny model, fine-tuned on portrait images.
Link to the model -> Huggingface π€ repo
- Upto 100% Faster inferences
- Upto 30% lower VRAM footprint
- Faster dreambooth and LoRA training
- The distilled models are in early phase and the outputs may not be at a production quality yet.
- These models may not be the best general models. They are best used as fine-tuned or LoRA trained on specific concepts/styles.
- Distilled models are not very good at composibility or multiconcepts yet.
- SDXL distilled models and code.
- Further fine-tuned SD-1.5 base models for better composibility and generalization.
- Apply TensorRT and/or AITemplate for further accelerations.
- Look at Quantization-Aware-Training(QAT) during distillation process.
@article{kim2023architectural,
title={On Architectural Compression of Text-to-Image Diffusion Models},
author={Kim, Bo-Kyeong and Song, Hyoung-Kyu and Castells, Thibault and Choi, Shinkook},
journal={arXiv preprint arXiv:2305.15798},
year={2023},
url={https://arxiv.org/abs/2305.15798}
}
@article{Kim_2023_ICMLW,
title={BK-SDM: Architecturally Compressed Stable Diffusion for Efficient Text-to-Image Generation},
author={Kim, Bo-Kyeong and Song, Hyoung-Kyu and Castells, Thibault and Choi, Shinkook},
journal={ICML Workshop on Efficient Systems for Foundation Models (ES-FoMo)},
year={2023},
url={https://openreview.net/forum?id=bOVydU0XKC}
}