forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable diffusion pipeline (huggingface#168)
* add stable diffusion pipeline * get rid of multiple if/else * batch_size is unused * add type hints * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * fix some bugs Co-authored-by: Patrick von Platen <[email protected]>
- Loading branch information
1 parent
92b6dbb
commit 5782e03
Showing
10 changed files
with
187 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ...utils import is_transformers_available | ||
|
||
|
||
if is_transformers_available(): | ||
from .pipeline_stable_diffusion import StableDiffusionPipeline |
115 changes: 115 additions & 0 deletions
115
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import inspect | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
|
||
from tqdm.auto import tqdm | ||
from transformers import CLIPTextModel, CLIPTokenizer | ||
|
||
from ...models import AutoencoderKL, UNet2DConditionModel | ||
from ...pipeline_utils import DiffusionPipeline | ||
from ...schedulers import DDIMScheduler, PNDMScheduler | ||
|
||
|
||
class StableDiffusionPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
vae: AutoencoderKL, | ||
text_encoder: CLIPTextModel, | ||
tokenizer: CLIPTokenizer, | ||
unet: UNet2DConditionModel, | ||
scheduler: Union[DDIMScheduler, PNDMScheduler], | ||
): | ||
super().__init__() | ||
scheduler = scheduler.set_format("pt") | ||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
|
||
@torch.no_grad() | ||
def __call__( | ||
self, | ||
prompt: Union[str, List[str]], | ||
num_inference_steps: Optional[int] = 50, | ||
guidance_scale: Optional[float] = 1.0, | ||
eta: Optional[float] = 0.0, | ||
generator: Optional[torch.Generator] = None, | ||
torch_device: Optional[Union[str, torch.device]] = None, | ||
output_type: Optional[str] = "pil", | ||
): | ||
if torch_device is None: | ||
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
if isinstance(prompt, str): | ||
batch_size = 1 | ||
elif isinstance(prompt, list): | ||
batch_size = len(prompt) | ||
else: | ||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
|
||
self.unet.to(torch_device) | ||
self.vae.to(torch_device) | ||
self.text_encoder.to(torch_device) | ||
|
||
# get prompt text embeddings | ||
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") | ||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] | ||
|
||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
# corresponds to doing no classifier free guidance. | ||
do_classifier_free_guidance = guidance_scale > 1.0 | ||
# get unconditional embeddings for classifier free guidance | ||
if do_classifier_free_guidance: | ||
max_length = text_input.input_ids.shape[-1] | ||
uncond_input = self.tokenizer( | ||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | ||
) | ||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] | ||
|
||
# For classifier free guidance, we need to do two forward passes. | ||
# Here we concatenate the unconditional and text embeddings into a single batch | ||
# to avoid doing two forward passes | ||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
|
||
# get the intial random noise | ||
latents = torch.randn( | ||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), | ||
generator=generator, | ||
) | ||
latents = latents.to(torch_device) | ||
|
||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
# and should be between [0, 1] | ||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
extra_kwargs = {} | ||
if accepts_eta: | ||
extra_kwargs["eta"] = eta | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for t in tqdm(self.scheduler.timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
|
||
# predict the noise residual | ||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | ||
|
||
# perform guidance | ||
if do_classifier_free_guidance: | ||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] | ||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
image = self.vae.decode(latents) | ||
|
||
image = (image / 2 + 0.5).clamp(0, 1) | ||
image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
if output_type == "pil": | ||
image = self.numpy_to_pil(image) | ||
|
||
return {"sample": image} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters