Skip to content

Commit

Permalink
[Kandinsky 3.0] Follow-up TODOs (huggingface#5944)
Browse files Browse the repository at this point in the history
clean-up kendinsky 3.0
  • Loading branch information
yiyixuxu authored Dec 1, 2023
1 parent 0f55c17 commit b41f809
Show file tree
Hide file tree
Showing 9 changed files with 744 additions and 279 deletions.
4 changes: 2 additions & 2 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
Expand Down Expand Up @@ -72,7 +72,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
Expand Down
49 changes: 6 additions & 43 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.nn.functional as F
from torch import einsum, nn
from torch import nn

from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -109,15 +109,17 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim

# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand All @@ -126,7 +128,7 @@ def __init__(
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0

self.heads = heads
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
Expand Down Expand Up @@ -193,7 +195,7 @@ def __init__(
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -2219,44 +2221,6 @@ def __call__(
return hidden_states


# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""

@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)

def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)

attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)

if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand All @@ -2282,7 +2246,6 @@ def __call__(
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)

AttentionProcessor = Union[
Expand Down
Loading

0 comments on commit b41f809

Please sign in to comment.