Skip to content

Commit

Permalink
Add a MNIST CNN example
Browse files Browse the repository at this point in the history
  • Loading branch information
koz4k committed Oct 28, 2017
1 parent ba0f70b commit 1039f44
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
29 changes: 29 additions & 0 deletions examples/mnist-cnn/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2017,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 changes: 29 additions & 0 deletions examples/mnist-cnn/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
MNIST with CNN
--------------

This example illustrates how to implement a custom ``Synthesizer``.
Code is mostly copied from the official PyTorch MNIST example:
https://github.com/pytorch/examples/blob/master/mnist/main.py

Classification model is the same as in the original example (a CNN) with
batch normalization added on every layer and DNI inserted between the last
convolutional layer and the first fully-connected layer (before activation).

Synthesizer used is a CNN with three convolutional layers with padding, so
that sizes of the feature maps are kept constant, and ReLU activation function.

To install requirements::

$ pip install -r requirements.txt

To train with regular backpropagation::

$ python main.py

To train with DNI (no label conditioning)::

$ python main.py --dni

To train with cDNI (label conditioning)::

$ python main.py --dni --context
167 changes: 167 additions & 0 deletions examples/mnist-cnn/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import dni

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--dni', action='store_true', default=False,
help='enable DNI')
parser.add_argument('--context', action='store_true', default=False,
help='enable context (label conditioning) in DNI')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)


kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)


def one_hot(indexes, n_classes):
result = torch.FloatTensor(indexes.size() + (n_classes,))
if args.cuda:
result = result.cuda()
result.zero_()
indexes_rank = len(indexes.size())
result.scatter_(
dim=indexes_rank,
index=indexes.data.unsqueeze(dim=indexes_rank),
value=1
)
return Variable(result)


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv1_bn = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_bn = nn.BatchNorm2d(20)
self.conv2_drop = nn.Dropout2d()
if args.dni:
self.backward_interface = dni.BackwardInterface(ConvSynthesizer())
self.fc1 = nn.Linear(320, 50)
self.fc1_bn = nn.BatchNorm1d(50)
self.fc2 = nn.Linear(50, 10)
self.fc2_bn = nn.BatchNorm1d(10)

def forward(self, x, y=None):
x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2))
x = F.max_pool2d(self.conv2_drop(self.conv2_bn(self.conv2(x))), 2)
if args.dni and self.training:
if args.context:
context = one_hot(y, 10)
else:
context = None
with dni.synthesizer_context(context):
x = self.backward_interface(x)
x = F.relu(x)
x = x.view(-1, 320)
x = F.relu(self.fc1_bn(self.fc1(x)))
x = F.dropout(x, training=self.training)
x = self.fc2_bn(self.fc2(x))
return F.log_softmax(x)


class ConvSynthesizer(nn.Module):
def __init__(self):
super(ConvSynthesizer, self).__init__()
self.input_trigger = nn.Conv2d(20, 20, kernel_size=5, padding=2)
self.input_context = nn.Linear(10, 20)
self.hidden = nn.Conv2d(20, 20, kernel_size=5, padding=2)
self.output = nn.Conv2d(20, 20, kernel_size=5, padding=2)
# zero-initialize the last layer, as in the paper
nn.init.constant(self.output.weight, 0)

def forward(self, trigger, context):
x = self.input_trigger(trigger)
if context is not None:
x += (
self.input_context(context).unsqueeze(2)
.unsqueeze(3)
.expand_as(x)
)
x = self.hidden(F.relu(x))
return self.output(F.relu(x))


model = Net()
if args.cuda:
model.cuda()

optimizer = optim.Adam(model.parameters(), lr=args.lr)

def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data, target)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))

def test():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


for epoch in range(1, args.epochs + 1):
train(epoch)
test()
3 changes: 3 additions & 0 deletions examples/mnist-cnn/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+git://github.com/koz4k/dni-pytorch.git#egg=dni-pytorch
torch
torchvision

0 comments on commit 1039f44

Please sign in to comment.