Skip to content

Commit

Permalink
[ip-adapter] refactor prepare_ip_adapter_image_embeds and skip load…
Browse files Browse the repository at this point in the history
… `image_encoder` (huggingface#7016)

* add

Co-authored-by: Sayak Paul <[email protected]>

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2024
1 parent f4fc750 commit 06b01ea
Show file tree
Hide file tree
Showing 25 changed files with 769 additions and 138 deletions.
33 changes: 33 additions & 0 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,39 @@ export_to_gif(frames, "gummy_bear.gif")
> [!TIP]
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.

```py
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
ip_adapter_image=image,
ip_adapter_image_embeds=None,
device="cuda",
num_images_per_prompt=1,
do_classifier_free_guidance=True,
)

torch.save(image_embeds, "image_embeds.ipadpt")
```

Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`

> [!TIP]
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.
```py
image_embeds = torch.load("image_embeds.ipadpt")
images = pipeline(
prompt="a polar bear sitting in a chair drinking a milkshake",
ip_adapter_image_embeds=image_embeds,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
num_inference_steps=100,
generator=generator,
).images
```

> [!TIP]
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.
## Specific use cases

IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with!
Expand Down
6 changes: 4 additions & 2 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,10 @@ def __call__(
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
are specified, images must be passed as a list such that each element of the list can be correctly
Expand Down
6 changes: 4 additions & 2 deletions examples/community/pipeline_animatediff_img2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down
51 changes: 37 additions & 14 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import torch
from huggingface_hub.utils import validate_hf_hub_args
Expand Down Expand Up @@ -52,11 +52,12 @@ def load_ip_adapter(
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
Expand All @@ -65,7 +66,18 @@ def load_ip_adapter(
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
If a list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
for example, `image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
Expand All @@ -87,8 +99,6 @@ def load_ip_adapter(
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Expand Down Expand Up @@ -184,16 +194,29 @@ def load_ip_adapter(

# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
low_cpu_mem_usage=low_cpu_mem_usage,
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
)

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
Expand Down
38 changes: 32 additions & 6 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
Expand All @@ -394,13 +394,23 @@ def prepare_ip_adapter_image_embeds(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)

return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down Expand Up @@ -494,6 +504,16 @@ def check_inputs(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
Expand Down Expand Up @@ -612,8 +632,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down Expand Up @@ -717,7 +739,11 @@ def __call__(

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)

# 4. Prepare timesteps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
Expand All @@ -472,13 +472,23 @@ def prepare_ip_adapter_image_embeds(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)

return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down Expand Up @@ -523,6 +533,8 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
Expand Down Expand Up @@ -567,6 +579,21 @@ def check_inputs(
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

def get_timesteps(self, num_inference_steps, timesteps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
Expand Down Expand Up @@ -765,8 +792,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down Expand Up @@ -814,6 +843,8 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
video=video,
latents=latents,
ip_adapter_image=ip_adapter_image,
ip_adapter_image_embeds=ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

Expand Down Expand Up @@ -855,7 +886,11 @@ def __call__(

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)

# 4. Prepare timesteps
Expand Down
Loading

0 comments on commit 06b01ea

Please sign in to comment.