Skip to content

Commit

Permalink
refactor - introduce AggregatedFeatureExtractionLayer and use it from…
Browse files Browse the repository at this point in the history
… various models / no more Lambda layers usage in the model anymore.
  • Loading branch information
ksachdeva committed Sep 30, 2019
1 parent 077f486 commit e7a9aac
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions lib/FSANET_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Layer
from keras.layers import Lambda
from keras.layers import Reshape
from keras.layers import Multiply
from keras.layers import Flatten
Expand Down Expand Up @@ -171,6 +170,37 @@ def call(self, inputs):
def compute_output_shape(self, input_shapes):
return input_shapes[-1]

class AggregatedFeatureExtractionLayer(Layer):
def __init__(self, num_capsule, **kwargs):
super(AggregatedFeatureExtractionLayer,self).__init__(**kwargs)
self.trainable = False
self.num_capsule = num_capsule

def call(self, input):
s1_a = 0
s1_b = self.num_capsule//3
feat_s1_div = input[:,s1_a:s1_b,:]
s2_a = self.num_capsule//3
s2_b = 2*self.num_capsule//3
feat_s2_div = input[:,s2_a:s2_b,:]
s3_a = 2*self.num_capsule//3
s3_b = self.num_capsule
feat_s3_div = input[:,s3_a:s3_b,:]

return [feat_s1_div, feat_s2_div, feat_s3_div]

def compute_output_shape(self, input_shape):
last_dim = input_shape[-1]
partition = self.num_capsule//3
return [(input_shape[0], partition, last_dim), (input_shape[0], partition, last_dim), (input_shape[0], partition, last_dim)]

def get_config(self):
config = {
'num_capsule': self.num_capsule
}
base_config = super(AggregatedFeatureExtractionLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class BaseFSANet(object):
def __init__(self, image_size,num_classes,stage_num,lambda_d, S_set):
Expand Down Expand Up @@ -410,21 +440,13 @@ def __init__(self, image_size,num_classes,stage_num,lambda_d, S_set):

def ssr_aggregation_model_build(self, shape_primcaps):
input_primcaps = Input(shape_primcaps)
capsule = CapsuleLayer(self.num_capsule, self.dim_capsule, routings=self.routings, name='caps')(input_primcaps)
capsule = CapsuleLayer(self.num_capsule, self.dim_capsule, routings=self.routings, name='caps')(input_primcaps)

s1_a = 0
s1_b = self.num_capsule//3
feat_s1_div = Lambda(lambda x: x[:,s1_a:s1_b,:])(capsule)
s2_a = self.num_capsule//3
s2_b = 2*self.num_capsule//3
feat_s2_div = Lambda(lambda x: x[:,s2_a:s2_b,:])(capsule)
s3_a = 2*self.num_capsule//3
s3_b = self.num_capsule
feat_s3_div = Lambda(lambda x: x[:,s3_a:s3_b,:])(capsule)
feat_s1_div, feat_s2_div, feat_s3_div = AggregatedFeatureExtractionLayer(num_capsule=self.num_capsule)(capsule)

feat_s1_div = Reshape((-1,))(feat_s1_div)
feat_s2_div = Reshape((-1,))(feat_s2_div)
feat_s3_div = Reshape((-1,))(feat_s3_div)
feat_s3_div = Reshape((-1,))(feat_s3_div)

return Model(inputs=input_primcaps,outputs=[feat_s1_div,feat_s2_div,feat_s3_div], name='ssr_Cap_model')

Expand Down Expand Up @@ -471,15 +493,7 @@ def ssr_aggregation_model_build(self, shape_primcaps):
agg_feat = NetVLAD(feature_size=64, max_samples=self.num_primcaps, cluster_size=self.num_capsule, output_dim=self.num_capsule*self.dim_capsule)(input_primcaps)
agg_feat = Reshape((self.num_capsule,self.dim_capsule))(agg_feat)

s1_a = 0
s1_b = self.num_capsule//3
feat_s1_div = Lambda(lambda x: x[:,s1_a:s1_b,:])(agg_feat)
s2_a = self.num_capsule//3
s2_b = 2*self.num_capsule//3
feat_s2_div = Lambda(lambda x: x[:,s2_a:s2_b,:])(agg_feat)
s3_a = 2*self.num_capsule//3
s3_b = self.num_capsule
feat_s3_div = Lambda(lambda x: x[:,s3_a:s3_b,:])(agg_feat)
feat_s1_div, feat_s2_div, feat_s3_div = AggregatedFeatureExtractionLayer(num_capsule=self.num_capsule)(agg_feat)

feat_s1_div = Reshape((-1,))(feat_s1_div)
feat_s2_div = Reshape((-1,))(feat_s2_div)
Expand Down Expand Up @@ -529,15 +543,7 @@ def ssr_aggregation_model_build(self, shape_primcaps):
metric_feat = MatMulLayer(16,type=1)(input_primcaps)
metric_feat = MatMulLayer(3,type=2)(metric_feat)

s1_a = 0
s1_b = self.num_capsule//3
feat_s1_div = Lambda(lambda x: x[:,s1_a:s1_b,:])(metric_feat)
s2_a = self.num_capsule//3
s2_b = 2*self.num_capsule//3
feat_s2_div = Lambda(lambda x: x[:,s2_a:s2_b,:])(metric_feat)
s3_a = 2*self.num_capsule//3
s3_b = self.num_capsule
feat_s3_div = Lambda(lambda x: x[:,s3_a:s3_b,:])(metric_feat)
feat_s1_div, feat_s2_div, feat_s3_div = AggregatedFeatureExtractionLayer(num_capsule=self.num_capsule)(metric_feat)

feat_s1_div = Reshape((-1,))(feat_s1_div)
feat_s2_div = Reshape((-1,))(feat_s2_div)
Expand Down

0 comments on commit e7a9aac

Please sign in to comment.