Skip to content

Commit

Permalink
update expected results of slow tests (huggingface#268)
Browse files Browse the repository at this point in the history
* update expected results of slow tests

* relax sum and mean tests

* Print shapes when reporting exception

* formatting

* fix sentence

* relax test_stable_diffusion_fast_ddim for gpu fp16

* relax flakey tests on GPU

* added comment on large tolerences

* black

* format

* set scheduler seed

* added generator

* use np.isclose

* set num_inference_steps to 50

* fix dep. warning

* update expected_slice

* preprocess if image

* updated expected results

* updated expected from CI

* pass generator to VAE

* undo change back to orig

* use orignal

* revert back the expected on cpu

* revert back values for CPU

* more undo

* update result after using gen

* update mean

* set generator for mps

* update expected on CI server

* undo

* use new seed every time

* cpu manual seed

* reduce num_inference_steps

* style

* use generator for randn

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
kashif and patrickvonplaten authored Sep 12, 2022
1 parent 25a51b6 commit f4781a0
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 47 deletions.
8 changes: 6 additions & 2 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,11 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
return DecoderOutput(sample=dec)

def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
Expand All @@ -570,7 +574,7 @@ def forward(
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

if not isinstance(init_image, torch.FloatTensor):
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)

# encode the init image into latents and scale the latents
Expand Down
14 changes: 8 additions & 6 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ def test_output_pretrained(self):
model.eval()
model.to(torch_device)

torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
noise = torch.randn(
1,
model.config.in_channels,
model.config.sample_size,
model.config.sample_size,
generator=torch.manual_seed(0),
)
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

Expand All @@ -154,7 +156,7 @@ def test_output_pretrained(self):
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on

self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))


# TODO(Patrick) - Re-add this test after having cleaned up LDM
Expand Down
22 changes: 14 additions & 8 deletions tests/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,24 @@ def test_output_pretrained(self):
image = image.to(torch_device)
with torch.no_grad():
_ = model(image, sample_posterior=True).sample

torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

image = torch.randn(
1,
model.config.in_channels,
model.config.sample_size,
model.config.sample_size,
generator=torch.manual_seed(0),
)
image = image.to(torch_device)
with torch.no_grad():
output = model(image, sample_posterior=True).sample
output = model(image, sample_posterior=True, generator=generator).sample

output_slice = output[0, -1, -3:, -3:].flatten().cpu()

# fmt: off
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
expected_output_slice = torch.tensor([-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
2 changes: 1 addition & 1 deletion tests/test_models_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def test_output_pretrained(self):
# fmt: off
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
26 changes: 14 additions & 12 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def test_stable_diffusion_ddim(self):

assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

Expand Down Expand Up @@ -463,17 +464,18 @@ def test_score_sde_ve_pipeline(self):
sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)

torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images

torch.manual_seed(0)
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0]
generator = torch.manual_seed(0)
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[
0
]

image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)

expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
Expand Down Expand Up @@ -647,7 +649,7 @@ def test_stable_diffusion_inpaint(self):
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.to(device).permute(0, 2, 3, 1)[0]
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

Expand Down Expand Up @@ -729,8 +731,8 @@ def test_from_pretrained_save_pretrained(self):
new_ddpm.to(torch_device)

generator = torch.manual_seed(0)

image = ddpm(generator=generator, output_type="numpy").images

generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator, output_type="numpy").images

Expand All @@ -750,8 +752,8 @@ def test_from_pretrained_hub(self):
ddpm_from_hub.set_progress_bar_config(disable=None)

generator = torch.manual_seed(0)

image = ddpm(generator=generator, output_type="numpy").images

generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images

Expand All @@ -774,8 +776,8 @@ def test_from_pretrained_hub_pass_model(self):
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)

generator = torch.manual_seed(0)

image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images

generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images

Expand Down Expand Up @@ -981,14 +983,14 @@ def test_score_sde_ve_pipeline(self):
sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)

torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy").images
generator = torch.manual_seed(0)
image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 256, 256, 3)

expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
Expand Down
48 changes: 31 additions & 17 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,14 @@ def test_full_loop_no_noise(self):

model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)

for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)

# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample

# if t > 0:
# noise = self.dummy_sample_deter
Expand All @@ -336,7 +337,7 @@ def test_full_loop_no_noise(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 259.0883) < 1e-2
assert abs(result_sum.item() - 258.9070) < 1e-2
assert abs(result_mean.item() - 0.3374) < 1e-3


Expand Down Expand Up @@ -657,7 +658,7 @@ def test_full_loop_no_noise(self):
class ScoreSdeVeSchedulerTest(unittest.TestCase):
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
scheduler_classes = (ScoreSdeVeScheduler,)
forward_default_kwargs = (("seed", 0),)
forward_default_kwargs = ()

@property
def dummy_sample(self):
Expand Down Expand Up @@ -718,13 +719,19 @@ def check_over_configs(self, time_step=0, **config):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)

output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
new_output = new_scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
new_output = new_scheduler.step_correct(
residual, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"

Expand All @@ -743,13 +750,19 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)

output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
new_output = new_scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
new_output = new_scheduler.step_correct(
residual, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"

Expand Down Expand Up @@ -779,26 +792,27 @@ def test_full_loop_no_noise(self):

scheduler.set_sigmas(num_inference_steps)
scheduler.set_timesteps(num_inference_steps)
generator = torch.manual_seed(0)

for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i]

for _ in range(scheduler.correct_steps):
with torch.no_grad():
model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample

with torch.no_grad():
model_output = model(sample, sigma_t)

output = scheduler.step_pred(model_output, t, sample, **kwargs)
output = scheduler.step_pred(model_output, t, sample, generator=generator, **kwargs)
sample, _ = output.prev_sample, output.prev_sample_mean

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 14379591680.0) < 1e-2
assert abs(result_mean.item() - 18723426.0) < 1e-3
assert np.isclose(result_sum.item(), 14372758528.0)
assert np.isclose(result_mean.item(), 18714530.0)

def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
Expand All @@ -817,8 +831,8 @@ def test_step_shape(self):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps

output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
output_0 = scheduler.step_pred(residual, 0, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, generator=torch.manual_seed(0), **kwargs).prev_sample

self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)

0 comments on commit f4781a0

Please sign in to comment.