Skip to content

Commit

Permalink
Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (hugg…
Browse files Browse the repository at this point in the history
…ingface#4863)

* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm.

* Add docstring for attn_norm_num_groups in UNet2DModel.

* Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32.

* Add test for attn_norm_num_groups to UNet2DModelTests.

* Fix expected slices for slow tests.

* Also fix tolerances for slow tests.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
dg845 and sayakpaul authored Sep 15, 2023
1 parent 5fd42e5 commit 4c8a05f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 7 deletions.
2 changes: 2 additions & 0 deletions scripts/convert_consistency_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
Expand All @@ -52,6 +53,7 @@
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
Expand Down Expand Up @@ -107,6 +111,7 @@ def __init__(
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
Expand Down Expand Up @@ -192,6 +197,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)

Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def __init__(
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
Expand All @@ -494,6 +495,9 @@ def __init__(
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention

if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None

# there is always at least one resnet
resnets = [
ResnetBlock2D(
Expand Down Expand Up @@ -526,7 +530,7 @@ def __init__(
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 8

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
)
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
)

with torch.no_grad():
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.to_tuple()[0]

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")


class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
Expand Down
12 changes: 6 additions & 6 deletions tests/pipelines/consistency_models/test_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def test_consistency_model_cd_multistep(self):

image_slice = image[0, -3:, -3:, -1]

expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254])
expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058])

assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

def test_consistency_model_cd_onestep(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
Expand All @@ -239,9 +239,9 @@ def test_consistency_model_cd_onestep(self):

image_slice = image[0, -3:, -3:, -1]

expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217])
expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095])

assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@require_torch_2
def test_consistency_model_cd_multistep_flash_attn(self):
Expand All @@ -263,7 +263,7 @@ def test_consistency_model_cd_multistep_flash_attn(self):

image_slice = image[0, -3:, -3:, -1]

expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353])
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

Expand All @@ -289,6 +289,6 @@ def test_consistency_model_cd_onestep_flash_attn(self):

image_slice = image[0, -3:, -3:, -1]

expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095])
expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

0 comments on commit 4c8a05f

Please sign in to comment.