Skip to content

Commit

Permalink
refactor - introduce MatrixMultipleLayer to replace the Lambda layers…
Browse files Browse the repository at this point in the history
… that use tf.matmul
  • Loading branch information
ksachdeva committed Sep 30, 2019
1 parent cf2fd23 commit d63ccbb
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions lib/FSANET_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ def call(self, inputs):
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])

class MatrixMultiplyLayer(Layer):
def __init__(self, **kwargs):
super(MatrixMultiplyLayer,self).__init__(**kwargs)
self.trainable = False

def call(self, inputs):
x1, x2 = inputs
# TODO: add some asserts on the inputs
# it is expected the shape of inputs are
# arranged to be able to perform the matrix multiplication
return tf.matmul(x1,x2)

def compute_output_shape(self, input_shapes):
return (input_shapes[0][0],input_shapes[0][1], input_shapes[1][-1])


class BaseFSANet(object):
def __init__(self, image_size,num_classes,stage_num,lambda_d, S_set):
Expand Down Expand Up @@ -283,11 +298,12 @@ def ssr_S_model_build(self, num_primcaps, m_dim):

feat_pre_concat = Concatenate()([feat_s1_preS,feat_s2_preS,feat_s3_preS])
SL_matrix = Dense(int(num_primcaps/3)*m_dim,activation='sigmoid')(feat_pre_concat)
SL_matrix = Reshape((int(num_primcaps/3),m_dim))(SL_matrix)

S_matrix_s1 = Lambda(lambda x: tf.matmul(x[0],x[1]),name='S_matrix_s1')([SL_matrix,SR_matrix_s1])
S_matrix_s2 = Lambda(lambda x: tf.matmul(x[0],x[1]),name='S_matrix_s2')([SL_matrix,SR_matrix_s2])
S_matrix_s3 = Lambda(lambda x: tf.matmul(x[0],x[1]),name='S_matrix_s3')([SL_matrix,SR_matrix_s3])
SL_matrix = Reshape((int(num_primcaps/3),m_dim))(SL_matrix)


S_matrix_s1 = MatrixMultiplyLayer(name="S_matrix_s1")([SL_matrix,SR_matrix_s1])
S_matrix_s2 = MatrixMultiplyLayer(name='S_matrix_s2')([SL_matrix,SR_matrix_s2])
S_matrix_s3 = MatrixMultiplyLayer(name='S_matrix_s3')([SL_matrix,SR_matrix_s3])

# Very important!!! Without this training won't converge.
# norm_S = Lambda(lambda x: K.tile(K.sum(x,axis=-1,keepdims=True),(1,1,64)))(S_matrix)
Expand Down

0 comments on commit d63ccbb

Please sign in to comment.