@@ -37,67 +37,83 @@ def forward(self, input):
37
37
return nn .functional .avg_pool2d (input , kernel_size = input .size ()[2 :]).view (- 1 , input .size (1 ))
38
38
39
39
40
- def conv_block (in_channels , out_channels ):
40
+ def conv_block (in_channels , out_channels , activation = 'relu' ):
41
41
"""Returns a Module that performs 3x3 convolution, ReLu activation, 2x2 max pooling.
42
42
43
43
# Arguments
44
44
in_channels:
45
45
out_channels
46
46
"""
47
- return nn .Sequential (
48
- nn .Conv2d (in_channels , out_channels , 3 , padding = 1 ),
49
- nn .BatchNorm2d (out_channels ),
50
- nn .ReLU (),
51
- nn .MaxPool2d (kernel_size = 2 , stride = 2 )
52
- )
53
-
54
-
55
- def functional_conv_block (x , weights , biases , bn_weights , bn_biases ):
47
+ if activation == 'relu' :
48
+ return nn .Sequential (
49
+ nn .Conv2d (in_channels , out_channels , 3 , padding = 1 ),
50
+ nn .BatchNorm2d (out_channels ),
51
+ nn .ReLU (),
52
+ nn .MaxPool2d (kernel_size = 2 , stride = 2 )
53
+ )
54
+ elif activation == 'selu' :
55
+ return nn .Sequential (
56
+ nn .Conv2d (in_channels , out_channels , 3 , padding = 1 ),
57
+ nn .SELU (),
58
+ nn .MaxPool2d (kernel_size = 2 , stride = 2 )
59
+ )
60
+ else :
61
+ raise ValueError ('Unsupported activation.' )
62
+
63
+
64
+ def functional_conv_block (x , weights , biases , bn_weights , bn_biases , activation : str = 'relu' ):
56
65
"""Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion."""
57
66
x = F .conv2d (x , weights , biases , padding = 1 )
58
- x = F .batch_norm (x , running_mean = None , running_var = None , weight = bn_weights , bias = bn_biases , training = True )
59
- x = F .relu (x )
67
+ if activation == 'relu' :
68
+ x = F .batch_norm (x , running_mean = None , running_var = None , weight = bn_weights , bias = bn_biases , training = True )
69
+ x = F .relu (x )
70
+ elif activation == 'selu' :
71
+ x = F .selu (x )
72
+ else :
73
+ raise ValueError ('Unsupported activation.' )
60
74
x = F .max_pool2d (x , kernel_size = 2 , stride = 2 )
61
75
return x
62
76
63
77
64
78
##########
65
79
# Models #
66
80
##########
67
- def get_few_shot_encoder (num_input_channels = 1 ):
81
+ def get_few_shot_encoder (num_input_channels = 1 , activation : str = 'relu' ):
68
82
"""Creates a few shot encoder as used in Matching and Prototypical Networks
69
83
70
84
# Arguments:
71
85
num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
72
86
miniImageNet = 3
87
+ activation: Whether to use ReLu activation + batchnorm or SELU on its own
73
88
"""
74
89
return nn .Sequential (
75
- conv_block (num_input_channels , 64 ),
76
- conv_block (64 , 64 ),
77
- conv_block (64 , 64 ),
78
- conv_block (64 , 64 ),
90
+ conv_block (num_input_channels , 64 , activation ),
91
+ conv_block (64 , 64 , activation ),
92
+ conv_block (64 , 64 , activation ),
93
+ conv_block (64 , 64 , activation ),
79
94
Flatten (),
80
95
)
81
96
82
97
83
98
class FewShotClassifier (nn .Module ):
84
- def __init__ (self , num_input_channels : int , k_way : int , final_layer_size : int = 64 ):
99
+ def __init__ (self , num_input_channels : int , k_way : int , final_layer_size : int = 64 , activation : str = 'relu' ):
85
100
"""Creates a few shot classifier as used in MAML.
86
101
87
102
This network should be identical to the one created by `get_few_shot_encoder` but with a
88
103
clasification layer on top.
89
104
90
- # Arguments:
91
- num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
92
- miniImageNet = 3
93
- k_way: Number of classes the model will discriminate between
94
- final_layer_size: 64 for Omniglot, 1600 for miniImageNet
105
+ # Arguments:
106
+ num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
107
+ miniImageNet = 3
108
+ k_way: Number of classes the model will discriminate between
109
+ final_layer_size: 64 for Omniglot, 1600 for miniImageNet
110
+ activation: Whether to use ReLu activation + batchnorm or SELU on its own
95
111
"""
96
112
super (FewShotClassifier , self ).__init__ ()
97
- self .conv1 = conv_block (num_input_channels , 64 )
98
- self .conv2 = conv_block (64 , 64 )
99
- self .conv3 = conv_block (64 , 64 )
100
- self .conv4 = conv_block (64 , 64 )
113
+ self .conv1 = conv_block (num_input_channels , 64 , activation )
114
+ self .conv2 = conv_block (64 , 64 , activation )
115
+ self .conv3 = conv_block (64 , 64 , activation )
116
+ self .conv4 = conv_block (64 , 64 , activation )
101
117
102
118
self .logits = nn .Linear (final_layer_size , k_way )
103
119
@@ -116,7 +132,7 @@ def functional_forward(self, x, weights):
116
132
117
133
for block in [1 , 2 , 3 , 4 ]:
118
134
x = functional_conv_block (x , weights [f'conv{ block } .0.weight' ], weights [f'conv{ block } .0.bias' ],
119
- weights [ f'conv{ block } .1.weight' ] , weights [ f'conv{ block } .1.bias' ] )
135
+ weights . get ( f'conv{ block } .1.weight' ) , weights . get ( f'conv{ block } .1.bias' ) )
120
136
121
137
x = x .view (x .size (0 ), - 1 )
122
138
0 commit comments