Skip to content

Commit

Permalink
[StableDiffusionInpaintPipeline] fix batch_size for mask and masked l…
Browse files Browse the repository at this point in the history
…atents (huggingface#1279)

fix bs for mask and masked latents
  • Loading branch information
patil-suraj authored Nov 14, 2022
1 parent c9b3463 commit a8d0977
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def __call__(
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size,
batch_size * num_images_per_prompt,
height,
width,
text_embeddings.dtype,
Expand Down
41 changes: 41 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,47 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
device = "cpu"
unet = self.dummy_cond_unet_inpaint
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
images = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=init_image,
mask_image=mask_image,
num_images_per_prompt=2,
).images

# check if the output is a list of 2 images
assert len(images) == 2

@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint_legacy works with fp16"""
Expand Down

0 comments on commit a8d0977

Please sign in to comment.