Skip to content

Commit

Permalink
rename file
Browse files Browse the repository at this point in the history
  • Loading branch information
gasvn committed Jun 14, 2023
1 parent edf6abf commit 7068ff5
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 187 deletions.
2 changes: 1 addition & 1 deletion sam2edit_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections import defaultdict
from diffusers import StableDiffusionControlNetPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from utils.stable_diffusion_controlnet_inpaint_ref import (
from utils.stable_diffusion_controlnet_inpaint import (
StableDiffusionControlNetInpaintPipeline,
)

Expand Down
186 changes: 178 additions & 8 deletions utils/stable_diffusion_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
randn_tensor,
replace_example_docstring,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from utils.stable_diffusion_reference import StableDiffusionReferencePipeline

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -257,7 +258,12 @@ def prepare_controlnet_conditioning_image(
return controlnet_conditioning_image


class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin):
class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline,
LoraLoaderMixin,
StableDiffusionReferencePipeline,
TextualInversionLoaderMixin,
):
"""
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
"""
Expand Down Expand Up @@ -1009,6 +1015,25 @@ def __call__(
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
alignment_ratio=None,
guess_mode: bool = False,
ref_image: Union[
torch.FloatTensor,
PIL.Image.Image,
List[torch.FloatTensor],
List[PIL.Image.Image],
] = None,
ref_mask: Union[
torch.FloatTensor,
PIL.Image.Image,
List[torch.FloatTensor],
List[PIL.Image.Image],
] = None,
ref_controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
ref_prompt: Union[str, List[str]] = None,
attention_auto_machine_weight: float = 1.0,
gn_auto_machine_weight: float = 1.0,
style_fidelity: float = 0.5,
reference_attn: bool = True,
reference_adain: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -1086,6 +1111,22 @@ def __call__(
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
also be accepted as an image.
attention_auto_machine_weight (`float`):
Weight of using reference query for self attention's context.
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
gn_auto_machine_weight (`float`):
Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
style_fidelity (`float`):
style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
elif style_fidelity=0.0, prompt more important, else balanced.
reference_attn (`bool`):
Whether to use reference query for self attention's context.
reference_adain (`bool`):
Whether to use reference adain.
Examples:
Expand Down Expand Up @@ -1115,6 +1156,14 @@ def __call__(
negative_prompt_embeds,
controlnet_conditioning_scale,
)
if ref_image is not None: # for ref_only mode
self.check_ref_input(reference_attn, reference_adain)
if ref_mask is not None:
ref_mask = prepare_mask_image(ref_mask)
ref_mask = F.interpolate(
ref_mask,
size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1147,8 +1196,17 @@ def __call__(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
if ref_image is not None:
ref_prompt_embeds = self._encode_prompt(
ref_prompt,
device,
num_images_per_prompt * 2,
do_classifier_free_guidance,
negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
prompt_embeds=None,
)

# 4. Prepare mask, image, and controlnet_conditioning_image
# 4. Prepare mask, image, and controlnet_conditioning_image + ref_img
image = prepare_image(image)

mask_image = prepare_mask_image(mask_image)
Expand Down Expand Up @@ -1187,6 +1245,37 @@ def __call__(

masked_image = image * (mask_image < 0.5)

if ref_image is not None: # for ref_only mode
# Preprocess reference image
# from controlnet_aux import LineartDetector
# processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
ref_ori = ref_image
ref_image = self.prepare_ref_image(
image=ref_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=prompt_embeds.dtype,
)

ref_control_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=ref_ori,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
ref_controlnet_conditioning_image = controlnet_conditioning_image.copy()
ref_controlnet_conditioning_image[-1] = ref_control_image
# ref_controlnet_conditioning_scale = controlnet_conditioning_scale.copy()
# ref_controlnet_conditioning_scale[0] = 1.0 # disable the first sam controlnet
# ref_controlnet_conditioning_scale[-1] = 0.2

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
Expand All @@ -1206,7 +1295,7 @@ def __call__(

noise = latents

if self.unet.config.in_channels != 4:
if self.unet.config.in_channels != 4: # inpainting base model
mask_image_latents = self.prepare_mask_latents(
mask_image,
batch_size * num_images_per_prompt,
Expand All @@ -1227,7 +1316,7 @@ def __call__(
generator,
do_classifier_free_guidance,
)
if self.unet.config.in_channels == 4:
if self.unet.config.in_channels == 4: # non-inpainting base model
init_masked_image_latents = self.prepare_masked_image_latents(
image,
batch_size * num_images_per_prompt,
Expand All @@ -1248,9 +1337,42 @@ def __call__(
mask_image = mask_image.to(latents.device).type_as(latents)
mask_image = 1 - mask_image

if ref_image is not None: # for ref_only mode
ref_image_latents = self.prepare_ref_latents(
ref_image,
batch_size * num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

if ref_image is not None: # for ref_only mode
# Modify self attention and group norm
self.uc_mask = (
torch.Tensor(
[1] * batch_size * num_images_per_prompt
+ [0] * batch_size * num_images_per_prompt
)
.type_as(ref_image_latents)
.bool()
)
self.attention_auto_machine_weight = attention_auto_machine_weight
self.gn_auto_machine_weight = gn_auto_machine_weight
self.do_classifier_free_guidance = do_classifier_free_guidance
self.style_fidelity = style_fidelity
self.ref_mask = ref_mask
attn_modules, gn_modules = self.redefine_ref_model(
self.unet, reference_attn, reference_adain, model_type="unet"
)

control_attn_modules, control_gn_modules = self.redefine_ref_model(
self.controlnet, reference_attn, False, model_type="controlnet"
)

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -1263,7 +1385,7 @@ def __call__(
non_inpainting_latent_model_input = self.scheduler.scale_model_input(
non_inpainting_latent_model_input, t
)
if self.unet.config.in_channels != 4:
if self.unet.config.in_channels != 4: # inpainting base model
inpainting_latent_model_input = torch.cat(
[
non_inpainting_latent_model_input,
Expand All @@ -1275,6 +1397,55 @@ def __call__(
else:
inpainting_latent_model_input = non_inpainting_latent_model_input

if ref_image is not None: # for ref_only mode
# ref only part
noise = randn_tensor(
ref_image_latents.shape,
generator=generator,
device=ref_image_latents.device,
dtype=ref_image_latents.dtype,
)
ref_xt = self.scheduler.add_noise(
ref_image_latents,
noise,
t.reshape(
1,
),
)
ref_xt = self.scheduler.scale_model_input(ref_xt, t)

MODE = "write"
self.change_module_mode(
MODE, control_attn_modules, control_gn_modules
)

(
ref_down_block_res_samples,
ref_mid_block_res_sample,
) = self.controlnet(
ref_xt,
t,
encoder_hidden_states=ref_prompt_embeds,
controlnet_cond=ref_controlnet_conditioning_image,
conditioning_scale=ref_controlnet_conditioning_scale,
guess_mode=guess_mode,
return_dict=False,
)

self.change_module_mode(MODE, attn_modules, gn_modules)
self.unet(
ref_xt,
t,
encoder_hidden_states=ref_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=ref_down_block_res_samples,
mid_block_additional_residual=ref_mid_block_res_sample,
return_dict=False,
)

# predict the noise residual
MODE = "read" # change to read mode for following noise_pred
self.change_module_mode(MODE, attn_modules, gn_modules)
down_block_res_samples, mid_block_res_sample = self.controlnet(
non_inpainting_latent_model_input,
t,
Expand All @@ -1284,7 +1455,6 @@ def __call__(
guess_mode=guess_mode,
return_dict=False,
)

# predict the noise residual
noise_pred = self.unet(
inpainting_latent_model_input,
Expand Down Expand Up @@ -1320,7 +1490,7 @@ def __call__(
# print(i, len(timesteps))
# masking for non-inpainting models
init_latents_proper = self.scheduler.add_noise(
init_masked_image_latents, noise, t
init_masked_image_latents, noise, timesteps[i + 1]
)
latents = (init_latents_proper * mask_image) + (
latents * (1 - mask_image)
Expand Down
Loading

0 comments on commit 7068ff5

Please sign in to comment.