Skip to content

Commit

Permalink
Bring Flax attention naming in sync with PyTorch (huggingface#2511)
Browse files Browse the repository at this point in the history
Bring flax attention naming in sync with PyTorch.
  • Loading branch information
pcuenca authored Mar 1, 2023
1 parent eadf0e2 commit e4a9fb3
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax.numpy as jnp


class FlaxAttentionBlock(nn.Module):
class FlaxCrossAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
Expand Down Expand Up @@ -118,10 +118,10 @@ class FlaxBasicTransformerBlock(nn.Module):

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down Expand Up @@ -242,10 +242,14 @@ def __call__(self, hidden_states, context, deterministic=True):
return hidden_states


class FlaxGluFeedForward(nn.Module):
class FlaxFeedForward(nn.Module):
r"""
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
[`FeedForward`] class, with the following simplifications:
- The activation function is currently hardcoded to a gated linear unit from:
https://arxiv.org/abs/2002.05202
- `dim_out` is equal to `dim`.
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
Parameters:
dim (:obj:`int`):
Expand Down

0 comments on commit e4a9fb3

Please sign in to comment.