forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
kakaobrain unCLIP (huggingface#1428)
* [wip] attention block updates * [wip] unCLIP unet decoder and super res * [wip] unCLIP prior transformer * [wip] scheduler changes * [wip] text proj utility class * [wip] UnCLIPPipeline * [wip] kakaobrain unCLIP convert script * [unCLIP pipeline] fixes re: @patrickvonplaten remove callbacks move denoising loops into call function * UNCLIPScheduler re: @patrickvonplaten Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified DDPM scheduler with changes to support karlo * mask -> attention_mask re: @patrickvonplaten * [DDPMScheduler] remove leftover change * [docs] PriorTransformer * [docs] UNet2DConditionModel and UNet2DModel * [nit] UNCLIPScheduler -> UnCLIPScheduler matches existing unclip naming better * [docs] SchedulingUnCLIP * [docs] UnCLIPTextProjModel * refactor * finish licenses * rename all to attention_mask and prep in models * more renaming * don't expose unused configs * final renaming fixes * remove x attn mask when not necessary * configure kakao script to use new class embedding config * fix copies * [tests] UnCLIPScheduler * finish x attn * finish * remove more * rename condition blocks * clean more * Apply suggestions from code review * up * fix * [tests] UnCLIPPipelineFastTests * remove unused imports * [tests] UnCLIPPipelineIntegrationTests * correct * make style Co-authored-by: Patrick von Platen <[email protected]>
- Loading branch information
1 parent
402b956
commit 2dcf64b
Showing
21 changed files
with
3,594 additions
and
118 deletions.
There are no files selected for viewing
1,159 changes: 1,159 additions & 0 deletions
1,159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..modeling_utils import ModelMixin | ||
from ..utils import BaseOutput | ||
from .attention import BasicTransformerBlock | ||
from .embeddings import TimestepEmbedding, Timesteps | ||
|
||
|
||
@dataclass | ||
class PriorTransformerOutput(BaseOutput): | ||
""" | ||
Args: | ||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
The predicted CLIP image embedding conditioned on the CLIP text embedding input. | ||
""" | ||
|
||
predicted_image_embedding: torch.FloatTensor | ||
|
||
|
||
class PriorTransformer(ModelMixin, ConfigMixin): | ||
""" | ||
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the | ||
transformer predicts the image embeddings through a denoising diffusion process. | ||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library | ||
implements for all the models (such as downloading or saving, etc.) | ||
For more details, see the original paper: https://arxiv.org/abs/2204.06125 | ||
Parameters: | ||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. | ||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. | ||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. | ||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP | ||
image embeddings and text embeddings are both the same dimension. | ||
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the | ||
length of the prompt after it has been tokenized. | ||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the | ||
projected hidden_states. The actual length of the used hidden_states is `num_embeddings + | ||
additional_embeddings`. | ||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
num_attention_heads: int = 32, | ||
attention_head_dim: int = 64, | ||
num_layers: int = 20, | ||
embedding_dim: int = 768, | ||
num_embeddings=77, | ||
additional_embeddings=4, | ||
dropout: float = 0.0, | ||
): | ||
super().__init__() | ||
self.num_attention_heads = num_attention_heads | ||
self.attention_head_dim = attention_head_dim | ||
inner_dim = num_attention_heads * attention_head_dim | ||
self.additional_embeddings = additional_embeddings | ||
|
||
self.time_proj = Timesteps(inner_dim, True, 0) | ||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) | ||
|
||
self.proj_in = nn.Linear(embedding_dim, inner_dim) | ||
|
||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim) | ||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) | ||
|
||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) | ||
|
||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) | ||
|
||
self.transformer_blocks = nn.ModuleList( | ||
[ | ||
BasicTransformerBlock( | ||
inner_dim, | ||
num_attention_heads, | ||
attention_head_dim, | ||
dropout=dropout, | ||
activation_fn="gelu", | ||
attention_bias=True, | ||
) | ||
for d in range(num_layers) | ||
] | ||
) | ||
|
||
self.norm_out = nn.LayerNorm(inner_dim) | ||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) | ||
|
||
causal_attention_mask = torch.full( | ||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") | ||
) | ||
causal_attention_mask.triu_(1) | ||
causal_attention_mask = causal_attention_mask[None, ...] | ||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) | ||
|
||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) | ||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) | ||
|
||
def forward( | ||
self, | ||
hidden_states, | ||
timestep: Union[torch.Tensor, float, int], | ||
proj_embedding: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor, | ||
attention_mask: Optional[torch.BoolTensor] = None, | ||
return_dict: bool = True, | ||
): | ||
""" | ||
Args: | ||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
x_t, the currently predicted image embeddings. | ||
timestep (`torch.long`): | ||
Current denoising step. | ||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
Projected embedding vector the denoising process is conditioned on. | ||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): | ||
Hidden states of the text embeddings the denoising process is conditioned on. | ||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): | ||
Text mask for the text embeddings. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain | ||
tuple. | ||
Returns: | ||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: | ||
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
returning a tuple, the first element is the sample tensor. | ||
""" | ||
batch_size = hidden_states.shape[0] | ||
|
||
timesteps = timestep | ||
if not torch.is_tensor(timesteps): | ||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) | ||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | ||
timesteps = timesteps[None].to(hidden_states.device) | ||
|
||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) | ||
|
||
timesteps_projected = self.time_proj(timesteps) | ||
|
||
# timesteps does not contain any weights and will always return f32 tensors | ||
# but time_embedding might be fp16, so we need to cast here. | ||
timesteps_projected = timesteps_projected.to(dtype=self.dtype) | ||
time_embeddings = self.time_embedding(timesteps_projected) | ||
|
||
proj_embeddings = self.embedding_proj(proj_embedding) | ||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) | ||
hidden_states = self.proj_in(hidden_states) | ||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) | ||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype) | ||
|
||
hidden_states = torch.cat( | ||
[ | ||
encoder_hidden_states, | ||
proj_embeddings[:, None, :], | ||
time_embeddings[:, None, :], | ||
hidden_states[:, None, :], | ||
prd_embedding, | ||
], | ||
dim=1, | ||
) | ||
|
||
hidden_states = hidden_states + positional_embeddings | ||
|
||
if attention_mask is not None: | ||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | ||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) | ||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) | ||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) | ||
|
||
for block in self.transformer_blocks: | ||
hidden_states = block(hidden_states, attention_mask=attention_mask) | ||
|
||
hidden_states = self.norm_out(hidden_states) | ||
hidden_states = hidden_states[:, -1] | ||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) | ||
|
||
if not return_dict: | ||
return (predicted_image_embedding,) | ||
|
||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) | ||
|
||
def post_process_latents(self, prior_latents): | ||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean | ||
return prior_latents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.