Skip to content

Commit

Permalink
Add support for Multi-ControlNet to StableDiffusionControlNetPipeline (
Browse files Browse the repository at this point in the history
…huggingface#2627)

* support for List[ControlNetModel] on init()

* Add to support for multiple ControlNetCondition

* rename conditioning_scale to scale

* scaling bugfix

* Manually merge `MultiControlNet` huggingface#2621

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* cleanups
- don't expose ControlNetCondition
- move scaling to ControlNetModel

* make style error correct

* remove ControlNetCondition to reduce code diff

* refactoring image/cond_scale

* add explain for `images`

* Add docstrings

* all fast-test passed

* Add a slow test

* nit

* Apply suggestions from code review

* small precision fix

* nits

MultiControlNet -> MultiControlNetModel - Matches existing naming a bit
closer

MultiControlNetModel inherit from model utils class - Don't have to
re-write fp16 test

Skip tests that save multi controlnet pipeline - Clearer than changing
test body

Don't auto-batch the number of input images to the number of controlnets.
We generally like to require the user to pass the expected number of
inputs. This simplifies the processing code a bit more

Use existing image pre-processing code a bit more. We can rely on the
existing image pre-processing code and keep the inference loop a bit
simpler.

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: William Berman <[email protected]>
  • Loading branch information
3 people authored Mar 13, 2023
1 parent 4ae54b3 commit d9b8adc
Show file tree
Hide file tree
Showing 3 changed files with 384 additions and 31 deletions.
5 changes: 5 additions & 0 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def forward(
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -492,6 +493,10 @@ def forward(

mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

if not return_dict:
return (down_block_res_samples, mid_block_res_sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@


import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch
from torch import nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import ModelMixin
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
Expand Down Expand Up @@ -85,6 +89,63 @@
"""


class MultiControlNetModel(ModelMixin):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.
Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""

def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
return_dict,
)

# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample

return down_block_res_samples, mid_block_res_sample


class StableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
Expand All @@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
as a list, the outputs from each ControlNet are added together to create one combined additional
conditioning.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
Expand All @@ -122,7 +185,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
Expand All @@ -146,6 +209,9 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -432,6 +498,7 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -470,6 +537,41 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Check `image`

if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

if len(image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)

for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False

# Check `controlnet_conditioning_scale`

if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False

def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
Expand Down Expand Up @@ -501,7 +603,9 @@ def check_inputs(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
Expand Down Expand Up @@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)

return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
Expand All @@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents

def _default_height_width(self, height, width, image):
if isinstance(image, list):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]

if height is None:
Expand All @@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image):

return height, width

# override DiffusionPipeline
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
variant: Optional[str] = None,
):
if isinstance(self.controlnet, ControlNetModel):
super().save_pretrained(save_directory, safe_serialization, variant)
else:
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -593,7 +715,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand All @@ -602,10 +724,14 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image.
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -658,10 +784,10 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
Examples:
Returns:
Expand All @@ -676,7 +802,15 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
)

# 2. Define call parameters
Expand All @@ -693,6 +827,9 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
Expand All @@ -705,18 +842,37 @@ def __call__(
)

# 4. Prepare image
image = self.prepare_image(
image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
if isinstance(self.controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []

for image_ in image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)
images.append(image_)

image = images
else:
assert False

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -746,20 +902,16 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)

down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down
Loading

0 comments on commit d9b8adc

Please sign in to comment.