Skip to content

Commit

Permalink
unet semantic segmentation with pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
yeyun111 committed Sep 7, 2017
1 parent 931836c commit 205f109
Show file tree
Hide file tree
Showing 5 changed files with 447 additions and 0 deletions.
2 changes: 2 additions & 0 deletions random_bonus/unet_segmentation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## 基于PyTorch实现U-Net图像分割
TO BE UPDATED
71 changes: 71 additions & 0 deletions random_bonus/unet_segmentation/argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import argparse
import torch.optim as optim

OPTIMIZERS = {
'adadelta': optim.Adadelta,
'adam': optim.Adam,
'rmsprop': optim.RMSprop,
'sgd': optim.SGD
}


def parse_args():
parser = argparse.ArgumentParser(
description='Simple Demo of Image Segmentation with U-Net',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# general options
parser.add_argument('mode',
help='train/test')
parser.add_argument('dataroot',
help='Directory containing training images in "images" and "segmentations" or test images')
parser.add_argument('--cpu',
help='Set to CPU mode', action='store_true')
parser.add_argument('--color_labels',
help='Colors of labels in segmentation image',
type=str, default='(0,0,0),(255,255,255)')

# training options
parser.add_argument('--img-dir',
help='Directory under [dataroot] containing images',
type=str, default='images')
parser.add_argument('--seg-dir',
help='Directory under [dataroot] containing segmentations',
type=str, default='segmentations')
parser.add_argument('--epochs',
help='Num of training epochs',
type=int, default=20)
parser.add_argument('--batch-size',
help='Batch size',
type=int, default=4)
parser.add_argument('--optimizer',
help='Optimizer: Adadelta/Adam/RMSprop/SGD',
type=str, default='SGD')
parser.add_argument('--lr',
help='Learning rate, for Adadelta it is the base learning rate',
type=float, default=0.01)
parser.add_argument('--lr-policy',
help='Learning rate policy, example:"5:0.0005,10:0.0001,18:1e-5"',
type=str, default='')
parser.add_argument('--no-batchnorm',
help='Do NOT use batch normalization', action='store_true')
parser.add_argument('--print-interval',
help='Print info after each specified iterations',
type=int, default=20)

# test options
parser.add_argument('--model',
help='Path to pre-trained model',
type=str, default='')
parser.add_argument('--output-dir',
help='Directory for output results',
type=str, default='')

args = parser.parse_args()
args.dataroot = args.dataroot.rstrip(os.sep)
args.color_labels = eval('[{}]'.format(args.color_labels))
args.optimizer = OPTIMIZERS[args.optimizer.lower()]
args.lr_policy = eval('{{{}}}'.format(args.lr_policy)) if args.lr_policy else {}

return args
179 changes: 179 additions & 0 deletions random_bonus/unet_segmentation/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import os
import sys
import numpy
from PIL import Image
import torch
import torchvision
import torch.utils.data
from torch.autograd import Variable
from argparser import parse_args
import utils
import networks


def train(args):
# set logger
logging_dir = 'train-{}'.format(utils.get_datetime_string())
os.mkdir('{}'.format(logging_dir))
logging.basicConfig(
level=logging.INFO,
filename='{}/log.txt'.format(logging_dir),
format='%(asctime)s %(message)s',
filemode='w'
)

console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

# initialize loader
train_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'train']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(64, 128))
val_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'val']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(64, 128))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True)

# initialize model, input channels need to be calculated by hand
model = networks.UNet([32, 64, 128, 256, 512], 3, 2)
if not args.cpu:
model.cuda()

criterion = utils.CrossEntropyLoss2D()

# optimizer & lr policy
lr = args.lr
optimizer = args.optimizer(model.parameters(), lr=lr)
logging.info('| Learning Rate\t| Initialized learning rate: {}'.format(lr))

# train
for epoch in range(args.epochs):
model.train()
# update lr if lr_policy is defined
if epoch in args.lr_policy:
lr = args.lr_policy[epoch]
optimizer = args.optimizer(model.parameters(), lr=lr)
logging.info('| Learning Rate\t| Epoch: {}\t| Change learning rate to {}'.format(epoch, lr))

# iterate all samples
losses = utils.AverageMeter()
for i_batch, (img, seg) in enumerate(train_loader):

img = Variable(img)
seg = Variable(seg)

if not args.cpu:
img = img.cuda()
seg = seg.cuda()

# compute output
output = model(img)
loss = criterion(output, seg)
losses.update(loss.data[0])

# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()

if i_batch % args.print_interval == 0:
logging.info(
'| Epoch: {}/{}\t'
'| Iteration: {}/{}\t'
'| Training loss: {}'.format(
epoch, args.epochs,
i_batch, len(train_loader),
losses.avg
)
)
losses = utils.AverageMeter()

model.eval()
losses = utils.AverageMeter()
for i_batch, (img, seg) in enumerate(val_loader):

img = Variable(img)
seg = Variable(seg)

if not args.cpu:
img = img.cuda()
seg = seg.cuda()

# compute output
output = model(img)
loss = criterion(output, seg)
losses.update(loss.data[0], float(img.size(0))/float(args.batch_size))

logging.info(
'| Epoch: {}/{}\t'
'| Validation loss: {}'.format(
epoch, args.epochs,
losses.avg
)
)

model_weights_path = '{}/epoch-{}.pth'.format(logging_dir, epoch+1)
torch.save(model.state_dict(), model_weights_path)
logging.info('| Checkpoint\t| {} is saved for epoch {}'.format(model_weights_path, epoch+1))


def test(args):
if not args.model:
print('Need a pretrained model!')
return

# check if output dir exists
output_dir = args.output_dir if args.output_dir else 'test-{}'.format(utils.get_datetime_string())
if not os.path.exists(output_dir):
os.mkdir(output_dir)

# load model
model = networks.UNet([32, 64, 128, 256, 512], 3, 2)
model.load_state_dict(torch.load(args.model))
model = model.eval()

if not args.cpu:
model.cuda()

# iterate all images with one by one
transform = torchvision.transforms.ToTensor()
for filename in [x for x in os.listdir(args.dataroot)]:
filepath = os.sep.join([args.dataroot, filename])
with open(filepath, 'r') as f:
img = Image.open(f)
img = img.resize((128, 256))
img = transform(img)
img = img.view(1, *img.shape)
img = Variable(img)
if not args.cpu:
img = img.cuda()
output = model(img)
_, c, h, w = output.data.shape
output_argmax = numpy.argmax(output.data.numpy()[0], axis=0)
out_img = numpy.zeros((h, w, 3), dtype=numpy.uint8)
for i, color in enumerate(args.color_labels):
out_img[output_argmax == i] = numpy.array(args.color_labels[i], dtype=numpy.uint8)
out_img = Image.fromarray(out_img)
seg_filepath = os.sep.join([output_dir, filename[:filename.rfind('.')]+'.png'])
out_img.save(seg_filepath)
print('{} is exported!'.format(seg_filepath))


if __name__ == '__main__':

args = parse_args()
if args.mode == 'train':
train(args)
elif args.mode == 'test':
test(args)
else:
print('Wrong input!')
66 changes: 66 additions & 0 deletions random_bonus/unet_segmentation/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetConvBlock(nn.Module):
def __init__(self, input_nch, output_nch, kernel_size=3, activation=F.leaky_relu, use_bn=True, same_conv=True):
super(UNetConvBlock, self).__init__()
padding = kernel_size // 2 if same_conv else 0 # only support odd kernel
self.conv0 = nn.Conv2d(input_nch, output_nch, kernel_size, padding=padding)
self.conv1 = nn.Conv2d(output_nch, output_nch, kernel_size, padding=padding)
self.act = activation
self.batch_norm = nn.BatchNorm2d(output_nch) if use_bn else None

def forward(self, x):
x = self.conv0(x)
if self.batch_norm:
x = self.batch_norm(x)
x = self.act(x)
x = self.conv1(x)
if self.batch_norm:
x = self.batch_norm(x)
return self.act(x)


class UNet(nn.Module):
def __init__(self, conv_channels, input_nch=3, output_nch=2, use_bn=True):
super(UNet, self).__init__()
self.n_stages = len(conv_channels)
# define convolution blocks
down_convs = []
up_convs = []

in_nch = input_nch
for i, out_nch in enumerate(conv_channels):
down_convs.append(UNetConvBlock(in_nch, out_nch, use_bn=use_bn))
up_conv_in_ch = 2 * out_nch if i < self.n_stages - 1 else out_nch # first up conv with equal channels
up_conv_out_ch = out_nch if i == 0 else in_nch # last up conv with channels equal to labels
up_convs.insert(0, UNetConvBlock(up_conv_in_ch, up_conv_out_ch, use_bn=use_bn))
in_nch = out_nch

self.down_convs = nn.ModuleList(down_convs)
self.up_convs = nn.ModuleList(up_convs)

# define output convolution
self.out_conv = nn.Conv2d(conv_channels[0], output_nch, 1)

def forward(self, x):
# conv & downsampling
down_sampled_fmaps = []
for i in range(self.n_stages-1):
x = self.down_convs[i](x)
x = F.max_pool2d(x, 2, 2)
down_sampled_fmaps.insert(0, x)

# center convs
x = self.down_convs[self.n_stages-1](x)
x = self.up_convs[0](x)

# conv & upsampling
for i, down_sampled_fmap in enumerate(down_sampled_fmaps):
x = torch.cat([x, down_sampled_fmap], 1)
x = self.up_convs[i+1](x)
x = F.upsample(x, scale_factor=2)

return self.out_conv(x)
Loading

0 comments on commit 205f109

Please sign in to comment.