From 7068ff5b16fa5502f82b3e1c2a36e77ee0a4bed3 Mon Sep 17 00:00:00 2001 From: shgao Date: Wed, 14 Jun 2023 12:57:47 +0800 Subject: [PATCH] rename file --- sam2edit_lora.py | 2 +- utils/stable_diffusion_controlnet_inpaint.py | 186 +++++++++++++++++- ...stable_diffusion_controlnet_inpaint_v1.py} | 186 +----------------- 3 files changed, 187 insertions(+), 187 deletions(-) rename utils/{stable_diffusion_controlnet_inpaint_ref.py => stable_diffusion_controlnet_inpaint_v1.py} (87%) diff --git a/sam2edit_lora.py b/sam2edit_lora.py index 0c275dd..f59233e 100644 --- a/sam2edit_lora.py +++ b/sam2edit_lora.py @@ -21,7 +21,7 @@ from collections import defaultdict from diffusers import StableDiffusionControlNetPipeline from diffusers import ControlNetModel, UniPCMultistepScheduler -from utils.stable_diffusion_controlnet_inpaint_ref import ( +from utils.stable_diffusion_controlnet_inpaint import ( StableDiffusionControlNetInpaintPipeline, ) diff --git a/utils/stable_diffusion_controlnet_inpaint.py b/utils/stable_diffusion_controlnet_inpaint.py index 2e5a906..43bd291 100644 --- a/utils/stable_diffusion_controlnet_inpaint.py +++ b/utils/stable_diffusion_controlnet_inpaint.py @@ -31,7 +31,8 @@ randn_tensor, replace_example_docstring, ) -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from utils.stable_diffusion_reference import StableDiffusionReferencePipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -257,7 +258,12 @@ def prepare_controlnet_conditioning_image( return controlnet_conditioning_image -class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin): +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, + LoraLoaderMixin, + StableDiffusionReferencePipeline, + TextualInversionLoaderMixin, +): """ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ """ @@ -1009,6 +1015,25 @@ def __call__( controlnet_conditioning_scale: Union[float, List[float]] = 1.0, alignment_ratio=None, guess_mode: bool = False, + ref_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + List[torch.FloatTensor], + List[PIL.Image.Image], + ] = None, + ref_mask: Union[ + torch.FloatTensor, + PIL.Image.Image, + List[torch.FloatTensor], + List[PIL.Image.Image], + ] = None, + ref_controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + ref_prompt: Union[str, List[str]] = None, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -1086,6 +1111,22 @@ def __call__( guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. Examples: @@ -1115,6 +1156,14 @@ def __call__( negative_prompt_embeds, controlnet_conditioning_scale, ) + if ref_image is not None: # for ref_only mode + self.check_ref_input(reference_attn, reference_adain) + if ref_mask is not None: + ref_mask = prepare_mask_image(ref_mask) + ref_mask = F.interpolate( + ref_mask, + size=(height // self.vae_scale_factor, width // self.vae_scale_factor), + ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1147,8 +1196,17 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + if ref_image is not None: + ref_prompt_embeds = self._encode_prompt( + ref_prompt, + device, + num_images_per_prompt * 2, + do_classifier_free_guidance, + negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", + prompt_embeds=None, + ) - # 4. Prepare mask, image, and controlnet_conditioning_image + # 4. Prepare mask, image, and controlnet_conditioning_image + ref_img image = prepare_image(image) mask_image = prepare_mask_image(mask_image) @@ -1187,6 +1245,37 @@ def __call__( masked_image = image * (mask_image < 0.5) + if ref_image is not None: # for ref_only mode + # Preprocess reference image + # from controlnet_aux import LineartDetector + # processor = LineartDetector.from_pretrained("lllyasviel/Annotators") + ref_ori = ref_image + ref_image = self.prepare_ref_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + ref_control_image = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=ref_ori, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + ref_controlnet_conditioning_image = controlnet_conditioning_image.copy() + ref_controlnet_conditioning_image[-1] = ref_control_image + # ref_controlnet_conditioning_scale = controlnet_conditioning_scale.copy() + # ref_controlnet_conditioning_scale[0] = 1.0 # disable the first sam controlnet + # ref_controlnet_conditioning_scale[-1] = 0.2 + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -1206,7 +1295,7 @@ def __call__( noise = latents - if self.unet.config.in_channels != 4: + if self.unet.config.in_channels != 4: # inpainting base model mask_image_latents = self.prepare_mask_latents( mask_image, batch_size * num_images_per_prompt, @@ -1227,7 +1316,7 @@ def __call__( generator, do_classifier_free_guidance, ) - if self.unet.config.in_channels == 4: + if self.unet.config.in_channels == 4: # non-inpainting base model init_masked_image_latents = self.prepare_masked_image_latents( image, batch_size * num_images_per_prompt, @@ -1248,9 +1337,42 @@ def __call__( mask_image = mask_image.to(latents.device).type_as(latents) mask_image = 1 - mask_image + if ref_image is not None: # for ref_only mode + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + if ref_image is not None: # for ref_only mode + # Modify self attention and group norm + self.uc_mask = ( + torch.Tensor( + [1] * batch_size * num_images_per_prompt + + [0] * batch_size * num_images_per_prompt + ) + .type_as(ref_image_latents) + .bool() + ) + self.attention_auto_machine_weight = attention_auto_machine_weight + self.gn_auto_machine_weight = gn_auto_machine_weight + self.do_classifier_free_guidance = do_classifier_free_guidance + self.style_fidelity = style_fidelity + self.ref_mask = ref_mask + attn_modules, gn_modules = self.redefine_ref_model( + self.unet, reference_attn, reference_adain, model_type="unet" + ) + + control_attn_modules, control_gn_modules = self.redefine_ref_model( + self.controlnet, reference_attn, False, model_type="controlnet" + ) + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1263,7 +1385,7 @@ def __call__( non_inpainting_latent_model_input = self.scheduler.scale_model_input( non_inpainting_latent_model_input, t ) - if self.unet.config.in_channels != 4: + if self.unet.config.in_channels != 4: # inpainting base model inpainting_latent_model_input = torch.cat( [ non_inpainting_latent_model_input, @@ -1275,6 +1397,55 @@ def __call__( else: inpainting_latent_model_input = non_inpainting_latent_model_input + if ref_image is not None: # for ref_only mode + # ref only part + noise = randn_tensor( + ref_image_latents.shape, + generator=generator, + device=ref_image_latents.device, + dtype=ref_image_latents.dtype, + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.change_module_mode( + MODE, control_attn_modules, control_gn_modules + ) + + ( + ref_down_block_res_samples, + ref_mid_block_res_sample, + ) = self.controlnet( + ref_xt, + t, + encoder_hidden_states=ref_prompt_embeds, + controlnet_cond=ref_controlnet_conditioning_image, + conditioning_scale=ref_controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + self.change_module_mode(MODE, attn_modules, gn_modules) + self.unet( + ref_xt, + t, + encoder_hidden_states=ref_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=ref_down_block_res_samples, + mid_block_additional_residual=ref_mid_block_res_sample, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" # change to read mode for following noise_pred + self.change_module_mode(MODE, attn_modules, gn_modules) down_block_res_samples, mid_block_res_sample = self.controlnet( non_inpainting_latent_model_input, t, @@ -1284,7 +1455,6 @@ def __call__( guess_mode=guess_mode, return_dict=False, ) - # predict the noise residual noise_pred = self.unet( inpainting_latent_model_input, @@ -1320,7 +1490,7 @@ def __call__( # print(i, len(timesteps)) # masking for non-inpainting models init_latents_proper = self.scheduler.add_noise( - init_masked_image_latents, noise, t + init_masked_image_latents, noise, timesteps[i + 1] ) latents = (init_latents_proper * mask_image) + ( latents * (1 - mask_image) diff --git a/utils/stable_diffusion_controlnet_inpaint_ref.py b/utils/stable_diffusion_controlnet_inpaint_v1.py similarity index 87% rename from utils/stable_diffusion_controlnet_inpaint_ref.py rename to utils/stable_diffusion_controlnet_inpaint_v1.py index 43bd291..2e5a906 100644 --- a/utils/stable_diffusion_controlnet_inpaint_ref.py +++ b/utils/stable_diffusion_controlnet_inpaint_v1.py @@ -31,8 +31,7 @@ randn_tensor, replace_example_docstring, ) -from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from utils.stable_diffusion_reference import StableDiffusionReferencePipeline +from diffusers.loaders import LoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -258,12 +257,7 @@ def prepare_controlnet_conditioning_image( return controlnet_conditioning_image -class StableDiffusionControlNetInpaintPipeline( - DiffusionPipeline, - LoraLoaderMixin, - StableDiffusionReferencePipeline, - TextualInversionLoaderMixin, -): +class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin): """ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ """ @@ -1015,25 +1009,6 @@ def __call__( controlnet_conditioning_scale: Union[float, List[float]] = 1.0, alignment_ratio=None, guess_mode: bool = False, - ref_image: Union[ - torch.FloatTensor, - PIL.Image.Image, - List[torch.FloatTensor], - List[PIL.Image.Image], - ] = None, - ref_mask: Union[ - torch.FloatTensor, - PIL.Image.Image, - List[torch.FloatTensor], - List[PIL.Image.Image], - ] = None, - ref_controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - ref_prompt: Union[str, List[str]] = None, - attention_auto_machine_weight: float = 1.0, - gn_auto_machine_weight: float = 1.0, - style_fidelity: float = 0.5, - reference_attn: bool = True, - reference_adain: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -1111,22 +1086,6 @@ def __call__( guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. - ref_image (`torch.FloatTensor`, `PIL.Image.Image`): - The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can - also be accepted as an image. - attention_auto_machine_weight (`float`): - Weight of using reference query for self attention's context. - If attention_auto_machine_weight=1.0, use reference query for all self attention's context. - gn_auto_machine_weight (`float`): - Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. - style_fidelity (`float`): - style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, - elif style_fidelity=0.0, prompt more important, else balanced. - reference_attn (`bool`): - Whether to use reference query for self attention's context. - reference_adain (`bool`): - Whether to use reference adain. Examples: @@ -1156,14 +1115,6 @@ def __call__( negative_prompt_embeds, controlnet_conditioning_scale, ) - if ref_image is not None: # for ref_only mode - self.check_ref_input(reference_attn, reference_adain) - if ref_mask is not None: - ref_mask = prepare_mask_image(ref_mask) - ref_mask = F.interpolate( - ref_mask, - size=(height // self.vae_scale_factor, width // self.vae_scale_factor), - ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1196,17 +1147,8 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) - if ref_image is not None: - ref_prompt_embeds = self._encode_prompt( - ref_prompt, - device, - num_images_per_prompt * 2, - do_classifier_free_guidance, - negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", - prompt_embeds=None, - ) - # 4. Prepare mask, image, and controlnet_conditioning_image + ref_img + # 4. Prepare mask, image, and controlnet_conditioning_image image = prepare_image(image) mask_image = prepare_mask_image(mask_image) @@ -1245,37 +1187,6 @@ def __call__( masked_image = image * (mask_image < 0.5) - if ref_image is not None: # for ref_only mode - # Preprocess reference image - # from controlnet_aux import LineartDetector - # processor = LineartDetector.from_pretrained("lllyasviel/Annotators") - ref_ori = ref_image - ref_image = self.prepare_ref_image( - image=ref_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=prompt_embeds.dtype, - ) - - ref_control_image = prepare_controlnet_conditioning_image( - controlnet_conditioning_image=ref_ori, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - ref_controlnet_conditioning_image = controlnet_conditioning_image.copy() - ref_controlnet_conditioning_image[-1] = ref_control_image - # ref_controlnet_conditioning_scale = controlnet_conditioning_scale.copy() - # ref_controlnet_conditioning_scale[0] = 1.0 # disable the first sam controlnet - # ref_controlnet_conditioning_scale[-1] = 0.2 - # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -1295,7 +1206,7 @@ def __call__( noise = latents - if self.unet.config.in_channels != 4: # inpainting base model + if self.unet.config.in_channels != 4: mask_image_latents = self.prepare_mask_latents( mask_image, batch_size * num_images_per_prompt, @@ -1316,7 +1227,7 @@ def __call__( generator, do_classifier_free_guidance, ) - if self.unet.config.in_channels == 4: # non-inpainting base model + if self.unet.config.in_channels == 4: init_masked_image_latents = self.prepare_masked_image_latents( image, batch_size * num_images_per_prompt, @@ -1337,42 +1248,9 @@ def __call__( mask_image = mask_image.to(latents.device).type_as(latents) mask_image = 1 - mask_image - if ref_image is not None: # for ref_only mode - ref_image_latents = self.prepare_ref_latents( - ref_image, - batch_size * num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - if ref_image is not None: # for ref_only mode - # Modify self attention and group norm - self.uc_mask = ( - torch.Tensor( - [1] * batch_size * num_images_per_prompt - + [0] * batch_size * num_images_per_prompt - ) - .type_as(ref_image_latents) - .bool() - ) - self.attention_auto_machine_weight = attention_auto_machine_weight - self.gn_auto_machine_weight = gn_auto_machine_weight - self.do_classifier_free_guidance = do_classifier_free_guidance - self.style_fidelity = style_fidelity - self.ref_mask = ref_mask - attn_modules, gn_modules = self.redefine_ref_model( - self.unet, reference_attn, reference_adain, model_type="unet" - ) - - control_attn_modules, control_gn_modules = self.redefine_ref_model( - self.controlnet, reference_attn, False, model_type="controlnet" - ) - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1385,7 +1263,7 @@ def __call__( non_inpainting_latent_model_input = self.scheduler.scale_model_input( non_inpainting_latent_model_input, t ) - if self.unet.config.in_channels != 4: # inpainting base model + if self.unet.config.in_channels != 4: inpainting_latent_model_input = torch.cat( [ non_inpainting_latent_model_input, @@ -1397,55 +1275,6 @@ def __call__( else: inpainting_latent_model_input = non_inpainting_latent_model_input - if ref_image is not None: # for ref_only mode - # ref only part - noise = randn_tensor( - ref_image_latents.shape, - generator=generator, - device=ref_image_latents.device, - dtype=ref_image_latents.dtype, - ) - ref_xt = self.scheduler.add_noise( - ref_image_latents, - noise, - t.reshape( - 1, - ), - ) - ref_xt = self.scheduler.scale_model_input(ref_xt, t) - - MODE = "write" - self.change_module_mode( - MODE, control_attn_modules, control_gn_modules - ) - - ( - ref_down_block_res_samples, - ref_mid_block_res_sample, - ) = self.controlnet( - ref_xt, - t, - encoder_hidden_states=ref_prompt_embeds, - controlnet_cond=ref_controlnet_conditioning_image, - conditioning_scale=ref_controlnet_conditioning_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - self.change_module_mode(MODE, attn_modules, gn_modules) - self.unet( - ref_xt, - t, - encoder_hidden_states=ref_prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=ref_down_block_res_samples, - mid_block_additional_residual=ref_mid_block_res_sample, - return_dict=False, - ) - - # predict the noise residual - MODE = "read" # change to read mode for following noise_pred - self.change_module_mode(MODE, attn_modules, gn_modules) down_block_res_samples, mid_block_res_sample = self.controlnet( non_inpainting_latent_model_input, t, @@ -1455,6 +1284,7 @@ def __call__( guess_mode=guess_mode, return_dict=False, ) + # predict the noise residual noise_pred = self.unet( inpainting_latent_model_input, @@ -1490,7 +1320,7 @@ def __call__( # print(i, len(timesteps)) # masking for non-inpainting models init_latents_proper = self.scheduler.add_noise( - init_masked_image_latents, noise, timesteps[i + 1] + init_masked_image_latents, noise, t ) latents = (init_latents_proper * mask_image) + ( latents * (1 - mask_image)