forked from yeyun111/dlcv_for_beginners
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
unet semantic segmentation with pytorch
- Loading branch information
Showing
5 changed files
with
447 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
## 基于PyTorch实现U-Net图像分割 | ||
TO BE UPDATED |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
import argparse | ||
import torch.optim as optim | ||
|
||
OPTIMIZERS = { | ||
'adadelta': optim.Adadelta, | ||
'adam': optim.Adam, | ||
'rmsprop': optim.RMSprop, | ||
'sgd': optim.SGD | ||
} | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Simple Demo of Image Segmentation with U-Net', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
|
||
# general options | ||
parser.add_argument('mode', | ||
help='train/test') | ||
parser.add_argument('dataroot', | ||
help='Directory containing training images in "images" and "segmentations" or test images') | ||
parser.add_argument('--cpu', | ||
help='Set to CPU mode', action='store_true') | ||
parser.add_argument('--color_labels', | ||
help='Colors of labels in segmentation image', | ||
type=str, default='(0,0,0),(255,255,255)') | ||
|
||
# training options | ||
parser.add_argument('--img-dir', | ||
help='Directory under [dataroot] containing images', | ||
type=str, default='images') | ||
parser.add_argument('--seg-dir', | ||
help='Directory under [dataroot] containing segmentations', | ||
type=str, default='segmentations') | ||
parser.add_argument('--epochs', | ||
help='Num of training epochs', | ||
type=int, default=20) | ||
parser.add_argument('--batch-size', | ||
help='Batch size', | ||
type=int, default=4) | ||
parser.add_argument('--optimizer', | ||
help='Optimizer: Adadelta/Adam/RMSprop/SGD', | ||
type=str, default='SGD') | ||
parser.add_argument('--lr', | ||
help='Learning rate, for Adadelta it is the base learning rate', | ||
type=float, default=0.01) | ||
parser.add_argument('--lr-policy', | ||
help='Learning rate policy, example:"5:0.0005,10:0.0001,18:1e-5"', | ||
type=str, default='') | ||
parser.add_argument('--no-batchnorm', | ||
help='Do NOT use batch normalization', action='store_true') | ||
parser.add_argument('--print-interval', | ||
help='Print info after each specified iterations', | ||
type=int, default=20) | ||
|
||
# test options | ||
parser.add_argument('--model', | ||
help='Path to pre-trained model', | ||
type=str, default='') | ||
parser.add_argument('--output-dir', | ||
help='Directory for output results', | ||
type=str, default='') | ||
|
||
args = parser.parse_args() | ||
args.dataroot = args.dataroot.rstrip(os.sep) | ||
args.color_labels = eval('[{}]'.format(args.color_labels)) | ||
args.optimizer = OPTIMIZERS[args.optimizer.lower()] | ||
args.lr_policy = eval('{{{}}}'.format(args.lr_policy)) if args.lr_policy else {} | ||
|
||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import logging | ||
import os | ||
import sys | ||
import numpy | ||
from PIL import Image | ||
import torch | ||
import torchvision | ||
import torch.utils.data | ||
from torch.autograd import Variable | ||
from argparser import parse_args | ||
import utils | ||
import networks | ||
|
||
|
||
def train(args): | ||
# set logger | ||
logging_dir = 'train-{}'.format(utils.get_datetime_string()) | ||
os.mkdir('{}'.format(logging_dir)) | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
filename='{}/log.txt'.format(logging_dir), | ||
format='%(asctime)s %(message)s', | ||
filemode='w' | ||
) | ||
|
||
console = logging.StreamHandler() | ||
console.setLevel(logging.INFO) | ||
formatter = logging.Formatter('%(asctime)s %(message)s') | ||
console.setFormatter(formatter) | ||
logging.getLogger('').addHandler(console) | ||
|
||
# initialize loader | ||
train_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'train']), | ||
image_folder=args.img_dir, | ||
segmentation_folder=args.seg_dir, | ||
labels=args.color_labels, | ||
image_size=(64, 128)) | ||
val_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'val']), | ||
image_folder=args.img_dir, | ||
segmentation_folder=args.seg_dir, | ||
labels=args.color_labels, | ||
image_size=(64, 128)) | ||
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) | ||
val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True) | ||
|
||
# initialize model, input channels need to be calculated by hand | ||
model = networks.UNet([32, 64, 128, 256, 512], 3, 2) | ||
if not args.cpu: | ||
model.cuda() | ||
|
||
criterion = utils.CrossEntropyLoss2D() | ||
|
||
# optimizer & lr policy | ||
lr = args.lr | ||
optimizer = args.optimizer(model.parameters(), lr=lr) | ||
logging.info('| Learning Rate\t| Initialized learning rate: {}'.format(lr)) | ||
|
||
# train | ||
for epoch in range(args.epochs): | ||
model.train() | ||
# update lr if lr_policy is defined | ||
if epoch in args.lr_policy: | ||
lr = args.lr_policy[epoch] | ||
optimizer = args.optimizer(model.parameters(), lr=lr) | ||
logging.info('| Learning Rate\t| Epoch: {}\t| Change learning rate to {}'.format(epoch, lr)) | ||
|
||
# iterate all samples | ||
losses = utils.AverageMeter() | ||
for i_batch, (img, seg) in enumerate(train_loader): | ||
|
||
img = Variable(img) | ||
seg = Variable(seg) | ||
|
||
if not args.cpu: | ||
img = img.cuda() | ||
seg = seg.cuda() | ||
|
||
# compute output | ||
output = model(img) | ||
loss = criterion(output, seg) | ||
losses.update(loss.data[0]) | ||
|
||
# compute gradient and do SGD step | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if i_batch % args.print_interval == 0: | ||
logging.info( | ||
'| Epoch: {}/{}\t' | ||
'| Iteration: {}/{}\t' | ||
'| Training loss: {}'.format( | ||
epoch, args.epochs, | ||
i_batch, len(train_loader), | ||
losses.avg | ||
) | ||
) | ||
losses = utils.AverageMeter() | ||
|
||
model.eval() | ||
losses = utils.AverageMeter() | ||
for i_batch, (img, seg) in enumerate(val_loader): | ||
|
||
img = Variable(img) | ||
seg = Variable(seg) | ||
|
||
if not args.cpu: | ||
img = img.cuda() | ||
seg = seg.cuda() | ||
|
||
# compute output | ||
output = model(img) | ||
loss = criterion(output, seg) | ||
losses.update(loss.data[0], float(img.size(0))/float(args.batch_size)) | ||
|
||
logging.info( | ||
'| Epoch: {}/{}\t' | ||
'| Validation loss: {}'.format( | ||
epoch, args.epochs, | ||
losses.avg | ||
) | ||
) | ||
|
||
model_weights_path = '{}/epoch-{}.pth'.format(logging_dir, epoch+1) | ||
torch.save(model.state_dict(), model_weights_path) | ||
logging.info('| Checkpoint\t| {} is saved for epoch {}'.format(model_weights_path, epoch+1)) | ||
|
||
|
||
def test(args): | ||
if not args.model: | ||
print('Need a pretrained model!') | ||
return | ||
|
||
# check if output dir exists | ||
output_dir = args.output_dir if args.output_dir else 'test-{}'.format(utils.get_datetime_string()) | ||
if not os.path.exists(output_dir): | ||
os.mkdir(output_dir) | ||
|
||
# load model | ||
model = networks.UNet([32, 64, 128, 256, 512], 3, 2) | ||
model.load_state_dict(torch.load(args.model)) | ||
model = model.eval() | ||
|
||
if not args.cpu: | ||
model.cuda() | ||
|
||
# iterate all images with one by one | ||
transform = torchvision.transforms.ToTensor() | ||
for filename in [x for x in os.listdir(args.dataroot)]: | ||
filepath = os.sep.join([args.dataroot, filename]) | ||
with open(filepath, 'r') as f: | ||
img = Image.open(f) | ||
img = img.resize((128, 256)) | ||
img = transform(img) | ||
img = img.view(1, *img.shape) | ||
img = Variable(img) | ||
if not args.cpu: | ||
img = img.cuda() | ||
output = model(img) | ||
_, c, h, w = output.data.shape | ||
output_argmax = numpy.argmax(output.data.numpy()[0], axis=0) | ||
out_img = numpy.zeros((h, w, 3), dtype=numpy.uint8) | ||
for i, color in enumerate(args.color_labels): | ||
out_img[output_argmax == i] = numpy.array(args.color_labels[i], dtype=numpy.uint8) | ||
out_img = Image.fromarray(out_img) | ||
seg_filepath = os.sep.join([output_dir, filename[:filename.rfind('.')]+'.png']) | ||
out_img.save(seg_filepath) | ||
print('{} is exported!'.format(seg_filepath)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
args = parse_args() | ||
if args.mode == 'train': | ||
train(args) | ||
elif args.mode == 'test': | ||
test(args) | ||
else: | ||
print('Wrong input!') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class UNetConvBlock(nn.Module): | ||
def __init__(self, input_nch, output_nch, kernel_size=3, activation=F.leaky_relu, use_bn=True, same_conv=True): | ||
super(UNetConvBlock, self).__init__() | ||
padding = kernel_size // 2 if same_conv else 0 # only support odd kernel | ||
self.conv0 = nn.Conv2d(input_nch, output_nch, kernel_size, padding=padding) | ||
self.conv1 = nn.Conv2d(output_nch, output_nch, kernel_size, padding=padding) | ||
self.act = activation | ||
self.batch_norm = nn.BatchNorm2d(output_nch) if use_bn else None | ||
|
||
def forward(self, x): | ||
x = self.conv0(x) | ||
if self.batch_norm: | ||
x = self.batch_norm(x) | ||
x = self.act(x) | ||
x = self.conv1(x) | ||
if self.batch_norm: | ||
x = self.batch_norm(x) | ||
return self.act(x) | ||
|
||
|
||
class UNet(nn.Module): | ||
def __init__(self, conv_channels, input_nch=3, output_nch=2, use_bn=True): | ||
super(UNet, self).__init__() | ||
self.n_stages = len(conv_channels) | ||
# define convolution blocks | ||
down_convs = [] | ||
up_convs = [] | ||
|
||
in_nch = input_nch | ||
for i, out_nch in enumerate(conv_channels): | ||
down_convs.append(UNetConvBlock(in_nch, out_nch, use_bn=use_bn)) | ||
up_conv_in_ch = 2 * out_nch if i < self.n_stages - 1 else out_nch # first up conv with equal channels | ||
up_conv_out_ch = out_nch if i == 0 else in_nch # last up conv with channels equal to labels | ||
up_convs.insert(0, UNetConvBlock(up_conv_in_ch, up_conv_out_ch, use_bn=use_bn)) | ||
in_nch = out_nch | ||
|
||
self.down_convs = nn.ModuleList(down_convs) | ||
self.up_convs = nn.ModuleList(up_convs) | ||
|
||
# define output convolution | ||
self.out_conv = nn.Conv2d(conv_channels[0], output_nch, 1) | ||
|
||
def forward(self, x): | ||
# conv & downsampling | ||
down_sampled_fmaps = [] | ||
for i in range(self.n_stages-1): | ||
x = self.down_convs[i](x) | ||
x = F.max_pool2d(x, 2, 2) | ||
down_sampled_fmaps.insert(0, x) | ||
|
||
# center convs | ||
x = self.down_convs[self.n_stages-1](x) | ||
x = self.up_convs[0](x) | ||
|
||
# conv & upsampling | ||
for i, down_sampled_fmap in enumerate(down_sampled_fmaps): | ||
x = torch.cat([x, down_sampled_fmap], 1) | ||
x = self.up_convs[i+1](x) | ||
x = F.upsample(x, scale_factor=2) | ||
|
||
return self.out_conv(x) |
Oops, something went wrong.