Skip to content

Commit

Permalink
Model CPU offload fix for BLIPDiffusion (huggingface#5174)
Browse files Browse the repository at this point in the history
cpu offload fix for blip diffusion
  • Loading branch information
DN6 authored Sep 25, 2023
1 parent 22b19d5 commit 92f15f5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
21 changes: 15 additions & 6 deletions src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""

model_cpu_offload_seq = "qformer->text_encoder->unet->vae"

def __init__(
self,
tokenizer: CLIPTokenizer,
Expand Down Expand Up @@ -155,7 +157,9 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device
latents = latents * self.scheduler.init_noise_sigma
return latents

def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device

# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
Expand All @@ -166,7 +170,7 @@ def encode_prompt(self, query_embeds, prompt):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)

batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
Expand Down Expand Up @@ -249,11 +253,12 @@ def __call__(
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device

reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)

if isinstance(prompt, str):
prompt = [prompt]
Expand All @@ -271,7 +276,7 @@ def __call__(
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
Expand All @@ -283,7 +288,7 @@ def __call__(
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
Expand All @@ -300,7 +305,7 @@ def __call__(
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
Expand Down Expand Up @@ -330,9 +335,13 @@ def __call__(
t,
latents,
)["prev_sample"]

image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""

model_cpu_offload_seq = "qformer->text_encoder->unet->vae"

def __init__(
self,
tokenizer: CLIPTokenizer,
Expand Down Expand Up @@ -166,7 +168,9 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device
latents = latents * self.scheduler.init_noise_sigma
return latents

def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device

# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
Expand All @@ -177,7 +181,7 @@ def encode_prompt(self, query_embeds, prompt):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)

batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
Expand Down Expand Up @@ -297,11 +301,12 @@ def __call__(
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device

reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)

if isinstance(prompt, str):
prompt = [prompt]
Expand All @@ -319,7 +324,7 @@ def __call__(
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
Expand All @@ -332,7 +337,7 @@ def __call__(
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
Expand All @@ -348,7 +353,7 @@ def __call__(
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
Expand Down Expand Up @@ -399,6 +404,9 @@ def __call__(
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)

Expand Down

0 comments on commit 92f15f5

Please sign in to comment.