Skip to content

Commit

Permalink
Stable diffusion pipeline (huggingface#168)
Browse files Browse the repository at this point in the history
* add stable diffusion pipeline

* get rid of multiple if/else

* batch_size is unused

* add type hints

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* fix some bugs

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
patil-suraj and patrickvonplaten authored Aug 14, 2022
1 parent 92b6dbb commit 5782e03
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 21 deletions.
3 changes: 2 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


if is_transformers_available():
from .pipelines import LDMTextToImagePipeline
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline

else:
from .utils.dummy_transformers_objects import *
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

if is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import StableDiffusionPipeline
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def __call__(

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}

extra_kwargs = {}
if accepts_eta:
extra_kwrags["eta"] = eta
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
Expand All @@ -86,7 +87,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def __call__(

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}

extra_kwargs = {}
if accepts_eta:
extra_kwrags["eta"] = eta
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]

# decode the image latents with the VAE
image = self.vqvae.decode(latents)
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ...utils import is_transformers_available


if is_transformers_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
115 changes: 115 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import inspect
from typing import List, Optional, Union

import torch

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler


class StableDiffusionPipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta

self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)

return {"sample": image}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

class KarrasVePipeline(DiffusionPipeline):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
"""

unet: UNet2DModel
Expand Down
20 changes: 10 additions & 10 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@

class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
"""

@register_to_config
Expand All @@ -43,10 +44,9 @@ def __init__(
tensor_format="pt",
):
"""
For more details on the parameters, see the original paper's Appendix E.:
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
are described in Table 5 of the paper.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
Expand Down Expand Up @@ -81,8 +81,8 @@ def set_timesteps(self, num_inference_steps):

def add_noise_to_input(self, sample, sigma, generator=None):
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
"""
if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/utils/dummy_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject):

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])


class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
34 changes: 34 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel

from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline


torch.backends.cuda.matmul.allow_tf32 = False

Expand Down Expand Up @@ -839,6 +841,38 @@ def test_ldm_text2img_fast(self):
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_stable_diffusion(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]

image_slice = image[0, -3:, -3:, -1]

# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_stable_diffusion_fast(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]

image_slice = image[0, -3:, -3:, -1]

# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
def test_score_sde_ve_pipeline(self):
model_id = "google/ncsnpp-church-256"
Expand Down

0 comments on commit 5782e03

Please sign in to comment.