Skip to content

Commit

Permalink
[Tests] Improve unet / vae tests (huggingface#1018)
Browse files Browse the repository at this point in the history
* improve tests

* up

* finish

* upload

* add init

* up

* finish vae

* finish

* reduce loading time with device_map

* remove device_map from CPU

* uP
  • Loading branch information
patrickvonplaten authored Oct 28, 2022
1 parent ab079f2 commit a80480f
Show file tree
Hide file tree
Showing 10 changed files with 506 additions and 188 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
"parameterized",
"pytest",
"pytest-timeout",
"pytest-xdist",
Expand Down Expand Up @@ -181,6 +182,7 @@ def run(self):
"accelerate",
"datasets",
"onnxruntime",
"parameterized",
"pytest",
"pytest-timeout",
"pytest-xdist",
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"parameterized": "parameterized",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


if is_torch_available():
from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device
from .testing_utils import floats_tensor, load_image, parse_flag_from_env, require_torch_gpu, slow, torch_device


logger = get_logger(__name__)
Expand Down
Empty file added tests/models/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
model.to(torch_device)

sample_size = 65536
Expand Down
248 changes: 196 additions & 52 deletions tests/test_models_unet_2d.py → tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import torch

from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.utils import floats_tensor, slow, torch_device
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_device
from parameterized import parameterized

from .test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin


torch.backends.cuda.matmul.allow_tf32 = False
Expand Down Expand Up @@ -66,28 +67,6 @@ def prepare_init_args_and_inputs_for_common(self):
return init_dict, inputs_dict


# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
# def test_output_pretrained(self):
# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
# model.eval()
#
# torch.manual_seed(0)
# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0)
#
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# time_step = torch.tensor([10])
#
# with torch.no_grad():
# output = model(noise, time_step).sample
#
# output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))


class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel

Expand Down Expand Up @@ -170,7 +149,9 @@ def test_from_pretrained_accelerate_wont_change_results(self):
torch.cuda.empty_cache()
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_normal_load.to(torch_device)
model_normal_load.eval()
arr_normal_load = model_normal_load(noise, time_step)["sample"]
Expand Down Expand Up @@ -309,31 +290,6 @@ def test_gradient_checkpointing(self):
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))


# TODO(Patrick) - Re-add this test after having cleaned up LDM
# def test_output_pretrained_spatial_transformer(self):
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
# model.eval()
#
# torch.manual_seed(0)
# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0)
#
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# context = torch.ones((1, 16, 64), dtype=torch.float32)
# time_step = torch.tensor([10] * noise.shape[0])
#
# with torch.no_grad():
# output = model(noise, time_step, context=context)
#
# output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#


class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel

Expand Down Expand Up @@ -383,7 +339,9 @@ def prepare_init_args_and_inputs_for_common(self):

@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
model, loading_info = UNet2DModel.from_pretrained(
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)

Expand All @@ -397,7 +355,7 @@ def test_from_pretrained_hub(self):

@slow
def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
model.to(torch_device)

torch.manual_seed(0)
Expand Down Expand Up @@ -449,3 +407,189 @@ def test_output_pretrained_ve_large(self):
def test_forward_with_norm_groups(self):
# not required for this model
pass


@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
batch_size, channels, height, width = shape
generator = torch.Generator(device=torch_device).manual_seed(seed)
dtype = torch.float16 if fp16 else torch.float32
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)

return image

def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
revision = "fp16" if fp16 else None
torch_dtype = torch.float16 if fp16 else torch.float32

model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
)
model.to(torch_device).eval()

return model

def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
generator = torch.Generator(device=torch_device).manual_seed(seed)
dtype = torch.float16 if fp16 else torch.float32
return torch.randn(shape, device=torch_device, generator=generator, dtype=dtype)

@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
# fmt: on
]
)
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == latents.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)

@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == latents.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)

@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
# fmt: on
]
)
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == latents.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)

@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == latents.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)

@parameterized.expand(
[
# fmt: off
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
# fmt: on
]
)
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
encoder_hidden_states = self.get_encoder_hidden_states(seed)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == (4, 4, 64, 64)

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)

@parameterized.expand(
[
# fmt: off
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == (4, 4, 64, 64)

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
Loading

0 comments on commit a80480f

Please sign in to comment.