Skip to content

Commit

Permalink
Add noise to unmasked latents
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Aug 25, 2022
1 parent 17f55f7 commit ce6b29d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
18 changes: 16 additions & 2 deletions image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
From https://github.com/huggingface/diffusers/pull/241
"""

def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -111,7 +112,12 @@ def __call__(

if init_image is not None:
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
init_image, prompt_strength, offset, num_inference_steps, batch_size, generator
init_image,
prompt_strength,
offset,
num_inference_steps,
batch_size,
generator,
)
else:
latents = torch.randn(
Expand All @@ -137,6 +143,8 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
Expand All @@ -161,8 +169,14 @@ def __call__(
"prev_sample"
]

# replace the unmasked part with original latents, with added noise
if mask is not None:
latents = init_latents_orig * mask + latents * (1 - mask)
timesteps = self.scheduler.timesteps[t_start + i]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
latents = noisy_init_latents * mask + latents * (1 - mask)

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def predict(
description="Inital image to generate variations of", default=None
),
mask: Path = Input(
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved",
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
default=None,
),
prompt_strength: float = Input(
Expand Down

0 comments on commit ce6b29d

Please sign in to comment.