forked from dragen1860/TensorFlow-2.x-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeed_forward.py
113 lines (103 loc) · 4.46 KB
/
feed_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from tensorflow import keras
import tensorflow.keras.backend as K
class FeedForward(keras.layers.Layer):
"""Point-wise feed-forward layer.
See: https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self,
units,
activation='relu',
use_bias=True,
kernel_initializer='glorot_normal',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
"""Initialize the layer.
:param units: Dimension of hidden units.
:param activation: Activation for the first linear transformation.
:param use_bias: Whether to use the bias term.
:param kernel_initializer: Initializer for kernels.
:param bias_initializer: Initializer for kernels.
:param kernel_regularizer: Regularizer for kernels.
:param bias_regularizer: Regularizer for kernels.
:param kernel_constraint: Constraint for kernels.
:param bias_constraint: Constraint for kernels.
:param kwargs:
"""
self.supports_masking = True
self.units = units
self.activation = keras.activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.bias_regularizer = keras.regularizers.get(bias_regularizer)
self.kernel_constraint = keras.constraints.get(kernel_constraint)
self.bias_constraint = keras.constraints.get(bias_constraint)
self.W1, self.b1 = None, None
self.W2, self.b2 = None, None
super(FeedForward, self).__init__(**kwargs)
def get_config(self):
config = {
'units': self.units,
'activation': keras.activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
'bias_initializer': keras.initializers.serialize(self.bias_initializer),
'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
'bias_constraint': keras.constraints.serialize(self.bias_constraint),
}
base_config = super(FeedForward, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
def compute_mask(self, inputs, input_mask=None):
return input_mask
def build(self, input_shape):
feature_dim = input_shape[-1]
self.W1 = self.add_weight(
shape=(feature_dim, self.units),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='{}_W1'.format(self.name),
)
if self.use_bias:
self.b1 = self.add_weight(
shape=(self.units,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='{}_b1'.format(self.name),
)
self.W2 = self.add_weight(
shape=(self.units, feature_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='{}_W2'.format(self.name),
)
if self.use_bias:
self.b2 = self.add_weight(
shape=(feature_dim,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='{}_b2'.format(self.name),
)
super(FeedForward, self).build(input_shape)
def call(self, x, mask=None):
h = K.dot(x, self.W1)
if self.use_bias:
h = K.bias_add(h, self.b1)
if self.activation is not None:
h = self.activation(h)
y = K.dot(h, self.W2)
if self.use_bias:
y = K.bias_add(y, self.b2)
return y