forked from lindermanlab/S5
-
Notifications
You must be signed in to change notification settings - Fork 4
/
layers.py
90 lines (81 loc) · 3.25 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from flax import linen as nn
import jax
class SequenceLayer(nn.Module):
""" Defines a single S5 layer, with S5 SSM, nonlinearity,
dropout, batch/layer norm, etc.
Args:
ssm (nn.Module): the SSM to be used (i.e. S5 ssm)
dropout (float32): dropout rate
d_model (int32): this is the feature size of the layer inputs and outputs
we usually refer to this size as H
activation (string): Type of activation function to use
training (bool): whether in training mode or not
prenorm (bool): apply prenorm if true or postnorm if false
batchnorm (bool): apply batchnorm if true or layernorm if false
bn_momentum (float32): the batchnorm momentum if batchnorm is used
step_rescale (float32): allows for uniformly changing the timescale parameter,
e.g. after training on a different resolution for
the speech commands benchmark
"""
ssm: nn.Module
dropout: float
d_model: int
activation: str = "gelu"
training: bool = True
prenorm: bool = False
batchnorm: bool = False
bn_momentum: float = 0.90
step_rescale: float = 1.0
def setup(self):
"""Initializes the ssm, batch/layer norm and dropout
"""
self.seq = self.ssm(step_rescale=self.step_rescale)
if self.activation in ["full_glu"]:
self.out1 = nn.Dense(self.d_model)
self.out2 = nn.Dense(self.d_model)
elif self.activation in ["half_glu1", "half_glu2"]:
self.out2 = nn.Dense(self.d_model)
if self.batchnorm:
self.norm = nn.BatchNorm(use_running_average=not self.training,
momentum=self.bn_momentum, axis_name='batch')
else:
self.norm = nn.LayerNorm()
self.drop = nn.Dropout(
self.dropout,
broadcast_dims=[0],
deterministic=not self.training,
)
def __call__(self, x):
"""
Compute the LxH output of S5 layer given an LxH input.
Args:
x (float32): input sequence (L, d_model)
Returns:
output sequence (float32): (L, d_model)
"""
skip = x
if self.prenorm:
x = self.norm(x)
x = self.seq(x)
if self.activation in ["full_glu"]:
x = self.drop(nn.gelu(x))
x = self.out1(x) * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
elif self.activation in ["half_glu1"]:
x = self.drop(nn.gelu(x))
x = x * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
elif self.activation in ["half_glu2"]:
# Only apply GELU to the gate input
x1 = self.drop(nn.gelu(x))
x = x * jax.nn.sigmoid(self.out2(x1))
x = self.drop(x)
elif self.activation in ["gelu"]:
x = self.drop(nn.gelu(x))
else:
raise NotImplementedError(
"Activation: {} not implemented".format(self.activation))
x = skip + x
if not self.prenorm:
x = self.norm(x)
return x