Skip to content

Commit

Permalink
Merge pull request sthalles#19 from sthalles/tf1.7
Browse files Browse the repository at this point in the history
Tf1.7
  • Loading branch information
sthalles authored May 8, 2018
2 parents 38b3bd1 + 142cb88 commit c83d4b6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 82 deletions.
27 changes: 19 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ For a complete documentation of this implementation, check out the [blog post](h

- Python 3.x
- Numpy
- Tensorflow 1.4.0
- Tensorflow 1.7.0

## Downloads

Expand All @@ -19,14 +19,14 @@ Download the model checkpoints and dataset.
* [Option 1](https://mega.nz/#F!LlFCSaBB!1L_EoepUwhrHw4lHv1HRaA)
* [Option 2](http://www.mediafire.com/?wx7h526chc4ar)

## Training and Eval
Place the checkpoints files inside `./dataset/tfrecords`. If the folder **does not** exist, create it.

Before training, create a folder named ```checkpoints/``` inside the ```resnet/``` directory. Download the pre-trained [resnet-50](https://arxiv.org/abs/1603.05027) or [resnet-101](https://arxiv.org/abs/1603.05027) models, and place the files inside ```checkpoints/```.
## Training and Eval

To train this model run:
Once you have the training and validation *TfRefords* files, just run the command bellow. Before running Deeplab_v3, the code will look for the proper `ResNets` checkpoints inside ```./resnet/checkpoints```, if the folder does not exist, it will first be **downloaded**.

```
python train.py --starting_learning_rate=0.00001 --batch_norm_decay=0.997 --gpu_id=0 --resnet_model=resnet_v2_50
python train.py --starting_learning_rate=0.00001 --batch_norm_decay=0.997 --crop_size=513 --gpu_id=0 --resnet_model=resnet_v2_50
```

Check out the *train.py* file for more input argument options. Each run produces a folder inside the *tboard_logs* directory (create it if not there).
Expand All @@ -37,11 +37,22 @@ To evaluate the model, run the *test.py* file passing to it the *model_id* param
python test.py --model_id=16645
```

## Retraining

To use a different dataset, you just need to modify the ```CreateTfRecord.ipynb``` notebook inside the ```dataset/``` folder, to suit your needs.

Also, be aware that originally Deeplab_v3 performs random crops of size *513x513* on the input images. This **crop_size** parameter can be configured by changing the *crop_size* hyper-parameter in **train.py**.

## Datasets

To create the dataset, first make sure you have the [Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and the [Semantic Boundaries Dataset and Benchmark](http://home.bharathh.info/pubs/codes/SBD/download.html) datasets downloaded.
To create the dataset, first make sure you have the [Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and/or the [Semantic Boundaries Dataset and Benchmark](http://home.bharathh.info/pubs/codes/SBD/download.html) datasets downloaded.

**Note: You do not need both datasets.**
- If you just want to test the code with one of the datasets (say the SBD), run the notebook normally, and it should work.

After, head to ```dataset/``` and run the ```CreateTfRecord.ipynb``` notebook.

After, head to ```dataset/``` and run the ```CreateTfRecord.ipynb``` notebook. The ```custom_train.txt``` file contains the name of the images selected for training. This file is designed to use the Pascal VOC 2012 set as a **TESTING** set. Therefore, it doesn't contain any images from the VOC 2012 val dataset. For more info, see the **Training** section of [Deeplab Image Semantic Segmentation Network](https://sthalles.github.io/deep_segmentation_network/).
The ```custom_train.txt``` file contains the name of the images selected for training. This file is designed to use the Pascal VOC 2012 set as a **TESTING** set. Therefore, it doesn't contain any images from the VOC 2012 val dataset. For more info, see the **Training** section of [Deeplab Image Semantic Segmentation Network](https://sthalles.github.io/deep_segmentation_network/).

Obs. You can skip that part and direct download the datasets used in this experiment - See the **Downloads** section

Expand All @@ -50,6 +61,6 @@ Obs. You can skip that part and direct download the datasets used in this experi
- Pixel accuracy: ~91%
- Mean Accuracy: ~82%
- Mean Intersection over Union (mIoU): ~74%
- Frequency weighed Intersection over Union: ~86.
- Frequency weighed Intersection over Union: ~86

![Results](https://github.com/sthalles/sthalles.github.io/blob/master/assets/deep_segmentation_network/results1.png)
94 changes: 44 additions & 50 deletions dataset/CreateTfRecord.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -11,12 +11,19 @@
"import os\n",
"import scipy.io as spio\n",
"from matplotlib import pyplot as plt\n",
"from scipy.misc import imread"
"from imageio import imread"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Obs: If you only have one of the datasets (does not matter which one), just run all the notebook's cells and it will work just fine."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -30,13 +37,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define base paths for pascal augmented VOC images\n",
"# download: http://home.bharathh.info/pubs/codes/SBD/download.html\n",
"base_dataset_dir_aug_voc = '<path-to-aug-voc>/benchmark_RELEASE/dataset'\n",
"base_dataset_dir_aug_voc = '<pascal/augmented/VOC/images/path>/benchmark_RELEASE/dataset'\n",
"images_folder_name_aug_voc = \"img/\"\n",
"annotations_folder_name_aug_voc = \"cls/\"\n",
"images_dir_aug_voc = os.path.join(base_dataset_dir_aug_voc, images_folder_name_aug_voc)\n",
Expand All @@ -45,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -60,25 +67,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of training images: 8252\n"
]
}
],
"outputs": [],
"source": [
"images_filename_list = get_files_list(base_dataset_dir_aug_voc, images_folder_name_aug_voc, annotations_folder_name_aug_voc, \"custom_train.txt\")\n",
"print(\"Total number of training images:\", len(images_filename_list))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -90,30 +89,24 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train set size: 7427\n",
"val set size: 825\n"
]
}
],
"outputs": [],
"source": [
"print(\"train set size:\", len(train_images_filename_list))\n",
"print(\"val set size:\", len(val_images_filename_list))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"TRAIN_DATASET_DIR=\"./dataset/\"\n",
"TRAIN_DATASET_DIR=\"./tfrecords/\"\n",
"if not os.path.exists(TRAIN_DATASET_DIR):\n",
" os.mkdir(TRAIN_DATASET_DIR)\n",
" \n",
"TRAIN_FILE = 'train.tfrecords'\n",
"VALIDATION_FILE = 'validation.tfrecords'\n",
"train_writer = tf.python_io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,TRAIN_FILE))\n",
Expand All @@ -122,7 +115,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -135,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -148,27 +141,36 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_tfrecord_dataset(filename_list, writer):\n",
"\n",
" # create training tfrecord\n",
" read_imgs_counter = 0\n",
" for i, image_name in enumerate(filename_list):\n",
"\n",
" try:\n",
" image_np = imread(os.path.join(images_dir_aug_voc, image_name.strip() + \".jpg\"))\n",
" except FileNotFoundError:\n",
" # read from Pascal VOC path\n",
" image_np = imread(os.path.join(images_dir_voc, image_name.strip() + \".jpg\"))\n",
" \n",
" try:\n",
" # read from Pascal VOC path\n",
" image_np = imread(os.path.join(images_dir_voc, image_name.strip() + \".jpg\"))\n",
" except FileNotFoundError:\n",
" print(\"File:\",image_name.strip(),\"not found.\")\n",
" continue\n",
" try:\n",
" annotation_np = read_annotation_from_mat_file(annotations_dir_aug_voc, image_name)\n",
" except FileNotFoundError:\n",
" # read from Pascal VOC path\n",
" annotation_np = imread(os.path.join(annotations_dir_voc, image_name.strip() + \".png\"))\n",
" \n",
" try:\n",
" annotation_np = imread(os.path.join(annotations_dir_voc, image_name.strip() + \".png\"))\n",
" except FileNotFoundError:\n",
" print(\"File:\",image_name.strip(),\"not found.\")\n",
" continue\n",
" \n",
" read_imgs_counter += 1\n",
" image_h = image_np.shape[0]\n",
" image_w = image_np.shape[1]\n",
"\n",
Expand All @@ -183,7 +185,7 @@
"\n",
" writer.write(example.SerializeToString())\n",
" \n",
" print(\"End of TfRecord. Total of image written:\", i)\n",
" print(\"End of TfRecord. Total of image written:\", read_imgs_counter)\n",
" writer.close()"
]
},
Expand All @@ -199,17 +201,9 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"End of TfRecord. Total of image written: 824\n"
]
}
],
"outputs": [],
"source": [
"# create validation dataset\n",
"create_tfrecord_dataset(val_images_filename_list, val_writer)"
Expand Down Expand Up @@ -239,7 +233,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def atrous_spatial_pyramid_pooling(net, scope, depth=256, reuse=None):
feature_map_size = tf.shape(net)

# apply global average pooling
image_level_features = tf.reduce_mean(net, [1, 2], name='image_level_global_pool', keep_dims=True)
image_level_features = tf.reduce_mean(net, [1, 2], name='image_level_global_pool', keepdims=True)
image_level_features = slim.conv2d(image_level_features, depth, [1, 1], scope="image_level_conv_1x1",
activation_fn=None)
image_level_features = tf.image.resize_bilinear(image_level_features, (feature_map_size[1], feature_map_size[2]))
Expand All @@ -45,7 +45,7 @@ def deeplab_v3(inputs, args, is_training, reuse):
# mean subtraction normalization
inputs = inputs - [_R_MEAN, _G_MEAN, _B_MEAN]

# inputs has shape [batch, 513, 513, 3]
# inputs has shape - Original: [batch, 513, 513, 3]
with slim.arg_scope(resnet_utils.resnet_arg_scope(args.l2_regularizer, is_training,
args.batch_norm_decay,
args.batch_norm_epsilon)):
Expand Down
50 changes: 39 additions & 11 deletions preprocessing/read_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import tensorflow as tf
from preprocessing.inception_preprocessing import apply_with_random_selector, distort_color
import urllib
import tarfile
import os

def random_flip_image_and_annotation(image_tensor, annotation_tensor, shapes):
def random_flip_image_and_annotation(image_tensor, annotation_tensor):
"""Accepts image tensor and annotation tensor and returns randomly flipped tensors of both.
The function performs random flip of image and annotation tensors with probability of 1/2
The flip is performed or not performed for image and annotation consistently, so that
Expand Down Expand Up @@ -41,10 +44,10 @@ def random_flip_image_and_annotation(image_tensor, annotation_tensor, shapes):
true_fn=lambda: tf.image.flip_left_right(annotation_tensor),
false_fn=lambda: annotation_tensor)

return randomly_flipped_img, tf.reshape(randomly_flipped_annotation, original_shape), shapes
return randomly_flipped_img, tf.reshape(randomly_flipped_annotation, original_shape)


def rescale_image_and_annotation_by_factor(image, annotation, shapes, nin_scale=0.5, max_scale=2):
def rescale_image_and_annotation_by_factor(image, annotation, nin_scale=0.5, max_scale=2):
#We apply data augmentation by randomly scaling theinput images(from 0.5 to 2.0)
#and randomly left - right flipping during training.
input_shape = tf.shape(image)[0:2]
Expand All @@ -63,24 +66,49 @@ def rescale_image_and_annotation_by_factor(image, annotation, shapes, nin_scale=
annotation = tf.image.resize_images(annotation, scaled_input_shape,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

return image, annotation, shapes
return image, annotation


def scale_image_with_crop_padding(image, annotation, shapes):
image_croped = tf.image.resize_image_with_crop_or_pad(image,513,513)
def download_resnet_checkpoint_if_necessary(resnet_checkpoints_path, resnet_model_name):
"""
Check if the resnet checkpoints are already downloaded, if not download it
:param resnet_checkpoints_path: string: path where the properly resnet checkpoint files should be found
:param resnet_model_name: one of resnet_v2_50 or resnet_v2_101
:return: None
"""
if not os.path.exists(resnet_checkpoints_path):
# create the path and download the resnet checkpoints
os.mkdir(resnet_checkpoints_path)

filename = resnet_model_name + "_2017_04_14.tar.gz"

url = "http://download.tensorflow.org/models/" + filename
full_file_path = os.path.join(resnet_checkpoints_path, filename)
urllib.request.urlretrieve(url, full_file_path)
thetarfile = tarfile.open(full_file_path, "r:gz")
thetarfile.extractall(path=resnet_checkpoints_path)
thetarfile.close()
print("Resnet:", resnet_model_name, "successfully downloaded.")
else:
print("ResNet checkpoints file successfully found.")


def scale_image_with_crop_padding(image, annotation, crop_size):

image_croped = tf.image.resize_image_with_crop_or_pad(image,crop_size,crop_size)

# Shift all the classes by one -- to be able to differentiate
# between zeros representing padded values and zeros representing
# a particular semantic class.
annotation_shifted_classes = annotation + 1

cropped_padded_annotation = tf.image.resize_image_with_crop_or_pad(annotation_shifted_classes,513,513)
cropped_padded_annotation = tf.image.resize_image_with_crop_or_pad(annotation_shifted_classes,crop_size,crop_size)

mask_out_number=255
annotation_additional_mask_out = tf.to_int32(tf.equal(cropped_padded_annotation, 0)) * (mask_out_number+1)
cropped_padded_annotation = cropped_padded_annotation + annotation_additional_mask_out - 1

return image_croped, tf.squeeze(cropped_padded_annotation), shapes
return image_croped, tf.squeeze(cropped_padded_annotation)

def tf_record_parser(record):
keys_to_features = {
Expand All @@ -103,9 +131,9 @@ def tf_record_parser(record):
annotation = tf.reshape(annotation, (height,width,1), name="annotation_reshape")
annotation = tf.to_int32(annotation)

return tf.to_float(image), annotation, (height, width)
return tf.to_float(image), annotation

def distort_randomly_image_color(image_tensor, annotation_tensor, shapes):
def distort_randomly_image_color(image_tensor, annotation_tensor):
"""Accepts image tensor of (width, height, 3) and returns color distorted image.
The function performs random brightness, saturation, hue, contrast change as it is performed
for inception model training in TF-Slim (you can find the link below in comments). All the
Expand Down Expand Up @@ -139,4 +167,4 @@ def distort_randomly_image_color(image_tensor, annotation_tensor, shapes):

img_float_distorted_original_range = distorted_image * 255

return img_float_distorted_original_range, annotation_tensor, shapes
return img_float_distorted_original_range, annotation_tensor
Loading

0 comments on commit c83d4b6

Please sign in to comment.