Skip to content

Commit

Permalink
[PixArt] fix small nits in pixart sigma (huggingface#7767)
Browse files Browse the repository at this point in the history
fix small nits in pixart sigma
  • Loading branch information
sayakpaul authored Apr 25, 2024
1 parent 39215aa commit e963621
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
9 changes: 0 additions & 9 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)

# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]

# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
Expand Down
15 changes: 3 additions & 12 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,7 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)

# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]

# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -369,7 +360,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
def check_inputs(
self,
prompt,
Expand Down Expand Up @@ -462,7 +453,7 @@ def process(text: str):

return [process(t) for t in text]

# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline._clean_caption
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
Expand Down

0 comments on commit e963621

Please sign in to comment.