Skip to content

Commit

Permalink
Add .features and .classifier to vision models. (learnables#172)
Browse files Browse the repository at this point in the history
* Add .features and .classifier to vision models.

* Add docs, fix minor bugs.
  • Loading branch information
seba-1511 authored Aug 26, 2020
1 parent 5c4e91f commit 0c9729b
Showing 1 changed file with 95 additions and 59 deletions.
154 changes: 95 additions & 59 deletions learn2learn/vision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
**Description**
A set of commonly used models for meta-learning vision tasks.
For simplicity, all models' `forward` conform to the following API:
~~~python
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
~~~
"""

import torch
import learn2learn as l2l

from scipy.stats import truncnorm
from torch import nn


def truncated_normal_(tensor, mean=0.0, std=1.0):
Expand All @@ -24,28 +33,29 @@ def fc_init_(module):
if hasattr(module, 'weight') and module.weight is not None:
truncated_normal_(module.weight.data, mean=0.0, std=0.01)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias.data, 0.0)
torch.nn.init.constant_(module.bias.data, 0.0)
return module


def maml_init_(module):
nn.init.xavier_uniform_(module.weight.data, gain=1.0)
nn.init.constant_(module.bias.data, 0.0)
torch.nn.init.xavier_uniform_(module.weight.data, gain=1.0)
torch.nn.init.constant_(module.bias.data, 0.0)
return module


class LinearBlock(nn.Module):
class LinearBlock(torch.nn.Module):

def __init__(self, input_size, output_size):
super(LinearBlock, self).__init__()
self.relu = nn.ReLU()
self.normalize = nn.BatchNorm1d(output_size,
affine=True,
momentum=0.999,
eps=1e-3,
track_running_stats=False,
)
self.linear = nn.Linear(input_size, output_size)
self.relu = torch.nn.ReLU()
self.normalize = torch.nn.BatchNorm1d(
output_size,
affine=True,
momentum=0.999,
eps=1e-3,
track_running_stats=False,
)
self.linear = torch.nn.Linear(input_size, output_size)
fc_init_(self.linear)

def forward(self, x):
Expand All @@ -55,7 +65,7 @@ def forward(self, x):
return x


class ConvBlock(nn.Module):
class ConvBlock(torch.nn.Module):

def __init__(self,
in_channels,
Expand All @@ -66,28 +76,32 @@ def __init__(self,
super(ConvBlock, self).__init__()
stride = (int(2 * max_pool_factor), int(2 * max_pool_factor))
if max_pool:
self.max_pool = nn.MaxPool2d(kernel_size=stride,
stride=stride,
ceil_mode=False,
)
self.max_pool = torch.nn.MaxPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=False,
)
stride = (1, 1)
else:
self.max_pool = lambda x: x
self.normalize = nn.BatchNorm2d(out_channels,
affine=True,
# eps=1e-3,
# momentum=0.999,
# track_running_stats=False,
)
nn.init.uniform_(self.normalize.weight)
self.relu = nn.ReLU()

self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=1,
bias=True)
self.normalize = torch.nn.BatchNorm2d(
out_channels,
affine=True,
# eps=1e-3,
# momentum=0.999,
# track_running_stats=False,
)
torch.nn.init.uniform_(self.normalize.weight)
self.relu = torch.nn.ReLU()

self.conv = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=1,
bias=True,
)
maml_init_(self.conv)

def forward(self, x):
Expand All @@ -98,7 +112,7 @@ def forward(self, x):
return x


class ConvBase(nn.Sequential):
class ConvBase(torch.nn.Sequential):

# NOTE:
# Omniglot: hidden=64, channels=1, no max_pool
Expand Down Expand Up @@ -126,10 +140,10 @@ def __init__(self,
super(ConvBase, self).__init__(*core)


class OmniglotFC(nn.Sequential):
class OmniglotFC(torch.nn.Module):
"""
[[Source]]()
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py)
**Description**
Expand All @@ -155,23 +169,30 @@ class OmniglotFC(nn.Sequential):
"""

def __init__(self, input_size, output_size, sizes=None):
super(OmniglotFC, self).__init__()
if sizes is None:
sizes = [256, 128, 64, 64]
layers = [LinearBlock(input_size, sizes[0]), ]
for s_i, s_o in zip(sizes[:-1], sizes[1:]):
layers.append(LinearBlock(s_i, s_o))
layers.append(fc_init_(nn.Linear(sizes[-1], output_size)))
super(OmniglotFC, self).__init__(*layers)
layers = torch.nn.Sequential(*layers)
self.features = torch.nn.Sequential(
l2l.nn.Flatten(),
layers,
)
self.classifier = fc_init_(torch.nn.Linear(sizes[-1], output_size))
self.input_size = input_size

def forward(self, x):
return super(OmniglotFC, self).forward(x.view(-1, self.input_size))
x = self.features(x)
x = self.classifier(x)
return x


class OmniglotCNN(nn.Module):
class OmniglotCNN(torch.nn.Module):
"""
[Source]()
[Source](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py)
**Description**
Expand Down Expand Up @@ -204,21 +225,26 @@ def __init__(self, output_size=5, hidden_size=64, layers=4):
channels=1,
max_pool=False,
layers=layers)
self.linear = nn.Linear(hidden_size, output_size, bias=True)
self.linear.weight.data.normal_()
self.linear.bias.data.mul_(0.0)
self.features = torch.nn.Sequential(
l2l.nn.Lambda(lambda x: x.view(-1, 1, 28, 28)),
self.base,
l2l.nn.Lambda(lambda x: x.mean(dim=[2, 3])),
l2l.nn.Flatten(),
)
self.classifier = torch.nn.Linear(hidden_size, output_size, bias=True)
self.classifier.weight.data.normal_()
self.classifier.bias.data.mul_(0.0)

def forward(self, x):
x = self.base(x.view(-1, 1, 28, 28))
x = x.mean(dim=[2, 3])
x = self.linear(x)
x = self.features(x)
x = self.classifier(x)
return x


class MiniImagenetCNN(nn.Module):
class MiniImagenetCNN(torch.nn.Module):
"""
[[Source]]()
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py)
**Description**
Expand All @@ -244,17 +270,27 @@ class MiniImagenetCNN(nn.Module):

def __init__(self, output_size, hidden_size=32, layers=4):
super(MiniImagenetCNN, self).__init__()
self.base = ConvBase(output_size=hidden_size,
hidden=hidden_size,
channels=3,
max_pool=True,
layers=layers,
max_pool_factor=4 // layers)
self.linear = nn.Linear(25 * hidden_size, output_size, bias=True)
maml_init_(self.linear)
base = ConvBase(
output_size=hidden_size,
hidden=hidden_size,
channels=3,
max_pool=True,
layers=layers,
max_pool_factor=4 // layers,
)
self.features = torch.nn.Sequential(
base,
l2l.nn.Flatten(),
)
self.classifier = torch.nn.Linear(
25 * hidden_size,
output_size,
bias=True,
)
maml_init_(self.classifier)
self.hidden_size = hidden_size

def forward(self, x):
x = self.base(x)
x = self.linear(x.view(-1, 25 * self.hidden_size))
x = self.features(x)
x = self.classifier(x)
return x

0 comments on commit 0c9729b

Please sign in to comment.