Skip to content

Commit

Permalink
fix - report the usage of constant for map dimensions instead of hard…
Browse files Browse the repository at this point in the history
… coded value

- had started the branch before a revision in the original repo so reporting it back here
  • Loading branch information
ksachdeva committed Sep 30, 2019
1 parent ff13db6 commit 03cf175
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions lib/FSANET_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
np.random.seed(2 ** 10)

# Custom layers
# Note - we use Lambda layers so that the model (weights)
# can be converted to various other formats. Usage of Lambda layers prevent the convertion
# Note - Usage of Lambda layers prevent the convertion
# and the optimizations by the underlying math engine (tensorflow in this case)

@register_keras_custom_object
Expand Down Expand Up @@ -235,6 +234,7 @@ def __init__(self, image_size,num_classes,stage_num,lambda_d, S_set):
self.m_dim = S_set[4]

self.F_shape = int(self.num_capsule/3)*self.dim_capsule
self.map_xy_size = int(8*image_size/64)

self.is_fc_model = False
self.is_noS_model = False
Expand Down Expand Up @@ -344,23 +344,23 @@ def _process_input(stage_index, stage_num, num_classes, input_s_pre):
return Model(inputs=[input_s1_pre,input_s2_pre,input_s3_pre],outputs=[pred_s1,pred_s2,pred_s3,delta_s1,delta_s2,delta_s3,local_s1,local_s2,local_s3], name=name_F)

def ssr_feat_S_model_build(self, m_dim):
input_preS = Input((8,8,64))
input_preS = Input((self.map_xy_size,self.map_xy_size,64))

if self.is_varS_model:
feat_preS = MomentsLayer()(input_preS)
else:
feat_preS = Conv2D(1,(1,1),padding='same',activation='sigmoid')(input_preS)

feat_preS = Reshape((-1,))(feat_preS)
SR_matrix = Dense(m_dim*(8*8+8*8+8*8),activation='sigmoid')(feat_preS)
SR_matrix = Reshape((m_dim,(8*8+8*8+8*8)))(SR_matrix)
SR_matrix = Dense(m_dim*(self.map_xy_size*self.map_xy_size*3),activation='sigmoid')(feat_preS)
SR_matrix = Reshape((m_dim,(self.map_xy_size*self.map_xy_size*3)))(SR_matrix)

return Model(inputs=input_preS,outputs=[SR_matrix,feat_preS],name='feat_S_model')

def ssr_S_model_build(self, num_primcaps, m_dim):
input_s1_preS = Input((8,8,64))
input_s2_preS = Input((8,8,64))
input_s3_preS = Input((8,8,64))
input_s1_preS = Input((self.map_xy_size,self.map_xy_size,64))
input_s2_preS = Input((self.map_xy_size,self.map_xy_size,64))
input_s3_preS = Input((self.map_xy_size,self.map_xy_size,64))

feat_S_model = self.ssr_feat_S_model_build(m_dim)

Expand All @@ -382,9 +382,9 @@ def ssr_S_model_build(self, num_primcaps, m_dim):
norm_S_s2 = MatrixNormLayer(tile_count=64)(S_matrix_s2)
norm_S_s3 = MatrixNormLayer(tile_count=64)(S_matrix_s3)

feat_s1_pre = Reshape((8*8,64))(input_s1_preS)
feat_s2_pre = Reshape((8*8,64))(input_s2_preS)
feat_s3_pre = Reshape((8*8,64))(input_s3_preS)
feat_s1_pre = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s1_preS)
feat_s2_pre = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s2_preS)
feat_s3_pre = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s3_preS)

feat_pre_concat = Concatenate(axis=1)([feat_s1_pre, feat_s2_pre, feat_s3_pre])

Expand All @@ -400,14 +400,15 @@ def ssr_S_model_build(self, num_primcaps, m_dim):

return Model(inputs=[input_s1_preS, input_s2_preS, input_s3_preS],outputs=primcaps, name='ssr_S_model')

def ssr_noS_model_build(self, **kwargs):
input_s1_preS = Input((8,8,64))
input_s2_preS = Input((8,8,64))
input_s3_preS = Input((8,8,64))
def ssr_noS_model_build(self, **kwargs):

primcaps_s1 = Reshape((8*8,64))(input_s1_preS)
primcaps_s2 = Reshape((8*8,64))(input_s2_preS)
primcaps_s3 = Reshape((8*8,64))(input_s3_preS)
input_s1_preS = Input((self.map_xy_size,self.map_xy_size,64))
input_s2_preS = Input((self.map_xy_size,self.map_xy_size,64))
input_s3_preS = Input((self.map_xy_size,self.map_xy_size,64))

primcaps_s1 = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s1_preS)
primcaps_s2 = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s2_preS)
primcaps_s3 = Reshape((self.map_xy_size*self.map_xy_size,64))(input_s3_preS)

primcaps = Concatenate(axis=1)([primcaps_s1,primcaps_s2,primcaps_s3])
return Model(inputs=[input_s1_preS, input_s2_preS, input_s3_preS],outputs=primcaps, name='ssr_S_model')
Expand Down

0 comments on commit 03cf175

Please sign in to comment.