Skip to content

Commit

Permalink
Controlnet ssd 1b support (huggingface#5779)
Browse files Browse the repository at this point in the history
* Add SSD-1B support for controlnet model

* Add conditioning_channels into ControlNet init from unet

* Fix black formatting

* Isort fixes

* Adds SSD-1B controlnet pipeline test with UNetMidBlock2D as mid block

* Overrides failing ssd-1b tests

* Fixes tests after main branch update

* Fixes code quality checks

---------

Co-authored-by: Marko Kostiv <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2023
1 parent ddd8bd5 commit 6a4aad4
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 27 deletions.
71 changes: 44 additions & 27 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from .unet_2d_condition import UNet2DConditionModel


Expand Down Expand Up @@ -191,6 +186,7 @@ def __init__(
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
Expand Down Expand Up @@ -409,20 +405,35 @@ def __init__(
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block

self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")

@classmethod
def from_unet(
Expand All @@ -431,6 +442,7 @@ def from_unet(
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 3,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Expand Down Expand Up @@ -477,8 +489,10 @@ def from_unet(
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
mid_block_type=unet.config.mid_block_type,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)

if load_weights_from_unet:
Expand Down Expand Up @@ -797,13 +811,16 @@ def forward(

# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)

# 5. Control net blocks

Expand Down
160 changes: 160 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from diffusers.models.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
Expand Down Expand Up @@ -817,3 +818,162 @@ def test_depth(self):
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
assert np.allclose(original_image, expected_image, atol=1e-04)


class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
def test_controlnet_sdxl_guess(self):
device = "cpu"

components = self.get_dummy_components()

sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(device)

sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
inputs["guess_mode"] = True

output = sd_pipe(**inputs)
image_slice = output.images[0, -3:, -3:, -1]
expected_slice = np.array(
[0.6831671, 0.5702532, 0.5459845, 0.6299793, 0.58563006, 0.6033695, 0.4493941, 0.46132287, 0.5035841]
)

# make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4

def test_controlnet_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLControlNetPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6850, 0.5135, 0.5545, 0.7033, 0.6617, 0.5971, 0.4165, 0.5480, 0.5070])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_conditioning_channels(self):
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
mid_block_type="UNetMidBlock2D",
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
time_cond_proj_dim=None,
)

controlnet = ControlNetModel.from_unet(unet, conditioning_channels=4)
assert type(controlnet.mid_block) == UNetMidBlock2D
assert controlnet.conditioning_channels == 4

def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
mid_block_type="UNetMidBlock2D",
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
time_cond_proj_dim=time_cond_proj_dim,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
conditioning_embedding_out_channels=(16, 32),
mid_block_type="UNetMidBlock2D",
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components

0 comments on commit 6a4aad4

Please sign in to comment.