Skip to content

Commit

Permalink
[WIP] masked_latent_inputs for inpainting pipeline (huggingface#4819)
Browse files Browse the repository at this point in the history
* add

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Sep 1, 2023
1 parent d8b6f5d commit 5c404f2
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 17 deletions.
12 changes: 10 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,11 @@ def prepare_latents(

if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)

if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -907,7 +911,11 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)

masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)

masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % 8 != 0 or width % 8 != 0:
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
Expand Down Expand Up @@ -622,7 +622,11 @@ def prepare_latents(

if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)

if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -670,7 +674,11 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)

masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down Expand Up @@ -715,6 +723,7 @@ def __call__(
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
Expand Down Expand Up @@ -914,12 +923,6 @@ def __call__(
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

mask = self.mask_processor.preprocess(mask_image, height=height, width=width)

masked_image = init_image * (mask < 0.5)

mask_condition = mask.clone()

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
Expand Down Expand Up @@ -947,8 +950,15 @@ def __call__(
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)

if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents

mask, masked_image_latents = self.prepare_mask_latents(
mask,
mask_condition,
masked_image,
batch_size * num_images_per_prompt,
height,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,16 @@ def prepare_mask_latents(

mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask

masked_image_latents = None
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None

if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
Expand Down Expand Up @@ -890,6 +896,7 @@ def __call__(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.9999,
Expand Down Expand Up @@ -1152,7 +1159,9 @@ def denoising_value_valid(dnv):

mask = self.mask_processor.preprocess(mask_image, height=height, width=width)

if init_image.shape[1] == 4:
if masked_image_latents is not None:
masked_image = masked_image_latents
elif init_image.shape[1] == 4:
# if images are in latent space, we can't mask it
masked_image = None
else:
Expand Down
37 changes: 37 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,43 @@ def test_stable_diffusion_inpaint_strength_zero_test(self):
with self.assertRaises(ValueError):
sd_pipe(**inputs).images

def test_stable_diffusion_inpaint_mask_latents(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(device)
sd_pipe.set_progress_bar_config(disable=None)

# normal mask + normal image
## `image`: pil, `mask_image``: pil, `masked_image_latents``: None
inputs = self.get_dummy_inputs(device)
inputs["strength"] = 0.9
out_0 = sd_pipe(**inputs).images

# image latents + mask latents
inputs = self.get_dummy_inputs(device)
image = sd_pipe.image_processor.preprocess(inputs["image"]).to(sd_pipe.device)
mask = sd_pipe.mask_processor.preprocess(inputs["mask_image"]).to(sd_pipe.device)
masked_image = image * (mask < 0.5)

generator = torch.Generator(device=device).manual_seed(0)
image_latents = (
sd_pipe.vae.encode(image).latent_dist.sample(generator=generator) * sd_pipe.vae.config.scaling_factor
)
torch.randn((1, 4, 32, 32), generator=generator)
mask_latents = (
sd_pipe.vae.encode(masked_image).latent_dist.sample(generator=generator)
* sd_pipe.vae.config.scaling_factor
)
inputs["image"] = image_latents
inputs["masked_image_latents"] = mask_latents
inputs["mask_image"] = mask
inputs["strength"] = 0.9
generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
inputs["generator"] = generator
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2


class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,35 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self):
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
> 1e-4
)

def test_stable_diffusion_xl_inpaint_mask_latents(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(device)
sd_pipe.set_progress_bar_config(disable=None)

# normal mask + normal image
## `image`: pil, `mask_image``: pil, `masked_image_latents``: None
inputs = self.get_dummy_inputs(device)
inputs["strength"] = 0.9
out_0 = sd_pipe(**inputs).images

# image latents + mask latents
inputs = self.get_dummy_inputs(device)
image = sd_pipe.image_processor.preprocess(inputs["image"]).to(sd_pipe.device)
mask = sd_pipe.mask_processor.preprocess(inputs["mask_image"]).to(sd_pipe.device)
masked_image = image * (mask < 0.5)

generator = torch.Generator(device=device).manual_seed(0)
image_latents = sd_pipe._encode_vae_image(image, generator=generator)
torch.randn((1, 4, 32, 32), generator=generator)
mask_latents = sd_pipe._encode_vae_image(masked_image, generator=generator)
inputs["image"] = image_latents
inputs["masked_image_latents"] = mask_latents
inputs["mask_image"] = mask
inputs["strength"] = 0.9
generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
inputs["generator"] = generator
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2

0 comments on commit 5c404f2

Please sign in to comment.