Skip to content

Commit b7b06b7

Browse files
committedNov 14, 2018
Added option for SELU activation
1 parent 6f04915 commit b7b06b7

File tree

3 files changed

+83
-75
lines changed

3 files changed

+83
-75
lines changed
 

‎experiments/maml.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
parser.add_argument('--meta-lr', default=0.005, type=float)
3232
parser.add_argument('--meta-batch-size', default=32, type=int)
3333
parser.add_argument('--order', default=1, type=int)
34+
parser.add_argument('--activation', default='relu', type=str)
3435

3536
args = parser.parse_args()
3637

@@ -52,7 +53,7 @@
5253
meta_batches_per_epoch = 100
5354

5455
param_str = f'{args.dataset}_order={args.order}_n={args.n}_k={args.k}_metabatch={args.meta_batch_size}_' \
55-
f'train_steps={args.inner_train_steps}_val_steps={args.inner_val_steps}'
56+
f'train_steps={args.inner_train_steps}_val_steps={args.inner_val_steps}_act={args.activation}'
5657
print(param_str)
5758

5859

@@ -112,7 +113,6 @@ def prepare_meta_batch_(batch):
112113
# MAML kwargs
113114
inner_train_steps=args.inner_val_steps,
114115
inner_lr=args.inner_lr,
115-
num_input_channels=num_input_channels,
116116
device=device,
117117
order=args.order,
118118
),

‎few_shot/datasets.py

+37-45
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313

1414
class OmniglotDataset(Dataset):
1515
def __init__(self, subset):
16+
"""Dataset class representing Omniglot dataset
17+
18+
# Arguments:
19+
subset: Whether the dataset represents the background or evaluation set
20+
"""
1621
if subset not in ('background', 'evaluation'):
1722
raise(ValueError, 'subset must be one of (background, evaluation)')
1823
self.subset = subset
@@ -49,53 +54,16 @@ def __len__(self):
4954
def num_classes(self):
5055
return len(self.df['class_name'].unique())
5156

52-
def build_n_shot_task(self, k, n=1, query=1):
53-
"""
54-
This method builds a k-way n-shot classification task. It returns a support set of n audio samples each from k
55-
unique speakers. In addition it will return a query sample. Downstream models will attempt to match the query
56-
sample to the correct samples in the support set.
57-
:param k: Number of unique speakers to include in this task
58-
:param n: Number of audio samples to include from each speaker
59-
:param query: Number of query samples
60-
:return:
61-
"""
62-
if k >= self.num_classes():
63-
raise(ValueError, 'k must be smaller than the number of unique speakers in this dataset!')
64-
65-
if k <= 1:
66-
raise(ValueError, 'k must be greater than or equal to one!')
67-
68-
query = self.df.sample(query)
69-
query_samples = self[query['id'].values[0]]
70-
# Add batch dimension
71-
query_samples = (query_samples[0][np.newaxis, :, :], query_samples[1])
72-
73-
is_query_character = self.df['class_id'] == query['class_id'].values[0]
74-
not_query_sample = ~self.df.index.isin(query['id'].values)
75-
correct_samples = self.df[is_query_character & not_query_sample].sample(n)
76-
77-
# Sample k-1 speakers
78-
other_support_set_characters = np.random.choice(
79-
self.df[~is_query_character]['class_id'].unique(), k-1, replace=False)
80-
81-
other_support_samples = []
82-
for i in range(k-1):
83-
is_same_speaker = self.df['class_id'] == other_support_set_characters[i]
84-
other_support_samples.append(
85-
self.df[~is_query_character & is_same_speaker].sample(n)
86-
)
87-
support_set = pd.concat([correct_samples]+other_support_samples)
88-
support_set_samples = tuple(np.stack(i) for i in zip(*[self[i] for i in support_set.index]))
89-
90-
return query_samples, support_set_samples
91-
9257
@staticmethod
9358
def index_subset(subset):
94-
"""
95-
Index a subset by looping through all of it's files and recording their speaker ID, filepath and length.
96-
:param subset: Name of the subset
97-
:return: A list of dicts containing information about all the audio files in a particular subset of the
98-
LibriSpeech dataset
59+
"""Index a subset by looping through all of its files and recording relevant information.
60+
61+
# Arguments
62+
subset: Name of the subset
63+
64+
# Returns
65+
A list of dicts containing information about all the image files in a particular subset of the
66+
Omniglot dataset dataset
9967
"""
10068
images = []
10169
print('Indexing {}...'.format(subset))
@@ -127,6 +95,11 @@ def index_subset(subset):
12795

12896
class MiniImageNet(Dataset):
12997
def __init__(self, subset):
98+
"""Dataset class representing miniImageNet dataset
99+
100+
# Arguments:
101+
subset: Whether the dataset represents the background or evaluation set
102+
"""
130103
if subset not in ('background', 'evaluation'):
131104
raise(ValueError, 'subset must be one of (background, evaluation)')
132105
self.subset = subset
@@ -168,6 +141,15 @@ def num_classes(self):
168141

169142
@staticmethod
170143
def index_subset(subset):
144+
"""Index a subset by looping through all of its files and recording relevant information.
145+
146+
# Arguments
147+
subset: Name of the subset
148+
149+
# Returns
150+
A list of dicts containing information about all the image files in a particular subset of the
151+
miniImageNet dataset
152+
"""
171153
images = []
172154
print('Indexing {}...'.format(subset))
173155
# Quick first pass to find total for tqdm bar
@@ -196,6 +178,16 @@ def index_subset(subset):
196178

197179
class DummyDataset(Dataset):
198180
def __init__(self, samples_per_class=10, n_classes=10, n_features=1):
181+
"""Dummy dataset for debugging/testing purposes
182+
183+
A sample from the DummyDataset has (n_features + 1) features. The first feature is the index of the sample
184+
in the data and the remaining features are the class index.
185+
186+
# Arguments
187+
samples_per_class: Number of samples per class in the dataset
188+
n_classes: Number of distinct classes in the dataset
189+
n_features: Number of extra features each sample should have.
190+
"""
199191
self.samples_per_class = samples_per_class
200192
self.n_classes = n_classes
201193
self.n_features = n_features

‎few_shot/models.py

+44-28
Original file line numberDiff line numberDiff line change
@@ -37,67 +37,83 @@ def forward(self, input):
3737
return nn.functional.avg_pool2d(input, kernel_size=input.size()[2:]).view(-1, input.size(1))
3838

3939

40-
def conv_block(in_channels, out_channels):
40+
def conv_block(in_channels, out_channels, activation='relu'):
4141
"""Returns a Module that performs 3x3 convolution, ReLu activation, 2x2 max pooling.
4242
4343
# Arguments
4444
in_channels:
4545
out_channels
4646
"""
47-
return nn.Sequential(
48-
nn.Conv2d(in_channels, out_channels, 3, padding=1),
49-
nn.BatchNorm2d(out_channels),
50-
nn.ReLU(),
51-
nn.MaxPool2d(kernel_size=2, stride=2)
52-
)
53-
54-
55-
def functional_conv_block(x, weights, biases, bn_weights, bn_biases):
47+
if activation == 'relu':
48+
return nn.Sequential(
49+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
50+
nn.BatchNorm2d(out_channels),
51+
nn.ReLU(),
52+
nn.MaxPool2d(kernel_size=2, stride=2)
53+
)
54+
elif activation == 'selu':
55+
return nn.Sequential(
56+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
57+
nn.SELU(),
58+
nn.MaxPool2d(kernel_size=2, stride=2)
59+
)
60+
else:
61+
raise ValueError('Unsupported activation.')
62+
63+
64+
def functional_conv_block(x, weights, biases, bn_weights, bn_biases, activation: str = 'relu'):
5665
"""Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion."""
5766
x = F.conv2d(x, weights, biases, padding=1)
58-
x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
59-
x = F.relu(x)
67+
if activation == 'relu':
68+
x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
69+
x = F.relu(x)
70+
elif activation == 'selu':
71+
x = F.selu(x)
72+
else:
73+
raise ValueError('Unsupported activation.')
6074
x = F.max_pool2d(x, kernel_size=2, stride=2)
6175
return x
6276

6377

6478
##########
6579
# Models #
6680
##########
67-
def get_few_shot_encoder(num_input_channels=1):
81+
def get_few_shot_encoder(num_input_channels=1, activation: str = 'relu'):
6882
"""Creates a few shot encoder as used in Matching and Prototypical Networks
6983
7084
# Arguments:
7185
num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
7286
miniImageNet = 3
87+
activation: Whether to use ReLu activation + batchnorm or SELU on its own
7388
"""
7489
return nn.Sequential(
75-
conv_block(num_input_channels, 64),
76-
conv_block(64, 64),
77-
conv_block(64, 64),
78-
conv_block(64, 64),
90+
conv_block(num_input_channels, 64, activation),
91+
conv_block(64, 64, activation),
92+
conv_block(64, 64, activation),
93+
conv_block(64, 64, activation),
7994
Flatten(),
8095
)
8196

8297

8398
class FewShotClassifier(nn.Module):
84-
def __init__(self, num_input_channels: int, k_way: int, final_layer_size: int = 64):
99+
def __init__(self, num_input_channels: int, k_way: int, final_layer_size: int = 64, activation: str = 'relu'):
85100
"""Creates a few shot classifier as used in MAML.
86101
87102
This network should be identical to the one created by `get_few_shot_encoder` but with a
88103
clasification layer on top.
89104
90-
# Arguments:
91-
num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
92-
miniImageNet = 3
93-
k_way: Number of classes the model will discriminate between
94-
final_layer_size: 64 for Omniglot, 1600 for miniImageNet
105+
# Arguments:
106+
num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
107+
miniImageNet = 3
108+
k_way: Number of classes the model will discriminate between
109+
final_layer_size: 64 for Omniglot, 1600 for miniImageNet
110+
activation: Whether to use ReLu activation + batchnorm or SELU on its own
95111
"""
96112
super(FewShotClassifier, self).__init__()
97-
self.conv1 = conv_block(num_input_channels, 64)
98-
self.conv2 = conv_block(64, 64)
99-
self.conv3 = conv_block(64, 64)
100-
self.conv4 = conv_block(64, 64)
113+
self.conv1 = conv_block(num_input_channels, 64, activation)
114+
self.conv2 = conv_block(64, 64, activation)
115+
self.conv3 = conv_block(64, 64, activation)
116+
self.conv4 = conv_block(64, 64, activation)
101117

102118
self.logits = nn.Linear(final_layer_size, k_way)
103119

@@ -116,7 +132,7 @@ def functional_forward(self, x, weights):
116132

117133
for block in [1, 2, 3, 4]:
118134
x = functional_conv_block(x, weights[f'conv{block}.0.weight'], weights[f'conv{block}.0.bias'],
119-
weights[f'conv{block}.1.weight'], weights[f'conv{block}.1.bias'])
135+
weights.get(f'conv{block}.1.weight'), weights.get(f'conv{block}.1.bias'))
120136

121137
x = x.view(x.size(0), -1)
122138

0 commit comments

Comments
 (0)
Please sign in to comment.