diff --git a/image_to_image.py b/image_to_image.py index e6cb2fe..125c77b 100644 --- a/image_to_image.py +++ b/image_to_image.py @@ -39,6 +39,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): """ From https://github.com/huggingface/diffusers/pull/241 """ + def __init__( self, vae: AutoencoderKL, @@ -111,7 +112,12 @@ def __call__( if init_image is not None: init_latents_orig, latents, init_timestep = self.latents_from_init_image( - init_image, prompt_strength, offset, num_inference_steps, batch_size, generator + init_image, + prompt_strength, + offset, + num_inference_steps, + batch_size, + generator, ) else: latents = torch.randn( @@ -137,6 +143,8 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + mask_noise = torch.randn(latents.shape, generator=generator, device=self.device) + t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance @@ -161,8 +169,14 @@ def __call__( "prev_sample" ] + # replace the unmasked part with original latents, with added noise if mask is not None: - latents = init_latents_orig * mask + latents * (1 - mask) + timesteps = self.scheduler.timesteps[t_start + i] + timesteps = torch.tensor( + [timesteps] * batch_size, dtype=torch.long, device=self.device + ) + noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps) + latents = noisy_init_latents * mask + latents * (1 - mask) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/predict.py b/predict.py index 188577e..c7f91ef 100644 --- a/predict.py +++ b/predict.py @@ -52,7 +52,7 @@ def predict( description="Inital image to generate variations of", default=None ), mask: Path = Input( - description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved", + description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7", default=None, ), prompt_strength: float = Input(