Skip to content

Commit

Permalink
[CLIPGuidedStableDiffusion] support DDIM scheduler (huggingface#1190)
Browse files Browse the repository at this point in the history
add ddim in clip guided
  • Loading branch information
patil-suraj authored Nov 9, 2022
1 parent 663f0c1 commit cd77a03
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from torch import nn
from torch.nn import functional as F

from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -56,7 +63,7 @@ def __init__(
clip_model: CLIPModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
Expand Down Expand Up @@ -123,7 +130,7 @@ def cond_fn(
# predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample

if isinstance(self.scheduler, PNDMScheduler):
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called
Expand Down Expand Up @@ -176,6 +183,7 @@ def __call__(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4,
Expand Down Expand Up @@ -275,6 +283,20 @@ def __call__(
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -306,7 +328,7 @@ def __call__(
)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down

0 comments on commit cd77a03

Please sign in to comment.