From 3529efd06e18e3a06c485c12de0c8a780465d6f6 Mon Sep 17 00:00:00 2001 From: amallya2 Date: Tue, 20 Feb 2018 19:28:55 -0600 Subject: [PATCH] Initial commit --- checkpoints/.keep | 0 data/.keep | 0 src/best_models/densenet121_binary.txt | 14 + src/dataset.py | 145 ++++++++ src/main.py | 447 +++++++++++++++++++++++++ src/modnets/__init__.py | 3 + src/modnets/densenet.py | 129 +++++++ src/modnets/layers.py | 229 +++++++++++++ src/modnets/resnet.py | 168 ++++++++++ src/modnets/vgg.py | 76 +++++ src/networks.py | 330 ++++++++++++++++++ src/pack.py | 187 +++++++++++ src/scripts/run_baselines.sh | 53 +++ src/scripts/run_packing.sh | 12 + src/scripts/run_piggyback_training.sh | 50 +++ src/utils.py | 24 ++ 16 files changed, 1867 insertions(+) create mode 100644 checkpoints/.keep create mode 100644 data/.keep create mode 100644 src/best_models/densenet121_binary.txt create mode 100644 src/dataset.py create mode 100644 src/main.py create mode 100644 src/modnets/__init__.py create mode 100644 src/modnets/densenet.py create mode 100644 src/modnets/layers.py create mode 100644 src/modnets/resnet.py create mode 100644 src/modnets/vgg.py create mode 100644 src/networks.py create mode 100644 src/pack.py create mode 100755 src/scripts/run_baselines.sh create mode 100755 src/scripts/run_packing.sh create mode 100755 src/scripts/run_piggyback_training.sh create mode 100644 src/utils.py diff --git a/checkpoints/.keep b/checkpoints/.keep new file mode 100644 index 0000000..e69de29 diff --git a/data/.keep b/data/.keep new file mode 100644 index 0000000..e69de29 diff --git a/src/best_models/densenet121_binary.txt b/src/best_models/densenet121_binary.txt new file mode 100644 index 0000000..235913e --- /dev/null +++ b/src/best_models/densenet121_binary.txt @@ -0,0 +1,14 @@ +################################################################################ +# DenseNet-121 Errors: +################################################################################ +# CUBS: 19.24 +# Stanford Cars: 10.62 +# Flowers: 4.91 +# Wikiart: 29.33 +# Sketch: 20.05 +################################################################################ +cubs_cropped: ../checkpoints/cubs_cropped/final/densenet121_binarizer_maskscale1e-2-none_lr1e-4-1e-4_decay15-15_2.pt.pt +stanford_cars_cropped: ../checkpoints/stanford_cars_cropped/final/densenet121_binarizer_maskscale1e-2-none_lr1e-4-1e-4_decay15-15_2.pt.pt +flowers: ../checkpoints/flowers/final/densenet121_binarizer_maskscale1e-2-none_lr1e-4-1e-4_decay15-15_3.pt.pt +wikiart: ../checkpoints/wikiart/final/densenet121_binarizer_maskscale1e-2-none_lr1e-4-1e-4_decay15-15_1.pt.pt +sketches: ../checkpoints/sketches/final/densenet121_binarizer_maskscale1e-2-none_lr1e-4-1e-4_decay15-15_1.pt.pt diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..7439ed0 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,145 @@ +import collections +import glob +import os + +import numpy as np +from PIL import Image + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data as data +import torchvision.datasets as datasets +import torchvision.models as models +import torchvision.transforms as transforms + + +def train_loader(path, batch_size, num_workers=4, pin_memory=False, normalize=None): + if normalize is None: + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return data.DataLoader( + datasets.ImageFolder(path, + transforms.Compose([ + transforms.Scale(256), + transforms.RandomSizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory) + + +def test_loader(path, batch_size, num_workers=4, pin_memory=False, normalize=None): + if normalize is None: + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return data.DataLoader( + datasets.ImageFolder(path, + transforms.Compose([ + transforms.Scale(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory) + + +def test_loader_caffe(path, batch_size, num_workers=4, pin_memory=False): + """Legacy loader for caffe. Used with models loaded from caffe.""" + # Returns images in 256 x 256 to subtract given mean of same size. + return data.DataLoader( + datasets.ImageFolder(path, + transforms.Compose([ + Scale((256, 256)), + # transforms.CenterCrop(224), + transforms.ToTensor(), + ])), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory) + + +def train_loader_cropped(path, batch_size, num_workers=4, pin_memory=False): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return data.DataLoader( + datasets.ImageFolder(path, + transforms.Compose([ + Scale((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory) + + +def test_loader_cropped(path, batch_size, num_workers=4, pin_memory=False): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return data.DataLoader( + datasets.ImageFolder(path, + transforms.Compose([ + Scale((224, 224)), + transforms.ToTensor(), + normalize, + ])), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory) + + +# Note: This might not be needed anymore given that this functionality exists in +# the newer PyTorch versions. +class Scale(object): + """Rescale the input PIL.Image to the given size. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (w, h), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance( + size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (PIL.Image): Image to be scaled. + Returns: + PIL.Image: Rescaled image. + """ + if isinstance(self.size, int): + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + return img + if w < h: + ow = self.size + oh = int(self.size * h / w) + return img.resize((ow, oh), self.interpolation) + else: + oh = self.size + ow = int(self.size * w / h) + return img.resize((ow, oh), self.interpolation) + else: + return img.resize(self.size, self.interpolation) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..57bdb9b --- /dev/null +++ b/src/main.py @@ -0,0 +1,447 @@ +"""Main entry point for doing all stuff.""" +from __future__ import division, print_function + +import argparse +import json +import warnings + +import torch +import torch.nn as nn +import torch.optim as optim +import torchnet as tnt +from torch.autograd import Variable +from tqdm import tqdm + +import dataset +import networks as net +import utils as utils + + +# To prevent PIL warnings. +warnings.filterwarnings("ignore") + +FLAGS = argparse.ArgumentParser() +FLAGS.add_argument('--arch', + choices=['vgg16', 'vgg16bn', 'resnet50', + 'densenet121', 'resnet50_diff'], + help='Architectures') +FLAGS.add_argument('--source', type=str, default='', + help='Location of the init file for resnet50_diff') +FLAGS.add_argument('--finetune_layers', + choices=['all', 'fc', 'classifier'], default='all', + help='Which layers to finetune, fc only works with vgg') +FLAGS.add_argument('--mode', + choices=['finetune', 'eval', 'check'], + help='Run mode') +FLAGS.add_argument('--num_outputs', type=int, default=-1, + help='Num outputs for dataset') +# Optimization options. +FLAGS.add_argument('--lr', type=float, + help='Learning rate for parameters, used for baselines') +FLAGS.add_argument('--lr_decay_every', type=int, + help='Step decay every this many epochs') +FLAGS.add_argument('--lr_mask', type=float, + help='Learning rate for mask') +FLAGS.add_argument('--lr_mask_decay_every', type=int, + help='Step decay every this many epochs') +FLAGS.add_argument('--mask_adam', action='store_true', default=False, + help='Use adam instead of sgdm for masks') +FLAGS.add_argument('--lr_classifier', type=float, + help='Learning rate for classifier') +FLAGS.add_argument('--lr_classifier_decay_every', type=int, + help='Step decay every this many epochs') + +FLAGS.add_argument('--lr_decay_factor', type=float, + help='Multiply lr by this much every step of decay') +FLAGS.add_argument('--finetune_epochs', type=int, + help='Number of initial finetuning epochs') +FLAGS.add_argument('--batch_size', type=int, default=32, + help='Batch size') +FLAGS.add_argument('--weight_decay', type=float, default=0.0, + help='Weight decay') +FLAGS.add_argument('--train_bn', action='store_true', default=False, + help='train batch norm or not') +# Masking options. +FLAGS.add_argument('--mask_init', default='1s', + choices=['1s', 'uniform', 'weight_based_1s'], + help='Type of mask init') +FLAGS.add_argument('--mask_scale', type=float, default=1e-2, + help='Mask initialization scaling') +FLAGS.add_argument('--mask_scale_gradients', type=str, default='none', + choices=['none', 'average', 'individual'], + help='Scale mask gradients by weights') +FLAGS.add_argument('--threshold_fn', + choices=['binarizer', 'ternarizer'], + help='Type of thresholding function') +# Paths. +FLAGS.add_argument('--dataset', type=str, default='', + help='Name of dataset') +FLAGS.add_argument('--train_path', type=str, default='', + help='Location of train data') +FLAGS.add_argument('--test_path', type=str, default='', + help='Location of test data') +FLAGS.add_argument('--save_prefix', type=str, default='../checkpoints/', + help='Location to save model') +FLAGS.add_argument('--loadname', type=str, default='', + help='Location to save model') +# Other. +FLAGS.add_argument('--cuda', action='store_true', default=True, + help='use CUDA') +FLAGS.add_argument('--no_mask', action='store_true', default=False, + help='Used for running baselines, does not use any masking') + + +class Manager(object): + """Handles training and pruning.""" + + def __init__(self, args, model): + self.args = args + self.cuda = args.cuda + self.model = model + + # Set up data loader, criterion, and pruner. + if 'cropped' in args.train_path: + train_loader = dataset.train_loader_cropped + test_loader = dataset.test_loader_cropped + else: + train_loader = dataset.train_loader + test_loader = dataset.test_loader + self.train_data_loader = train_loader( + args.train_path, args.batch_size, pin_memory=args.cuda) + self.test_data_loader = test_loader( + args.test_path, args.batch_size, pin_memory=args.cuda) + self.criterion = nn.CrossEntropyLoss() + + def eval(self): + """Performs evaluation.""" + self.model.eval() + error_meter = None + + print('Performing eval...') + for batch, label in tqdm(self.test_data_loader, desc='Eval'): + if self.cuda: + batch = batch.cuda() + batch = Variable(batch, volatile=True) + + output = self.model(batch) + + # Init error meter. + if error_meter is None: + topk = [1] + if output.size(1) > 5: + topk.append(5) + error_meter = tnt.meter.ClassErrorMeter(topk=topk) + error_meter.add(output.data, label) + + errors = error_meter.value() + print('Error: ' + ', '.join('@%s=%.2f' % + t for t in zip(topk, errors))) + + if 'train_bn' in self.args: + if self.args.train_bn: + self.model.train() + else: + self.model.train_nobn() + else: + print('args does not have train_bn flag, probably in eval-only mode.') + return errors + + def do_batch(self, optimizer, batch, label): + """Runs model for one batch.""" + if self.cuda: + batch = batch.cuda() + label = label.cuda() + batch = Variable(batch) + label = Variable(label) + + # Set grads to 0. + self.model.zero_grad() + + # Do forward-backward. + output = self.model(batch) + self.criterion(output, label).backward() + + # Scale gradients by average weight magnitude. + if self.args.mask_scale_gradients != 'none': + for module in self.model.shared.modules(): + if 'ElementWise' in str(type(module)): + abs_weights = module.weight.data.abs() + if self.args.mask_scale_gradients == 'average': + module.mask_real.grad.data.div_(abs_weights.mean()) + elif self.args.mask_scale_gradients == 'individual': + module.mask_real.grad.data.div_(abs_weights) + + # Set batchnorm grads to 0, if required. + if not self.args.train_bn: + for module in self.model.shared.modules(): + if 'BatchNorm' in str(type(module)): + if module.weight.grad is not None: + module.weight.grad.data.fill_(0) + if module.bias.grad is not None: + module.bias.grad.data.fill_(0) + + # Update params. + optimizer.step() + + def do_epoch(self, epoch_idx, optimizer): + """Trains model for one epoch.""" + for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx)): + self.do_batch(optimizer, batch, label) + + if self.args.threshold_fn == 'binarizer': + print('Num 0ed out parameters:') + for idx, module in enumerate(self.model.shared.modules()): + if 'ElementWise' in str(type(module)): + num_zero = module.mask_real.data.lt(5e-3).sum() + total = module.mask_real.data.numel() + print(idx, num_zero, total) + elif self.args.threshold_fn == 'ternarizer': + print('Num -1, 0ed out parameters:') + for idx, module in enumerate(self.model.shared.modules()): + if 'ElementWise' in str(type(module)): + num_neg = module.mask_real.data.lt(0).sum() + num_zero = module.mask_real.data.lt(5e-3).sum() - num_neg + total = module.mask_real.data.numel() + print(idx, num_neg, num_zero, total) + print('-' * 20) + + def save_model(self, epoch, best_accuracy, errors, savename): + """Saves model to file.""" + # Prepare the ckpt. + ckpt = { + 'args': self.args, + 'epoch': epoch, + 'accuracy': best_accuracy, + 'errors': errors, + 'model': self.model, + } + + # Save to file. + torch.save(ckpt, savename) + + def train(self, epochs, optimizer, save=True, savename='', best_accuracy=0): + """Performs training.""" + best_accuracy = best_accuracy + error_history = [] + + if self.args.cuda: + self.model = self.model.cuda() + + self.eval() + + for idx in range(epochs): + epoch_idx = idx + 1 + print('Epoch: %d' % (epoch_idx)) + + optimizer.update_lr(epoch_idx) + if self.args.train_bn: + self.model.train() + else: + self.model.train_nobn() + self.do_epoch(epoch_idx, optimizer) + errors = self.eval() + error_history.append(errors) + accuracy = 100 - errors[0] # Top-1 accuracy. + + # Save performance history and stats. + with open(savename + '.json', 'w') as fout: + json.dump({ + 'error_history': error_history, + 'args': vars(self.args), + }, fout) + + # Save best model, if required. + if save and accuracy > best_accuracy: + print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' % + (best_accuracy, accuracy)) + best_accuracy = accuracy + self.save_model(epoch_idx, best_accuracy, errors, savename) + + # Make sure masking didn't change any weights. + if not self.args.no_mask: + self.check() + print('Finished finetuning...') + print('Best error/accuracy: %0.2f%%, %0.2f%%' % + (100 - best_accuracy, best_accuracy)) + print('-' * 16) + + def check(self): + """Makes sure that the trained model weights match those of the pretrained model.""" + print('Making sure filter weights have not changed.') + if self.args.arch == 'vgg16': + pretrained = net.ModifiedVGG16(original=True) + elif self.args.arch == 'vgg16bn': + pretrained = net.ModifiedVGG16BN(original=True) + elif self.args.arch == 'resnet50': + pretrained = net.ModifiedResNet(original=True) + elif self.args.arch == 'densenet121': + pretrained = net.ModifiedDenseNet(original=True) + elif self.args.arch == 'resnet50_diff': + pretrained = net.ResNetDiffInit(self.args.source, original=True) + else: + raise ValueError('Architecture %s not supported.' % + (self.args.arch)) + + for module, module_pretrained in zip(self.model.shared.modules(), pretrained.shared.modules()): + if 'ElementWise' in str(type(module)) or 'BatchNorm' in str(type(module)): + weight = module.weight.data.cpu() + weight_pretrained = module_pretrained.weight.data.cpu() + # Using small threshold of 1e-8 for any floating point inconsistencies. + # Note that threshold per element is even smaller as the 1e-8 threshold + # is for sum of absolute differences. + assert (weight - weight_pretrained).abs().sum() < 1e-8, \ + 'module %s failed check' % (module) + if module.bias is not None: + bias = module.bias.data.cpu() + bias_pretrained = module_pretrained.bias.data.cpu() + assert (bias - bias_pretrained).abs().sum() < 1e-8 + if 'BatchNorm' in str(type(module)): + rm = module.running_mean.cpu() + rm_pretrained = module_pretrained.running_mean.cpu() + assert (rm - rm_pretrained).abs().sum() < 1e-8 + rv = module.running_var.cpu() + rv_pretrained = module_pretrained.running_var.cpu() + assert (rv - rv_pretrained).abs().sum() < 1e-8 + print('Passed checks...') + + +class Optimizers(object): + """Handles a list of optimizers.""" + + def __init__(self, args): + self.optimizers = [] + self.lrs = [] + self.decay_every = [] + self.args = args + + def add(self, optimizer, lr, decay_every): + """Adds optimizer to list.""" + self.optimizers.append(optimizer) + self.lrs.append(lr) + self.decay_every.append(decay_every) + + def step(self): + """Makes all optimizers update their params.""" + for optimizer in self.optimizers: + optimizer.step() + + def update_lr(self, epoch_idx): + """Update learning rate of every optimizer.""" + for optimizer, init_lr, decay_every in zip(self.optimizers, self.lrs, self.decay_every): + optimizer = utils.step_lr( + epoch_idx, init_lr, decay_every, + self.args.lr_decay_factor, optimizer) + + +def main(): + """Do stuff.""" + args = FLAGS.parse_args() + + # Set default train and test path if not provided as input. + utils.set_dataset_paths(args) + + # Load the required model. + if args.arch == 'vgg16': + model = net.ModifiedVGG16(mask_init=args.mask_init, + mask_scale=args.mask_scale, + threshold_fn=args.threshold_fn, + original=args.no_mask) + elif args.arch == 'vgg16bn': + model = net.ModifiedVGG16BN(mask_init=args.mask_init, + mask_scale=args.mask_scale, + threshold_fn=args.threshold_fn, + original=args.no_mask) + elif args.arch == 'resnet50': + model = net.ModifiedResNet(mask_init=args.mask_init, + mask_scale=args.mask_scale, + threshold_fn=args.threshold_fn, + original=args.no_mask) + elif args.arch == 'densenet121': + model = net.ModifiedDenseNet(mask_init=args.mask_init, + mask_scale=args.mask_scale, + threshold_fn=args.threshold_fn, + original=args.no_mask) + elif args.arch == 'resnet50_diff': + assert args.source + model = net.ResNetDiffInit(args.source, + mask_init=args.mask_init, + mask_scale=args.mask_scale, + threshold_fn=args.threshold_fn, + original=args.no_mask) + else: + raise ValueError('Architecture %s not supported.' % (args.arch)) + + # Add and set the model dataset. + model.add_dataset(args.dataset, args.num_outputs) + model.set_dataset(args.dataset) + if args.cuda: + model = model.cuda() + + # Initialize with weight based method, if necessary. + if not args.no_mask and args.mask_init == 'weight_based_1s': + print('Are you sure you want to try this?') + assert args.mask_scale_gradients == 'none' + assert not args.mask_scale + for idx, module in enumerate(model.shared.modules()): + if 'ElementWise' in str(type(module)): + weight_scale = module.weight.data.abs().mean() + module.mask_real.data.fill_(weight_scale) + + # Create the manager object. + manager = Manager(args, model) + + # Perform necessary mode operations. + if args.mode == 'finetune': + if args.no_mask: + # No masking will be done, used to run baselines of + # Classifier-Only and Individual Networks. + # Checks. + assert args.lr and args.lr_decay_every + assert not args.lr_mask and not args.lr_mask_decay_every + assert not args.lr_classifier and not args.lr_classifier_decay_every + print('No masking, running baselines.') + + # Get optimizer with correct params. + if args.finetune_layers == 'all': + params_to_optimize = model.parameters() + elif args.finetune_layers == 'classifier': + for param in model.shared.parameters(): + param.requires_grad = False + params_to_optimize = model.classifier.parameters() + + # optimizer = optim.Adam(params_to_optimize, lr=args.lr) + optimizer = optim.SGD(params_to_optimize, lr=args.lr, + momentum=0.9, weight_decay=args.weight_decay) + optimizers = Optimizers(args) + optimizers.add(optimizer, args.lr, args.lr_decay_every) + manager.train(args.finetune_epochs, optimizers, + save=True, savename=args.save_prefix) + else: + # Masking will be done. + # Checks. + assert not args.lr and not args.lr_decay_every + assert args.lr_mask and args.lr_mask_decay_every + assert args.lr_classifier and args.lr_classifier_decay_every + print('Performing masking.') + + optimizer_masks = optim.Adam( + model.shared.parameters(), lr=args.lr_mask) + optimizer_classifier = optim.Adam( + model.classifier.parameters(), lr=args.lr_classifier) + + optimizers = Optimizers(args) + optimizers.add(optimizer_masks, args.lr_mask, + args.lr_mask_decay_every) + optimizers.add(optimizer_classifier, args.lr_classifier, + args.lr_classifier_decay_every) + manager.train(args.finetune_epochs, optimizers, + save=True, savename=args.save_prefix) + elif args.mode == 'eval': + # Just run the model on the eval set. + manager.eval() + elif args.mode == 'check': + manager.check() + + +if __name__ == '__main__': + main() diff --git a/src/modnets/__init__.py b/src/modnets/__init__.py new file mode 100644 index 0000000..3eb6b80 --- /dev/null +++ b/src/modnets/__init__.py @@ -0,0 +1,3 @@ +from .vgg import * +from .resnet import * +from .densenet import * diff --git a/src/modnets/densenet.py b/src/modnets/densenet.py new file mode 100644 index 0000000..a6abc93 --- /dev/null +++ b/src/modnets/densenet.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from collections import OrderedDict + +import modnets.layers as nl + +__all__ = ['DenseNet', 'densenet121'] + + +def densenet121(mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" `_ + """ + model = DenseNet(mask_init, mask_scale, threshold_fn, + num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) + return model + + +class _DenseLayer(nn.Sequential): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, mask_init, mask_scale, threshold_fn): + super(_DenseLayer, self).__init__() + self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu.1', nn.ReLU(inplace=True)), + self.add_module('conv.1', nl.ElementWiseConv2d(num_input_features, bn_size * + growth_rate, + mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn, + kernel_size=1, stride=1, bias=False)), + self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu.2', nn.ReLU(inplace=True)), + self.add_module('conv.2', nl.ElementWiseConv2d(bn_size * growth_rate, growth_rate, + mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn, + kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + + def forward(self, x): + new_features = super(_DenseLayer, self).forward(x) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training) + return torch.cat([x, new_features], 1) + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, mask_init, mask_scale, threshold_fn): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(num_input_features + i * + growth_rate, growth_rate, bn_size, drop_rate, + mask_init, mask_scale, threshold_fn) + self.add_module('denselayer%d' % (i + 1), layer) + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features, mask_init, mask_scale, threshold_fn): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nl.ElementWiseConv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False, + mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + """ + + def __init__(self, mask_init, mask_scale, threshold_fn, growth_rate=32, + block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, + drop_rate=0, num_classes=1000): + + super(DenseNet, self).__init__() + + # First convolution + self.features = nn.Sequential(OrderedDict([ + ('conv0', nl.ElementWiseConv2d(3, num_init_features, + mask_init=mask_init, + mask_scale=mask_scale, + threshold_fn=threshold_fn, + kernel_size=7, + stride=2, padding=3, bias=False)), + ('norm0', nn.BatchNorm2d(num_init_features)), + ('relu0', nn.ReLU(inplace=True)), + ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, + bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition( + num_input_features=num_features, num_output_features=num_features // 2, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + + # Linear layer + self.classifier = nn.Linear(num_features, num_classes) + + def forward(self, x): + features = self.features(x) + out = F.relu(features, inplace=True) + out = F.avg_pool2d(out, kernel_size=7, stride=1).view( + features.size(0), -1) + out = self.classifier(out) + return out diff --git a/src/modnets/layers.py b/src/modnets/layers.py new file mode 100644 index 0000000..1d966ed --- /dev/null +++ b/src/modnets/layers.py @@ -0,0 +1,229 @@ +"""Contains novel layer definitions.""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.nn.modules.utils import _pair +from torch.nn.parameter import Parameter + +DEFAULT_THRESHOLD = 5e-3 + + +class Binarizer(torch.autograd.Function): + """Binarizes {0, 1} a real valued tensor.""" + + def __init__(self, threshold=DEFAULT_THRESHOLD): + super(Binarizer, self).__init__() + self.threshold = threshold + + def forward(self, inputs): + outputs = inputs.clone() + outputs[inputs.le(self.threshold)] = 0 + outputs[inputs.gt(self.threshold)] = 1 + return outputs + + def backward(self, gradOutput): + return gradOutput + + +class Ternarizer(torch.autograd.Function): + """Ternarizes {-1, 0, 1} a real valued tensor.""" + + def __init__(self, threshold=DEFAULT_THRESHOLD): + super(Ternarizer, self).__init__() + self.threshold = threshold + + def forward(self, inputs): + outputs = inputs.clone() + outputs.fill_(0) + outputs[inputs < 0] = -1 + outputs[inputs > self.threshold] = 1 + return outputs + + def backward(self, gradOutput): + return gradOutput + + +class ElementWiseConv2d(nn.Module): + """Modified conv with masks for weights.""" + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + mask_init='1s', mask_scale=1e-2, + threshold_fn='binarizer', threshold=None): + super(ElementWiseConv2d, self).__init__() + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + self.mask_scale = mask_scale + self.mask_init = mask_init + + if threshold is None: + threshold = DEFAULT_THRESHOLD + self.info = { + 'threshold_fn': threshold_fn, + 'threshold': threshold, + } + + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = False + self.output_padding = _pair(0) + self.groups = groups + + # weight and bias are no longer Parameters. + self.weight = Variable(torch.Tensor( + out_channels, in_channels // groups, *kernel_size), requires_grad=False) + if bias: + self.bias = Variable(torch.Tensor( + out_channels), requires_grad=False) + else: + self.register_parameter('bias', None) + + # Initialize real-valued mask weights. + self.mask_real = self.weight.data.new(self.weight.size()) + if mask_init == '1s': + self.mask_real.fill_(mask_scale) + elif mask_init == 'uniform': + self.mask_real.uniform_(-1 * mask_scale, mask_scale) + # mask_real is now a trainable parameter. + self.mask_real = Parameter(self.mask_real) + + # Initialize the thresholder. + if threshold_fn == 'binarizer': + print('Calling binarizer with threshold:', threshold) + self.threshold_fn = Binarizer(threshold=threshold) + elif threshold_fn == 'ternarizer': + print('Calling ternarizer with threshold:', threshold) + self.threshold_fn = Ternarizer(threshold=threshold) + + def forward(self, input): + # Get binarized/ternarized mask from real-valued mask. + mask_thresholded = self.threshold_fn(self.mask_real) + # Mask weights with above mask. + weight_thresholded = mask_thresholded * self.weight + # Perform conv using modified weight. + return F.conv2d(input, weight_thresholded, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def __repr__(self): + s = ('{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + s += ')' + return s.format(name=self.__class__.__name__, **self.__dict__) + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + for param in self._parameters.values(): + if param is not None: + # Variables stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + self.weight.data = fn(self.weight.data) + if self.bias is not None and self.bias.data is not None: + self.bias.data = fn(self.bias.data) + + +class ElementWiseLinear(nn.Module): + """Modified linear layer.""" + + def __init__(self, in_features, out_features, bias=True, + mask_init='1s', mask_scale=1e-2, + threshold_fn='binarizer', threshold=None): + super(ElementWiseLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.threshold_fn = threshold_fn + self.mask_scale = mask_scale + self.mask_init = mask_init + + if threshold is None: + threshold = DEFAULT_THRESHOLD + self.info = { + 'threshold_fn': threshold_fn, + 'threshold': threshold, + } + + # weight and bias are no longer Parameters. + self.weight = Variable(torch.Tensor( + out_features, in_features), requires_grad=False) + if bias: + self.bias = Variable(torch.Tensor( + out_features), requires_grad=False) + else: + self.register_parameter('bias', None) + + # Initialize real-valued mask weights. + self.mask_real = self.weight.data.new(self.weight.size()) + if mask_init == '1s': + self.mask_real.fill_(mask_scale) + elif mask_init == 'uniform': + self.mask_real.uniform_(-1 * mask_scale, mask_scale) + # mask_real is now a trainable parameter. + self.mask_real = Parameter(self.mask_real) + + # Initialize the thresholder. + if threshold_fn == 'binarizer': + self.threshold_fn = Binarizer(threshold=threshold) + elif threshold_fn == 'ternarizer': + self.threshold_fn = Ternarizer(threshold=threshold) + + def forward(self, input): + # Get binarized/ternarized mask from real-valued mask. + mask_thresholded = self.threshold_fn(self.mask_real) + # Mask weights with above mask. + weight_thresholded = mask_thresholded * self.weight + # Get output using modified weight. + return F.linear(input, weight_thresholded, self.bias) + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + 'in_features=' + str(self.in_features) \ + + ', out_features=' + str(self.out_features) + ')' + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + for param in self._parameters.values(): + if param is not None: + # Variables stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + self.weight.data = fn(self.weight.data) + self.bias.data = fn(self.bias.data) diff --git a/src/modnets/resnet.py b/src/modnets/resnet.py new file mode 100644 index 0000000..dd856e1 --- /dev/null +++ b/src/modnets/resnet.py @@ -0,0 +1,168 @@ +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo + +import modnets.layers as nl + +__all__ = ['ResNet', 'resnet50'] + + +def conv3x3(in_planes, out_planes, mask_init, mask_scale, threshold_fn, stride=1): + "3x3 convolution with padding" + return nl.ElementWiseConv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False, mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, mask_init, mask_scale, threshold_fn, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, mask_init, + mask_scale, threshold_fn, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, mask_init, + threshold_fn, mask_scale) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, mask_init, mask_scale, threshold_fn, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nl.ElementWiseConv2d( + inplanes, planes, kernel_size=1, bias=False, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nl.ElementWiseConv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nl.ElementWiseConv2d( + planes, planes * 4, kernel_size=1, bias=False, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, mask_init, mask_scale, threshold_fn, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nl.ElementWiseConv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], mask_init, mask_scale, threshold_fn) + self.layer2 = self._make_layer( + block, 128, layers[1], mask_init, mask_scale, threshold_fn, stride=2) + self.layer3 = self._make_layer( + block, 256, layers[2], mask_init, mask_scale, threshold_fn, stride=2) + self.layer4 = self._make_layer( + block, 512, layers[3], mask_init, mask_scale, threshold_fn, stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nl.ElementWiseConv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, mask_init, mask_scale, threshold_fn, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nl.ElementWiseConv2d( + self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False, + mask_init=mask_init, mask_scale=mask_scale, threshold_fn=threshold_fn), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, mask_init, + mask_scale, threshold_fn, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, + mask_init, mask_scale, threshold_fn)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet50(mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', **kwargs): + """Constructs a ResNet-50 model.""" + model = ResNet(Bottleneck, [3, 4, 6, 3], mask_init, + mask_scale, threshold_fn, **kwargs) + return model diff --git a/src/modnets/vgg.py b/src/modnets/vgg.py new file mode 100644 index 0000000..9630e7b --- /dev/null +++ b/src/modnets/vgg.py @@ -0,0 +1,76 @@ +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import math + +import modnets.layers as nl + +__all__ = [ + 'VGG', 'vgg16', 'vgg16_bn' +] + + +class VGG(nn.Module): + + def __init__(self, features, mask_init, mask_scale, threshold_fn, num_classes=1000): + super(VGG, self).__init__() + self.features = features + self.classifier = nn.Sequential( + nl.ElementWiseLinear( + 512 * 7 * 7, 4096, mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn), + nn.ReLU(True), + nn.Dropout(), + nl.ElementWiseLinear( + 4096, 4096, mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +def make_layers(cfg, mask_init, mask_scale, threshold_fn, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nl.ElementWiseConv2d( + in_channels, v, kernel_size=3, padding=1, + mask_init=mask_init, mask_scale=mask_scale, + threshold_fn=threshold_fn) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', **kwargs): + """VGG 16-layer model (configuration "D").""" + model = VGG(make_layers(cfg['D'], mask_init, mask_scale, threshold_fn), + mask_init, mask_scale, threshold_fn, **kwargs) + return model + + +def vgg16_bn(mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', **kwargs): + """VGG 16-layer model (configuration "D") with batch normalization.""" + model = VGG(make_layers(cfg['D'], mask_init, mask_scale, threshold_fn, batch_norm=True), + mask_init, mask_scale, threshold_fn, **kwargs) + return model diff --git a/src/networks.py b/src/networks.py new file mode 100644 index 0000000..5cd5d69 --- /dev/null +++ b/src/networks.py @@ -0,0 +1,330 @@ +"""Contains various network definitions.""" +from __future__ import division, print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torchvision import models + +import modnets +import modnets.layers as nl + + +class View(nn.Module): + """Changes view using a nn.Module.""" + + def __init__(self, *shape): + super(View, self).__init__() + self.shape = shape + + def forward(self, input): + return input.view(*self.shape) + + +class ModifiedVGG16(nn.Module): + """VGG16 with support for multiple classifiers.""" + + def __init__(self, mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', + make_model=True, original=False): + super(ModifiedVGG16, self).__init__() + if make_model: + self.make_model(mask_init, mask_scale, threshold_fn, original) + + def make_model(self, mask_init, mask_scale, threshold_fn, original): + """Creates the model.""" + if original: + vgg16 = models.vgg16(pretrained=True) + print('Creating model: No mask layers.') + else: + # Get the one with masks and pretrained model. + vgg16 = modnets.vgg16(mask_init, mask_scale, threshold_fn) + vgg16_pretrained = models.vgg16(pretrained=True) + # Copy weights from the pretrained to the modified model. + for module, module_pretrained in zip(vgg16.modules(), vgg16_pretrained.modules()): + if 'ElementWise' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + print('Creating model: Mask layers created.') + + self.datasets, self.classifiers = [], nn.ModuleList() + + idx = 6 + for module in vgg16.classifier.children(): + if isinstance(module, (nn.Linear, nl.ElementWiseLinear)): + if idx == 6: + fc6 = module + elif idx == 7: + fc7 = module + elif idx == 8: + self.datasets.append('imagenet') + self.classifiers.append(module) + idx += 1 + features = list(vgg16.features.children()) + features.extend([ + View(-1, 25088), + fc6, + nn.ReLU(inplace=True), + nn.Dropout(), + fc7, + nn.ReLU(inplace=True), + nn.Dropout(), + ]) + + # Shared params are those which are common amongst all classes. + self.shared = nn.Sequential(*features) + + # model.set_dataset() has to be called explicity, else model won't work. + self.classifier = None + + def add_dataset(self, dataset, num_outputs): + """Adds a new dataset to the classifier.""" + if dataset not in self.datasets: + self.datasets.append(dataset) + self.classifiers.append(nn.Linear(4096, num_outputs)) + + def set_dataset(self, dataset): + """Change the active classifier.""" + assert dataset in self.datasets + self.classifier = self.classifiers[self.datasets.index(dataset)] + + def train_nobn(self, mode=True): + """Override the default module train.""" + super(ModifiedVGG16, self).train(mode) + + # Set the BNs to eval mode so that the running means and averages + # do not update. + for module in self.shared.modules(): + if 'BatchNorm' in str(type(module)): + module.eval() + + def forward(self, x): + x = self.shared(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +class ModifiedVGG16BN(ModifiedVGG16): + """VGG16 with support for multiple classifiers.""" + + def __init__(self, mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', + make_model=True, original=False): + super(ModifiedVGG16BN, self).__init__(make_model=False) + if make_model: + self.make_model(mask_init, mask_scale, threshold_fn, original) + + def make_model(self, mask_init, mask_scale, threshold_fn, original): + """Creates the model.""" + if original: + vgg16_bn = models.vgg16_bn(pretrained=True) + print('Creating model: No mask layers.') + else: + # Get the one with masks and pretrained model. + vgg16_bn = modnets.vgg16_bn(mask_init, mask_scale, threshold_fn) + vgg16_bn_pretrained = models.vgg16_bn(pretrained=True) + # Copy weights from the pretrained to the modified model. + for module, module_pretrained in zip(vgg16_bn.modules(), vgg16_bn_pretrained.modules()): + if 'ElementWise' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + elif 'BatchNorm' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + module.running_mean.copy_(module_pretrained.running_mean) + module.running_var.copy_(module_pretrained.running_var) + print('Creating model: Mask layers created.') + + self.datasets, self.classifiers = [], nn.ModuleList() + + idx = 6 + for module in vgg16_bn.classifier.children(): + if isinstance(module, (nn.Linear, nl.ElementWiseLinear)): + if idx == 6: + fc6 = module + elif idx == 7: + fc7 = module + elif idx == 8: + self.datasets.append('imagenet') + self.classifiers.append(module) + idx += 1 + features = list(vgg16_bn.features.children()) + features.extend([ + View(-1, 25088), + fc6, + nn.ReLU(inplace=True), + nn.Dropout(), + fc7, + nn.ReLU(inplace=True), + nn.Dropout(), + ]) + + # Shared params are those which are common amongst all classes. + self.shared = nn.Sequential(*features) + + # model.set_dataset() has to be called explicity, else model won't work. + self.classifier = None + + +class ModifiedResNet(ModifiedVGG16): + """ResNet-50.""" + + def __init__(self, mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', + make_model=True, original=False): + super(ModifiedResNet, self).__init__(make_model=False) + if make_model: + self.make_model(mask_init, mask_scale, threshold_fn, original) + + def make_model(self, mask_init, mask_scale, threshold_fn, original): + """Creates the model.""" + if original: + resnet = models.resnet50(pretrained=True) + print('Creating model: No mask layers.') + else: + # Get the one with masks and pretrained model. + resnet = modnets.resnet50(mask_init, mask_scale, threshold_fn) + resnet_pretrained = models.resnet50(pretrained=True) + # Copy weights from the pretrained to the modified model. + for module, module_pretrained in zip(resnet.modules(), resnet_pretrained.modules()): + if 'ElementWise' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + if module.bias: + module.bias.data.copy_(module_pretrained.bias.data) + elif 'BatchNorm' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + module.running_mean.copy_(module_pretrained.running_mean) + module.running_var.copy_(module_pretrained.running_var) + print('Creating model: Mask layers created.') + + self.datasets, self.classifiers = [], nn.ModuleList() + + # Create the shared feature generator. + self.shared = nn.Sequential() + for name, module in resnet.named_children(): + if name != 'fc': + self.shared.add_module(name, module) + + # Add the default imagenet classifier. + self.datasets.append('imagenet') + self.classifiers.append(resnet.fc) + + # model.set_dataset() has to be called explicity, else model won't work. + self.classifier = None + + def add_dataset(self, dataset, num_outputs): + """Adds a new dataset to the classifier.""" + if dataset not in self.datasets: + self.datasets.append(dataset) + self.classifiers.append(nn.Linear(2048, num_outputs)) + + +class ModifiedDenseNet(ModifiedVGG16): + """DenseNet-121.""" + + def __init__(self, mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', + make_model=True, original=False): + super(ModifiedDenseNet, self).__init__(make_model=False) + if make_model: + self.make_model(mask_init, mask_scale, threshold_fn, original) + + def make_model(self, mask_init, mask_scale, threshold_fn, original): + """Creates the model.""" + if original: + densenet = models.densenet121(pretrained=True) + print('Creating model: No mask layers.') + else: + # Get the one with masks and pretrained model. + densenet = modnets.densenet121(mask_init, mask_scale, threshold_fn) + densenet_pretrained = models.densenet121(pretrained=True) + # Copy weights from the pretrained to the modified model. + for module, module_pretrained in zip(densenet.modules(), densenet_pretrained.modules()): + if 'ElementWise' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + if module.bias: + module.bias.data.copy_(module_pretrained.bias.data) + elif 'BatchNorm' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + module.running_mean.copy_(module_pretrained.running_mean) + module.running_var.copy_(module_pretrained.running_var) + print('Creating model: Mask layers created.') + + self.datasets, self.classifiers = [], nn.ModuleList() + + # Create the shared feature generator. + self.shared = densenet.features + + # Add the default imagenet classifier. + self.datasets.append('imagenet') + self.classifiers.append(densenet.classifier) + + # model.set_dataset() has to be called explicity, else model won't work. + self.classifier = None + + def forward(self, x): + features = self.shared(x) + out = F.relu(features, inplace=True) + out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1) + out = self.classifier(out) + return out + + def add_dataset(self, dataset, num_outputs): + """Adds a new dataset to the classifier.""" + if dataset not in self.datasets: + self.datasets.append(dataset) + self.classifiers.append(nn.Linear(1024, num_outputs)) + + +class ResNetDiffInit(ModifiedResNet): + """ResNet50 with non-ImageNet initialization.""" + + def __init__(self, source, mask_init='1s', mask_scale=1e-2, threshold_fn='binarizer', + make_model=True, original=False): + super(ResNetDiffInit, self).__init__(make_model=False) + if make_model: + self.make_model(source, mask_init, mask_scale, + threshold_fn, original) + + def make_model(self, source, mask_init, mask_scale, threshold_fn, original): + """Creates the model.""" + if original: + resnet = torch.load(source) + print('Loading model:', source) + else: + # Get the one with masks and pretrained model. + resnet = modnets.resnet50(mask_init, mask_scale, threshold_fn) + resnet_pretrained = torch.load(source) + # Copy weights from the pretrained to the modified model. + for module, module_pretrained in zip(resnet.modules(), resnet_pretrained.modules()): + if 'ElementWise' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + if module.bias: + module.bias.data.copy_(module_pretrained.bias.data) + elif 'BatchNorm' in str(type(module)): + module.weight.data.copy_(module_pretrained.weight.data) + module.bias.data.copy_(module_pretrained.bias.data) + module.running_mean.copy_(module_pretrained.running_mean) + module.running_var.copy_(module_pretrained.running_var) + print('Creating model: Mask layers created.') + + self.datasets, self.classifiers = [], nn.ModuleList() + + # Create the shared feature generator. + self.shared = nn.Sequential() + for name, module in resnet.named_children(): + if name != 'fc': + self.shared.add_module(name, module) + + # Add the default classifier. + if 'places' in source: + self.datasets.append('places') + elif 'imagenet' in source: + self.datasets.append('imagenet') + if original: + self.classifiers.append(resnet.fc) + else: + self.classifiers.append(resnet_pretrained.fc) + + # model.set_dataset() has to be called explicity, else model won't work. + self.classifier = None diff --git a/src/pack.py b/src/pack.py new file mode 100644 index 0000000..508882f --- /dev/null +++ b/src/pack.py @@ -0,0 +1,187 @@ +"""Packs binary masks only, showing that we're not changing pre-trained weights. + + Usage for packing masks only: + CUDA_VISIBLE_DEVICES=$GPU_ID python pack.py --mode pack \ + --packlist best_models/densenet121_binary.txt \ + --maskloc ../checkpoints/packed/densenet121_binary + + Usage for eval: + CUDA_VISIBLE_DEVICES=0 python pack.py --mode eval --arch densenet121 \ + --maskloc ../checkpoints/packed/densenet121_binary.pt \ + --dataset cubs_cropped +""" + +from __future__ import division, print_function + +import argparse + +import torch + +import networks as net +from main import Manager +import utils as utils + + +FLAGS = argparse.ArgumentParser() +FLAGS.add_argument('--mode', + choices=['pack', 'eval'], + help='Run mode') +# Packing arguments. +FLAGS.add_argument('--packlist', type=str, + help='File containing dataset:model per line') +# Eval arguments. +FLAGS.add_argument('--arch', + choices=['vgg16', 'vgg16bn', 'resnet50', + 'densenet121', 'resnet50_diff'], + help='Type of architecture') +FLAGS.add_argument('--source', type=str, default='', + help='Location to load model from for resnet50_diff') +FLAGS.add_argument('--maskloc', type=str, default='', + help='Location to save/load masks from') +FLAGS.add_argument('--dataset', type=str, default='', + help='Name of dataset') +FLAGS.add_argument('--train_path', type=str, default='', + help='Location of train data') +FLAGS.add_argument('--test_path', type=str, default='', + help='Location of test data') +FLAGS.add_argument('--batch_size', type=int, default=32, + help='Batch size') +# Other. +FLAGS.add_argument('--cuda', action='store_true', default=True, + help='use CUDA') + + +def main(): + """Do stuff.""" + args = FLAGS.parse_args() + + if args.mode == 'pack': + assert args.packlist and args.maskloc + dataset2masks = {} + dataset2classifiers = {} + net_type = None + + # Location to output stats. + fout = open(args.maskloc[:-2] + 'txt', 'w') + + # Load models one by one and store their masks. + fin = open(args.packlist, 'r') + counter = 1 + for idx, line in enumerate(fin): + if not line or not line.strip() or line[0] == '#': + continue + dataset, loadname = line.split(':') + loadname = loadname.strip() + + # Can't have same dataset twice. + if dataset in dataset2masks: + ValueError('Repeated datasets as input...') + print('Loading model #%d for dataset "%s"' % (counter, dataset)) + counter += 1 + ckpt = torch.load(loadname) + model = ckpt['model'] + # Ensure all inputs are for same model type. + if net_type is None: + net_type = str(type(model)) + else: + assert net_type == str(type(model)), '%s != %s' % ( + net_type, str(type(model))) + + # Gather masks and store in dictionary. + fout.write('Dataset: %s\n' % (dataset)) + total_params, neg_params, zerod_params = [], [], [] + masks = {} + for module_idx, module in enumerate(model.shared.modules()): + if 'ElementWise' in str(type(module)): + mask = module.threshold_fn(module.mask_real) + mask = mask.data.cpu() + + # Make sure mask values are in {0, 1} or {-1, 0, 1}. + num_zero = mask.eq(0).sum() + num_one = mask.eq(1).sum() + num_mone = mask.eq(-1).sum() + total = mask.numel() + threshold_type = module.threshold_fn.__class__.__name__ + if threshold_type == 'Binarizer': + assert num_mone == 0 + assert num_zero + num_one == total + elif threshold_type == 'Ternarizer': + assert num_mone + num_zero + num_one == total + masks[module_idx] = mask.type(torch.ByteTensor) + + # Count total and zerod out params. + total_params.append(total) + zerod_params.append(num_zero) + neg_params.append(num_mone) + fout.write('%d\t%.2f%%\t%.2f%%\n' % ( + module_idx, + neg_params[-1] / total_params[-1] * 100, + zerod_params[-1] / total_params[-1] * 100)) + print('Check Passed: Masks only have binary/ternary values.') + dataset2masks[dataset] = masks + dataset2classifiers[dataset] = model.classifier + + fout.write('Total -1: %d/%d = %.2f%%\n' % ( + sum(neg_params), sum(total_params), sum(neg_params) / sum(total_params) * 100)) + fout.write('Total 0: %d/%d = %.2f%%\n' % ( + sum(zerod_params), sum(total_params), sum(zerod_params) / sum(total_params) * 100)) + fout.write('-' * 20 + '\n') + + # Clean up and save masks to file. + fin.close() + fout.close() + torch.save({ + 'dataset2masks': dataset2masks, + 'dataset2classifiers': dataset2classifiers, + }, args.maskloc) + + elif args.mode == 'eval': + assert args.arch and args.maskloc and args.dataset + + # Set default train and test path if not provided as input. + utils.set_dataset_paths(args) + + # Load masks and classifier for this dataset. + info = torch.load(args.maskloc) + if args.dataset not in info['dataset2masks']: + ValueError('%s not found in masks.' % (args.dataset)) + masks = info['dataset2masks'][args.dataset] + classifier = info['dataset2classifiers'][args.dataset] + + # Create the vanilla model and apply masking. + model = None + if args.arch == 'vgg16': + model = net.ModifiedVGG16(original=True) + elif args.arch == 'vgg16bn': + model = net.ModifiedVGG16BN(original=True) + elif args.arch == 'resnet50': + model = net.ModifiedResNet(original=True) + elif args.arch == 'densenet121': + model = net.ModifiedDenseNet(original=True) + elif args.arch == 'resnet50_diff': + assert args.source + model = net.ResNetDiffInit(args.source, original=True) + model.eval() + + print('Applying masks.') + for module_idx, module in enumerate(model.shared.modules()): + if module_idx in masks: + mask = masks[module_idx] + module.weight.data[mask.eq(0)] = 0 + module.weight.data[mask.eq(-1)] *= -1 + print('Applied masks.') + + # Override model.classifier with saved one. + model.add_dataset(args.dataset, classifier.weight.size(0)) + model.set_dataset(args.dataset) + model.classifier = classifier + if args.cuda: + model = model.cuda() + + # Create the manager and run eval. + manager = Manager(args, model) + manager.eval() + + +if __name__ == '__main__': + main() diff --git a/src/scripts/run_baselines.sh b/src/scripts/run_baselines.sh new file mode 100755 index 0000000..2b49df2 --- /dev/null +++ b/src/scripts/run_baselines.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Runs the Classifier Only and Individual Model baselines. +# Usage: +# ./scripts/run_baselines.sh 3 3 vgg16 +# ./scripts/run_baselines.sh 3 3 vgg16bn + +# This is hard-coded to prevent silly mistakes. +declare -A NUM_OUTPUTS +NUM_OUTPUTS["imagenet"]="1000" +NUM_OUTPUTS["places"]="365" +NUM_OUTPUTS["stanford_cars_cropped"]="196" +NUM_OUTPUTS["cubs_cropped"]="200" +NUM_OUTPUTS["flowers"]="102" +NUM_OUTPUTS["wikiart"]="195" +NUM_OUTPUTS["sketches"]="250" + +GPU_ID=$1 +NUM_RUNS=$2 +ARCH=$3 +LR=1e-3 + +for RUN_ID in `seq 1 $NUM_RUNS`; +do + for DATASET in stanford_cars_cropped cubs_cropped flowers wikiart sketches; do + mkdir ../checkpoints/$DATASET + mkdir ../logs/$DATASET + + # for FT_LAYERS in classifier all; do + for FT_LAYERS in all; do + if [ "$FT_LAYERS" == "classifier" ]; then + LR_DECAY_EVERY=30 + TRAIN_BN='' + else + LR_DECAY_EVERY=15 + TRAIN_BN='--train_bn' + fi + + LOG_DIR=../logs/$DATASET/'final_'$FT_LAYERS + mkdir $LOG_DIR + CKPT_DIR=../checkpoints/$DATASET/'final_'$FT_LAYERS + mkdir $CKPT_DIR + TAG=$ARCH'_SGD_lr'$LR'_lrdecay'$LR_DECAY_EVERY'_'$RUN_ID + + CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --mode finetune \ + --arch $ARCH \ + --dataset $DATASET --num_outputs ${NUM_OUTPUTS[$DATASET]} \ + --no_mask --finetune_layers $FT_LAYERS $TRAIN_BN \ + --lr $LR --lr_decay_every $LR_DECAY_EVERY \ + --lr_decay_factor 0.1 --finetune_epochs 30 \ + --save_prefix $CKPT_DIR'/'$TAG'.pt' | tee $LOG_DIR'/'$TAG'.txt' + done + done +done diff --git a/src/scripts/run_packing.sh b/src/scripts/run_packing.sh new file mode 100755 index 0000000..8036d54 --- /dev/null +++ b/src/scripts/run_packing.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Stores only the masks after training. +# Usage: +# ./scripts/run_packing.sh 0 vgg16 binary + +GPU_ID=$1 +NAME=$2 +TYPE=$3 + +CUDA_VISIBLE_DEVICES=$GPU_ID python pack.py --mode pack \ + --packlist best_models/$NAME'_'$TYPE'.txt' \ + --maskloc ../checkpoints/packed/$NAME'_'$TYPE'.pt' diff --git a/src/scripts/run_piggyback_training.sh b/src/scripts/run_piggyback_training.sh new file mode 100755 index 0000000..b88156d --- /dev/null +++ b/src/scripts/run_piggyback_training.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Runs the piggyback method using default settings. +# Usage: +# ./scripts/run_mask_constant.sh vgg16 3 1 binarizer + +# This is hard-coded to prevent silly mistakes. +declare -A NUM_OUTPUTS +NUM_OUTPUTS["imagenet"]="1000" +NUM_OUTPUTS["places"]="365" +NUM_OUTPUTS["stanford_cars_cropped"]="196" +NUM_OUTPUTS["cubs_cropped"]="200" +NUM_OUTPUTS["flowers"]="102" +NUM_OUTPUTS["wikiart"]="195" +NUM_OUTPUTS["sketches"]="250" + +ARCH=$1 +GPU_ID=$2 +NUM_RUNS=$3 +THRESHOLD_FN=$4 +MASK_SCALE=1e-2 +MASK_SCALE_GRADS=none +LR_MASK=1e-4 +LR_CLASS=1e-4 +MASK_DECAY=15 +CLASS_DECAY=15 +NUM_EPOCHS=30 + +for RUN_ID in `seq 1 $NUM_RUNS`; +do + for DATASET in stanford_cars_cropped cubs_cropped flowers wikiart sketches; do + mkdir ../checkpoints/$DATASET + mkdir ../logs/$DATASET + + TAG=$ARCH'_'$THRESHOLD_FN'_maskscale'$MASK_SCALE'-'$MASK_SCALE_GRADS'_lr'$LR_MASK'-'$LR_CLASS'_decay'$MASK_DECAY'-'$CLASS_DECAY'_'$RUN_ID + + CKPT_DIR='../checkpoints/'$DATASET'/final/' + mkdir $CKPT_DIR + LOG_DIR='../logs/'$DATASET'/final/' + mkdir $LOG_DIR + + CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --mode finetune \ + --arch $ARCH --threshold_fn $THRESHOLD_FN \ + --mask_scale $MASK_SCALE --mask_scale_gradients $MASK_SCALE_GRADS \ + --dataset $DATASET --num_outputs ${NUM_OUTPUTS[$DATASET]} '--mask_adam' \ + --lr_mask $LR_MASK --lr_mask_decay_every $MASK_DECAY \ + --lr_classifier $LR_CLASS --lr_classifier_decay_every $CLASS_DECAY \ + --lr_decay_factor 0.1 --finetune_epochs $NUM_EPOCHS \ + --save_prefix $CKPT_DIR$TAG'.pt' | tee $LOG_DIR$TAG'.txt' + done +done \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..ab430ab --- /dev/null +++ b/src/utils.py @@ -0,0 +1,24 @@ +"""Contains a bunch of utility functions.""" + +import numpy as np + + +def step_lr(epoch, base_lr, lr_decay_every, lr_decay_factor, optimizer): + """Handles step decay of learning rate.""" + factor = np.power(lr_decay_factor, np.floor((epoch - 1) / lr_decay_every)) + new_lr = base_lr * factor + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + print('Set lr to ', new_lr) + return optimizer + + +def set_dataset_paths(args): + """Set default train and test path if not provided as input.""" + if not args.train_path: + args.train_path = '../data/%s/train' % (args.dataset) + if not args.test_path: + if args.dataset == 'imagenet' or args.dataset == 'places': + args.test_path = '../data/%s/val' % (args.dataset) + else: + args.test_path = '../data/%s/test' % (args.dataset)