-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
liusj17
committed
Dec 13, 2019
0 parents
commit c5054a7
Showing
26 changed files
with
3,401 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Network Slimming (Pytorch) | ||
|
||
This repository contains an official pytorch implementation for the following paper | ||
[Learning Efficient Convolutional Networks Through Network Slimming](http://openaccess.thecvf.com/content_iccv_2017/html/Liu_Learning_Efficient_Convolutional_ICCV_2017_paper.html) (ICCV 2017). | ||
[Zhuang Liu](https://liuzhuang13.github.io/), [Jianguo Li](https://sites.google.com/site/leeplus/), [Zhiqiang Shen](http://zhiqiangshen.com/), [Gao Huang](http://www.cs.cornell.edu/~gaohuang/), [Shoumeng Yan](https://scholar.google.com/citations?user=f0BtDUQAAAAJ&hl=en), [Changshui Zhang](http://bigeye.au.tsinghua.edu.cn/english/Introduction.html). | ||
|
||
Original implementation: [slimming](https://github.com/liuzhuang13/slimming) in Torch. | ||
The code is based on [pytorch-slimming](https://github.com/foolwood/pytorch-slimming). We add support for ResNet and DenseNet. | ||
|
||
Citation: | ||
``` | ||
@InProceedings{Liu_2017_ICCV, | ||
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui}, | ||
title = {Learning Efficient Convolutional Networks Through Network Slimming}, | ||
booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, | ||
month = {Oct}, | ||
year = {2017} | ||
} | ||
``` | ||
|
||
|
||
## Dependencies | ||
torch '0.4.1' or '1.0.1', torchvision v0.2.0 | ||
|
||
## Channel Selection Layer | ||
We introduce `channel selection` layer to help the pruning of ResNet and DenseNet. This layer is easy to implement. It stores a parameter `indexes` which is initialized to an all-1 vector. During pruning, it will set some places to 0 which correspond to the pruned channels. | ||
|
||
## Baseline | ||
|
||
The `dataset` argument specifies which dataset to use: `cifar10` or `cifar100`. The `arch` argument specifies the architecture to use: `vgg`,`resnet` or | ||
`densenet`. The depth is chosen to be the same as the networks used in the paper. | ||
```shell | ||
python main.py --dataset cifar10 --arch vgg --depth 19 | ||
python main.py --dataset cifar10 --arch resnet --depth 164 | ||
python main.py --dataset cifar10 --arch densenet --depth 40 | ||
``` | ||
|
||
## Train with Sparsity | ||
|
||
```shell | ||
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19 | ||
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164 | ||
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40 | ||
``` | ||
|
||
## Prune | ||
|
||
```shell | ||
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] | ||
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] | ||
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] | ||
``` | ||
The pruned model will be named `pruned.pth.tar`. | ||
|
||
## Fine-tune | ||
|
||
```shell | ||
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160 | ||
``` | ||
|
||
## Results | ||
|
||
The results are fairly close to the original paper, whose results are produced by Torch. Note that due to different random seeds, there might be up to ~0.5%/1.5% fluctation on CIFAR-10/100 datasets in different runs, according to our experiences. | ||
### CIFAR10 | ||
| CIFAR10-Vgg | Baseline | Sparsity (1e-4) | Prune (70%) | Fine-tune-160(70%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | | ||
| Top1 Accuracy (%) | 93.77 | 93.30 | 32.54 | 93.78 | | ||
| Parameters | 20.04M | 20.04M | 2.25M | 2.25M | | ||
|
||
| CIFAR10-Resnet-164 | Baseline | Sparsity (1e-5) | Prune(40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | :----------------:| :--------------------:| | ||
| Top1 Accuracy (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 | | ||
| Parameters | 1.71M | 1.73M | 1.45M | 1.45M | 1.12M | 1.12M | | ||
|
||
| CIFAR10-Densenet-40 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | :--------------------: | :-----------------:| | ||
| Top1 Accuracy (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 | | ||
| Parameters | 1.07M | 1.07M | 0.69M | 0.69M | 0.49M | 0.49M | | ||
|
||
### CIFAR100 | ||
| CIFAR100-Vgg | Baseline | Sparsity (1e-4) | Prune (50%) | Fine-tune-160(50%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | | ||
| Top1 Accuracy (%) | 72.12 | 72.05 | 5.31 | 73.32 | | ||
| Parameters | 20.04M | 20.04M | 4.93M | 4.93M | | ||
|
||
| CIFAR100-Resnet-164 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:| | ||
| Top1 Accuracy (%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- | | ||
| Parameters | 1.73M | 1.73M | 1.49M | 1.49M |--- | --- | | ||
|
||
Note: For results of pruning 60% of the channels for resnet164-cifar100, in this implementation, sometimes some layers are all pruned and there would be error. However, we also provide a [mask implementation](https://github.com/Eric-mingjie/network-slimming/tree/master/mask-impl) where we apply a mask to the scaling factor in BN layer. For mask implementaion, when pruning 60% of the channels in resnet164-cifar100, we can also train the pruned network. | ||
|
||
| CIFAR100-Densenet-40 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | | ||
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:| | ||
| Top1 Accuracy (%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 | | ||
| Parameters | 1.10M | 1.10M | 0.71M | 0.71M | 0.50M | 0.50M | | ||
|
||
## Prune mobilev2-ssdlite | ||
# details to see [README](./mbv2-ssdlite/README.md) | ||
|
||
## Contact | ||
sunmj15 at gmail.com | ||
liuzhuangthu at gmail.com |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
from torchvision import datasets, transforms | ||
from models import * | ||
|
||
|
||
# Prune settings | ||
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') | ||
parser.add_argument('--dataset', type=str, default='cifar100', | ||
help='training dataset (default: cifar10)') | ||
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', | ||
help='input batch size for testing (default: 256)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--depth', type=int, default=40, | ||
help='depth of the resnet') | ||
parser.add_argument('--percent', type=float, default=0.5, | ||
help='scale sparse rate (default: 0.5)') | ||
parser.add_argument('--model', default='', type=str, metavar='PATH', | ||
help='path to the model (default: none)') | ||
parser.add_argument('--save', default='', type=str, metavar='PATH', | ||
help='path to save pruned model (default: none)') | ||
|
||
args = parser.parse_args() | ||
args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
if not os.path.exists(args.save): | ||
os.makedirs(args.save) | ||
|
||
model = densenet(depth=args.depth, dataset=args.dataset) | ||
|
||
if args.cuda: | ||
model.cuda() | ||
if args.model: | ||
if os.path.isfile(args.model): | ||
print("=> loading checkpoint '{}'".format(args.model)) | ||
checkpoint = torch.load(args.model) | ||
args.start_epoch = checkpoint['epoch'] | ||
best_prec1 = checkpoint['best_prec1'] | ||
model.load_state_dict(checkpoint['state_dict']) | ||
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" | ||
.format(args.model, checkpoint['epoch'], best_prec1)) | ||
else: | ||
print("=> no checkpoint found at '{}'".format(args.resume)) | ||
|
||
total = 0 | ||
for m in model.modules(): | ||
if isinstance(m, nn.BatchNorm2d): | ||
total += m.weight.data.shape[0] | ||
|
||
bn = torch.zeros(total) | ||
index = 0 | ||
for m in model.modules(): | ||
if isinstance(m, nn.BatchNorm2d): | ||
size = m.weight.data.shape[0] | ||
bn[index:(index+size)] = m.weight.data.abs().clone() | ||
index += size | ||
|
||
y, i = torch.sort(bn) | ||
thre_index = int(total * args.percent) | ||
thre = y[thre_index] | ||
|
||
pruned = 0 | ||
cfg = [] | ||
cfg_mask = [] | ||
for k, m in enumerate(model.modules()): | ||
if isinstance(m, nn.BatchNorm2d): | ||
weight_copy = m.weight.data.abs().clone() | ||
mask = weight_copy.gt(thre).float().cuda() | ||
pruned = pruned + mask.shape[0] - torch.sum(mask) | ||
m.weight.data.mul_(mask) | ||
m.bias.data.mul_(mask) | ||
cfg.append(int(torch.sum(mask))) | ||
cfg_mask.append(mask.clone()) | ||
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. | ||
format(k, mask.shape[0], int(torch.sum(mask)))) | ||
elif isinstance(m, nn.MaxPool2d): | ||
cfg.append('M') | ||
|
||
pruned_ratio = pruned/total | ||
|
||
print('Pre-processing Successful!') | ||
|
||
# simple test model after Pre-processing prune (simple set BN scales to zeros) | ||
def test(model): | ||
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} | ||
if args.dataset == 'cifar10': | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), | ||
batch_size=args.test_batch_size, shuffle=False, **kwargs) | ||
elif args.dataset == 'cifar100': | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), | ||
batch_size=args.test_batch_size, shuffle=False, **kwargs) | ||
else: | ||
raise ValueError("No valid dataset is given.") | ||
model.eval() | ||
correct = 0 | ||
for data, target in test_loader: | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
data, target = Variable(data, volatile=True), Variable(target) | ||
output = model(data) | ||
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability | ||
correct += pred.eq(target.data.view_as(pred)).cpu().sum() | ||
|
||
print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( | ||
correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) | ||
return correct / float(len(test_loader.dataset)) | ||
|
||
acc = test(model) | ||
|
||
print("Cfg:") | ||
print(cfg) | ||
|
||
newmodel = densenet(depth=args.depth, dataset=args.dataset, cfg=cfg) | ||
|
||
if args.cuda: | ||
newmodel.cuda() | ||
|
||
num_parameters = sum([param.nelement() for param in newmodel.parameters()]) | ||
savepath = os.path.join(args.save, "prune.txt") | ||
with open(savepath, "w") as fp: | ||
fp.write("Configuration: \n"+str(cfg)+"\n") | ||
fp.write("Number of parameters: \n"+str(num_parameters)+"\n") | ||
fp.write("Test accuracy: \n"+str(acc)) | ||
|
||
old_modules = list(model.modules()) | ||
new_modules = list(newmodel.modules()) | ||
|
||
layer_id_in_cfg = 0 | ||
start_mask = torch.ones(3) | ||
end_mask = cfg_mask[layer_id_in_cfg] | ||
first_conv = True | ||
|
||
for layer_id in range(len(old_modules)): | ||
m0 = old_modules[layer_id] | ||
m1 = new_modules[layer_id] | ||
if isinstance(m0, nn.BatchNorm2d): | ||
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) | ||
if idx1.size == 1: | ||
idx1 = np.resize(idx1,(1,)) | ||
|
||
if isinstance(old_modules[layer_id + 1], channel_selection): | ||
# If the next layer is the channel selection layer, then the current batch normalization layer won't be pruned. | ||
m1.weight.data = m0.weight.data.clone() | ||
m1.bias.data = m0.bias.data.clone() | ||
m1.running_mean = m0.running_mean.clone() | ||
m1.running_var = m0.running_var.clone() | ||
|
||
# We need to set the mask parameter `indexes` for the channel selection layer. | ||
m2 = new_modules[layer_id + 1] | ||
m2.indexes.data.zero_() | ||
m2.indexes.data[idx1.tolist()] = 1.0 | ||
|
||
layer_id_in_cfg += 1 | ||
start_mask = end_mask.clone() | ||
if layer_id_in_cfg < len(cfg_mask): | ||
end_mask = cfg_mask[layer_id_in_cfg] | ||
continue | ||
|
||
elif isinstance(m0, nn.Conv2d): | ||
if first_conv: | ||
# We don't change the first convolution layer. | ||
m1.weight.data = m0.weight.data.clone() | ||
first_conv = False | ||
continue | ||
if isinstance(old_modules[layer_id - 1], channel_selection): | ||
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) | ||
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) | ||
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) | ||
if idx0.size == 1: | ||
idx0 = np.resize(idx0, (1,)) | ||
if idx1.size == 1: | ||
idx1 = np.resize(idx1, (1,)) | ||
|
||
# If the last layer is channel selection layer, then we don't change the number of output channels of the current | ||
# convolutional layer. | ||
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() | ||
m1.weight.data = w1.clone() | ||
continue | ||
|
||
elif isinstance(m0, nn.Linear): | ||
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) | ||
if idx0.size == 1: | ||
idx0 = np.resize(idx0, (1,)) | ||
|
||
m1.weight.data = m0.weight.data[:, idx0].clone() | ||
m1.bias.data = m0.bias.data.clone() | ||
|
||
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) | ||
|
||
print(newmodel) | ||
model = newmodel | ||
test(model) |
Oops, something went wrong.