Skip to content

Commit

Permalink
finished code for dilationv2 mobilenet, needs training
Browse files Browse the repository at this point in the history
  • Loading branch information
MSiam committed Jan 2, 2018
1 parent ada319f commit 9c5ec7d
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 16 deletions.
29 changes: 29 additions & 0 deletions config/experiments_config/dilationv2_mobilenet_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Directories arguments
data_dir: "full_cityscapes_res"
exp_dir: "dilationv2_mobilenet"
out_dir: "dilationv2_mobilenet"

# Data arguments
img_height: 512
img_width: 1024
num_channels: 3
num_classes: 20

# Train arguments
num_epochs: 200
batch_size: 4
shuffle: True
data_mode: "experiment_v2"
save_every: 10
test_every: 5
max_to_keep: 2
weighted_loss: True

# Models arguments
learning_rate: 0.0001
weight_decay: 0.000005
pretrained_path: "pretrained_weights/mobilenet_v1.pkl"

# Misc arguments
verbose: False

131 changes: 131 additions & 0 deletions models/dilationv2_mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from models.basic.basic_model import BasicModel
from models.encoders.VGG import VGG16
from models.encoders.mobilenet import MobileNet
from layers.convolution import conv2d_transpose, conv2d, atrous_conv2d, depthwise_separable_conv2d
import numpy as np
import tensorflow as tf
from utils.misc import _debug
import pdb

class DilationV2MobileNet(BasicModel):
"""
FCN8s with MobileNet as an encoder Model Architecture
"""

def __init__(self, args, phase=0):
super().__init__(args, phase=phase)
# init encoder
self.encoder = None
self.wd= self.args.weight_decay

# init network layers
self.upscore2 = None
self.score_feed1 = None
self.fuse_feed1 = None
self.upscore4 = None
self.score_feed2 = None
self.fuse_feed2 = None
self.upscore8 = None
self.targets_size = 8

def build(self):
print("\nBuilding the MODEL...")
self.init_input()
self.init_network()
self.init_output()
self.init_train()
self.init_summaries()
print("The Model is built successfully\n")

def init_input(self):
with tf.name_scope('input'):
self.x_pl = tf.placeholder(tf.float32,
[self.args.batch_size, self.params.img_height, self.params.img_width, 3])
self.y_pl = tf.placeholder(tf.int32, [self.args.batch_size, self.params.img_height//self.targets_size,
self.params.img_width//self.targets_size])
print('X_batch shape ', self.x_pl.get_shape().as_list(), ' ', self.y_pl.get_shape().as_list())
print('Afterwards: X_batch shape ', self.x_pl.get_shape().as_list(), ' ', self.y_pl.get_shape().as_list())

self.curr_learning_rate = tf.placeholder(tf.float32)
if self.params.weighted_loss:
self.wghts = np.zeros((self.args.batch_size, self.params.img_height, self.params.img_width),
dtype=np.float32)
self.is_training = tf.placeholder(tf.bool)


def init_network(self):
"""
Building the Network here
:return:
"""
# Init MobileNet as an encoder
self.encoder = MobileNet(x_input=self.x_pl, num_classes=self.params.num_classes,
pretrained_path=self.args.pretrained_path,
train_flag=self.is_training, width_multipler=1.0, weight_decay=self.args.weight_decay)

# Build Encoding part
self.encoder.build()

# Build Decoding part
with tf.name_scope('dilation_2'):
self.conv4_2 = atrous_conv2d('conv_ds_7_dil', self.encoder.conv4_1,
num_filters=512, kernel_size=(3, 3), padding='SAME',
activation=tf.nn.relu, dilation_rate=2,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv4_2)
self.conv5_1 = depthwise_separable_conv2d('conv_ds_8_dil', self.conv4_2,
width_multiplier=self.encoder.width_multiplier,
num_filters=512, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_1)
self.conv5_2 = depthwise_separable_conv2d('conv_ds_9_dil', self.conv5_1,
width_multiplier=self.encoder.width_multiplier,
num_filters=512, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_2)
self.conv5_3 = depthwise_separable_conv2d('conv_ds_10_dil', self.conv5_2,
width_multiplier=self.encoder.width_multiplier,
num_filters=512, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_3)
self.conv5_4 = depthwise_separable_conv2d('conv_ds_11_dil', self.conv5_3,
width_multiplier=self.encoder.width_multiplier,
num_filters=512, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_4)
self.conv5_5 = depthwise_separable_conv2d('conv_ds_12_dil', self.conv5_4,
width_multiplier=self.encoder.width_multiplier,
num_filters=512, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_5)
self.conv5_6 = atrous_conv2d('conv_ds_13_dil', self.conv5_5,
num_filters=1024, kernel_size=(3, 3), padding='SAME',
activation=tf.nn.relu, dilation_rate=4,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv5_6)
self.conv6_1 = depthwise_separable_conv2d('conv_ds_14_dil', self.conv5_6,
width_multiplier=self.encoder.width_multiplier,
num_filters=1024, kernel_size=(3, 3), padding='SAME',
stride=(1, 1), activation=tf.nn.relu,
batchnorm_enabled=True, is_training=self.is_training,
l2_strength=self.wd)
_debug(self.conv6_1)
# Pooling is removed.
self.score_fr = conv2d('conv_1c_1x1_dil', self.conv6_1, num_filters=self.params.num_classes, l2_strength=self.wd,
kernel_size=(1, 1))

_debug(self.score_fr)
self.logits= self.score_fr

5 changes: 4 additions & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@

#6- Dilation v1 MobileNet Test

#7- Dilation v2 MobileNet Train
python3 main.py --load_config=dilationv2_mobilenet_train.yaml train Train DilationV2MobileNet

###################################### ShuffleNet #################################################
#1- FCN8s ShuffleNet Train Coarse+Fine
python3 main.py --load_config=fcn8s_shufflenet_traincoarse.yaml train Train FCN8sShuffleNet
#python3 main.py --load_config=fcn8s_shufflenet_traincoarse.yaml train Train FCN8sShuffleNet
#python3 main.py --load_config=fcn8s_shufflenet_train.yaml train Train FCN8sShuffleNet

#2- FCN8s ShuffleNet Test
Expand Down
45 changes: 30 additions & 15 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cv2
#import cv2

from utils.img_utils import decode_labels
from utils.seg_dataloader import SegDataLoader
Expand Down Expand Up @@ -81,6 +81,16 @@ def __init__(self, args, sess, train_model, test_model):
self.num_iterations_validation_per_epoch = None
self.load_train_data_h5()
self.generator = self.train_h5_generator
elif self.args.data_mode == "experiment_v2":
self.targets_resize= 8
self.train_data = None
self.train_data_len = None
self.val_data = None
self.val_data_len = None
self.num_iterations_training_per_epoch = None
self.num_iterations_validation_per_epoch = None
self.load_train_data(v2=True)
self.generator = self.train_generator
elif self.args.data_mode == "experiment":
self.train_data = None
self.train_data_len = None
Expand Down Expand Up @@ -115,16 +125,6 @@ def __init__(self, args, sess, train_model, test_model):
# self.debug_y= misc.imread('/data/menna/cityscapes/gtFine/val/lindau/lindau_000048_000019_gtFine_labelIds.png')
# self.debug_x= np.expand_dims(misc.imresize(self.debug_x, (512,1024)), axis=0)
# self.debug_y= np.expand_dims(misc.imresize(self.debug_y, (512,1024)), axis=0)

# torch_data= torchfile.load('/data/menna/TFSegmentation/out_networks_layers/dict_out.t7')
# stat= torchfile.load('/data/menna/cityscape/512_1024/stat.t7')
# pdb.set_trace()
# torch_data= torchfile.load('/data/menna/cityscape/512_1024/data.t7')
# self.debug_x= np.expand_dims(torch_data[b'testData'][b'data'][0,:,:,:].transpose(1,2,0), axis=0)
# self.debug_y= np.expand_dims(torch_data[b'testData'][b'labels'][0,:,:], axis=0)
# np.save('data/debug/debug_x.npy', self.debug_x)
# np.save('data/debug/debug_y.npy', self.debug_y)

self.debug_x = np.load('data/debug/debug_x.npy')
self.debug_y = np.load('data/debug/debug_y.npy')
print("Debugging photo loaded")
Expand Down Expand Up @@ -270,13 +270,21 @@ def add_summary(self, step, summaries_dict=None, summaries_merged=None):
self.summary_writer.add_summary(summaries_merged, step)

@timeit
def load_train_data(self):
def load_train_data(self, v2=False):
print("Loading Training data..")
self.train_data = {'X': np.load(self.args.data_dir + "X_train.npy"),
'Y': np.load(self.args.data_dir + "Y_train.npy")}
if v2:
out_shape= (self.train_data['Y'].shape[1]//self.targets_resize,
self.train_data['Y'].shape[2]//self.targets_resize)
yy= np.zeros((self.train_data['Y'].shape[0],out_shape[0],out_shape[1]), dtype=self.train_data['Y'].dtype)
for y in range(self.train_data['Y'].shape[0]):
yy[y,...]= misc.imresize(self.train_data['Y'][y,...], out_shape, interp='nearest')
self.train_data['Y']=yy
self.train_data_len = self.train_data['X'].shape[0]
self.num_iterations_training_per_epoch = (
self.train_data_len + self.args.batch_size - 1) // self.args.batch_size

self.num_iterations_training_per_epoch = (self.train_data_len + self.args.batch_size - 1) // self.args.batch_size

print("Train-shape-x -- " + str(self.train_data['X'].shape) + " " + str(self.train_data_len))
print("Train-shape-y -- " + str(self.train_data['Y'].shape))
print("Num of iterations on training data in one epoch -- " + str(self.num_iterations_training_per_epoch))
Expand All @@ -285,6 +293,14 @@ def load_train_data(self):
print("Loading Validation data..")
self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
'Y': np.load(self.args.data_dir + "Y_val.npy")}
if v2:
out_shape= (self.val_data['Y'].shape[1]//self.targets_resize,
self.val_data['Y'].shape[2]//self.targets_resize)
yy= np.zeros((self.val_data['Y'].shape[0],out_shape[0],out_shape[1]), dtype=self.train_data['Y'].dtype)
for y in range(self.val_data['Y'].shape[0]):
yy[y,...]= misc.imresize(self.val_data['Y'][y,...], out_shape, interp='nearest')
self.val_data['Y']=yy

self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
self.num_iterations_validation_per_epoch = (
self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
Expand Down Expand Up @@ -409,7 +425,6 @@ def train_h5_generator(self):

def train(self):
print("Training mode will begin NOW ..")

curr_lr = self.train_model.args.learning_rate
for cur_epoch in range(self.train_model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):

Expand Down

0 comments on commit 9c5ec7d

Please sign in to comment.