Skip to content

Commit

Permalink
[PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to gen…
Browse files Browse the repository at this point in the history
…erate is more than 1 (huggingface#5752)

* does this fix things?

* attention mask use

* attention mask order

* better masking.

* add: tesrt

* remove mask_featur

* test

* debug

* fix: tests

* deprecate mask_feature

* add deprecation test

* add slow test

* add print statements to retrieve the assertion values.

* fix for the 1024 fast tes

* fix tesy

* fix the remaining

* Apply suggestions from code review

* more debug

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
sayakpaul and patrickvonplaten authored Nov 14, 2023
1 parent 16d5004 commit a5720e9
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 43 deletions.
103 changes: 66 additions & 37 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
Expand Down Expand Up @@ -162,8 +163,10 @@ def encode_prompt(
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -189,10 +192,11 @@ def encode_prompt(
string.
clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None

if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)

if device is None:
device = self._execution_device
Expand Down Expand Up @@ -229,13 +233,11 @@ def encode_prompt(
f" {max_length} tokens: {removed_text}"
)

attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds_attention_mask = attention_mask
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)

prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)

if self.text_encoder is not None:
dtype = self.text_encoder.dtype
Expand All @@ -250,8 +252,8 @@ def encode_prompt(
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)

# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
Expand All @@ -267,11 +269,11 @@ def encode_prompt(
add_special_tokens=True,
return_tensors="pt",
)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)

negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]

Expand All @@ -284,23 +286,13 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None

# Perform additional masking.
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
return masked_prompt_embeds, masked_negative_prompt_embeds

return prompt_embeds, negative_prompt_embeds
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -329,6 +321,8 @@ def check_inputs(
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -365,13 +359,25 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")

if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")

if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)

# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
Expand Down Expand Up @@ -579,14 +585,16 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -630,9 +638,12 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand All @@ -648,11 +659,10 @@ def __call__(
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
Examples:
Expand All @@ -661,6 +671,9 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
Expand All @@ -669,7 +682,15 @@ def __call__(
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)

self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)

# 2. Default height and width to transformer
Expand All @@ -688,19 +709,26 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -758,6 +786,7 @@ def __call__(
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
Expand Down
Loading

0 comments on commit a5720e9

Please sign in to comment.