From 3a357eb759634729bf7ab807bd57775d847ebe69 Mon Sep 17 00:00:00 2001 From: lvmin Date: Sat, 18 Feb 2023 01:31:29 -0800 Subject: [PATCH] prepare for strength mode --- cldm/cldm.py | 101 ++++++++++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/cldm/cldm.py b/cldm/cldm.py index 1a86f4165f..4e02423d9a 100644 --- a/cldm/cldm.py +++ b/cldm/cldm.py @@ -46,34 +46,34 @@ def forward(self, x, timesteps=None, context=None, control=None, only_mid_contro class ControlNet(nn.Module): def __init__( - self, - image_size, - in_channels, - model_channels, - hint_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, ): super().__init__() if use_spatial_transformer: @@ -144,21 +144,21 @@ def __init__( self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) self.input_hint_block = TimestepEmbedSequential( - conv_nd(dims, hint_channels, 16, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 16, 16, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 16, 32, 3, padding=1, stride=2), - nn.SiLU(), - conv_nd(dims, 32, 32, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 32, 96, 3, padding=1, stride=2), - nn.SiLU(), - conv_nd(dims, 96, 96, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 96, 256, 3, padding=1, stride=2), - nn.SiLU(), - zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) ) self._feature_size = model_channels @@ -186,7 +186,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -243,7 +243,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -261,10 +261,10 @@ def __init__( num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), ResBlock( ch, time_embed_dim, @@ -311,6 +311,7 @@ def __init__(self, control_stage_config, control_key, only_mid_control, *args, * self.control_model = instantiate_from_config(control_stage_config) self.control_key = control_key self.only_mid_control = only_mid_control + self.control_scales = [1.0] * 13 @torch.no_grad() def get_input(self, batch, k, bs=None, *args, **kwargs): @@ -330,6 +331,8 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs): cond_hint = torch.cat(cond['c_concat'], 1) control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) return eps