Skip to content

Commit

Permalink
prepare for strength mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Feb 18, 2023
1 parent 31c4095 commit 3a357eb
Showing 1 changed file with 52 additions and 49 deletions.
101 changes: 52 additions & 49 deletions cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,34 @@ def forward(self, x, timesteps=None, context=None, control=None, only_mid_contro

class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
Expand Down Expand Up @@ -144,21 +144,21 @@ def __init__(
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])

self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)

self._feature_size = model_channels
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
Expand Down Expand Up @@ -243,7 +243,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
Expand All @@ -261,10 +261,10 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
Expand Down Expand Up @@ -311,6 +311,7 @@ def __init__(self, control_stage_config, control_key, only_mid_control, *args, *
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13

@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
Expand All @@ -330,6 +331,8 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
cond_hint = torch.cat(cond['c_concat'], 1)

control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
control = [c * scale for c, scale in zip(control, self.control_scales)]

eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

return eps
Expand Down

0 comments on commit 3a357eb

Please sign in to comment.