Skip to content

Commit

Permalink
Adapt code to use Tid2013 dataset and fix show_images util function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo Ocampo committed Oct 6, 2019
1 parent 607044e commit 5d5e467
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 32 deletions.
104 changes: 73 additions & 31 deletions notebooks/train-diqa-base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
"outputs": [],
"source": [
"import imageio\n",
"import matplotlib.pyplot as plt\n",
"import numpy\n",
"import random\n",
"import pickle\n",
"import warnings\n",
"import tensorflow as tf\n",
"from tensorflow.python.keras.layers import Conv2D, GlobalAveragePooling2D, Dense\n",
"from utils import show_images, gaussian_filter, image_normalization, rescale, read_image"
"from tensorflow_core.python.keras.layers.pooling import GlobalAveragePooling2D\n",
"from tensorflow_core.python.layers.convolutional import Conv2D\n",
"from tensorflow_core.python.layers.core import Dense\n",
"from notebooks.utils import show_images, gaussian_filter, image_normalization, rescale, read_image\n",
"import imquality.datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"print(f'tensorflow version {tf.__version__}')"
Expand All @@ -41,8 +45,40 @@
"source": [
"## Dataset\n",
"\n",
"The dataset that we are going to use to train and test this algorithm is [TID2013](http://www.ponomarenko.info/tid2013.htm).\n",
"It is comprised of 25 reference images, and 24 different distortions with 5 severy levels each."
"The dataset that we are going to use to train and test this algorithm is [LiveIQA](https://live.ece.utexas.edu/research/quality/subjective.htm).\n",
"It is comprised of 30 reference images, and 5 different distortions with 5 severity levels each.\n",
"\n",
"The first thing we need to do is to download the dataset. For this, I have created a couple of builders\n",
"for Image Quality datasets in the [image-quality](https://github.com/ocampor/image-quality) package. The builders\n",
"are an interface defined by tensorflow in [tensorflow-datasets](https://www.tensorflow.org/datasets) package. \n",
"\n",
"This process is going to take a couple of minutes because the dataset size is around 700 megabytes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"builder = imquality.datasets.LiveIQA()\n",
"builder.download_and_prepare()"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"After downloading and preparing the data, we can turn the builder as a dataset and shuffle it. "
]
},
{
Expand All @@ -56,17 +92,20 @@
},
"outputs": [],
"source": [
"def get_image_url(image_idx: int, distortion: int, severity: int, base_uri=None) -> str:\n",
" if base_uri is None:\n",
" base_uri = 'https://data.ocampor.ai/image-quality/tid2013'\n",
" if severity == 0:\n",
" image_type = 'reference_images'\n",
" image_path = f'i{image_idx:02}.bmp'\n",
" else:\n",
" image_type = 'distorted_images'\n",
" image_path = f'i{image_idx:02}_{distortion:02}_{severity}.bmp'\n",
" \n",
" return f'{base_uri}/{image_type}/{image_path}'"
"ds = builder.as_dataset()['train']\n",
"ds = ds.shuffle(1024).batch(1).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The output is a generator; therefore, we cannot access it unless we iterate in a for loop. In order to display an\n",
"image, I am iterating once to extract a sample. You can iterate this several times to understand the dataset."
]
},
{
Expand All @@ -80,11 +119,14 @@
},
"outputs": [],
"source": [
"for idx in (2, 8, 21):\n",
" images = []\n",
" for distortion in (3, 8, 10, 11):\n",
" images.append(imageio.imread(get_image_url(idx, distortion, 5)))\n",
" show_images(images)"
"for features in ds.take(1):\n",
" distorted_image = features['distorted_image']\n",
" reference_image = features['reference_image']\n",
" dmos = tf.round(features['dmos'][0], 2)\n",
" distortion = features['distortion'][0]\n",
" print(f'The distortion of the image is {dmos} with'\n",
" f' a distortion {distortion}')\n",
" show_images([reference_image, distorted_image])"
]
},
{
Expand Down Expand Up @@ -163,14 +205,14 @@
},
"outputs": [],
"source": [
"results = []\n",
"for severity in (1, 3, 5):\n",
" I = tf.convert_to_tensor(imageio.imread(get_image_url(2, 11, severity)))\n",
" I_d = image_preprocess(I)\n",
"for features in ds.take(1):\n",
" distorted_image = features['distorted_image']\n",
" reference_image = features['reference_image']\n",
" I_d = image_preprocess(distorted_image)\n",
" I_d = tf.image.grayscale_to_rgb(I_d)\n",
" results.append(image_normalization(I_d, 0, 1))\n",
" I_d = image_normalization(I_d, 0, 1)\n",
"\n",
"show_images(results)"
"show_images([reference_image, I_d])"
]
},
{
Expand Down
7 changes: 6 additions & 1 deletion notebooks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import typing

import tensorflow as tf
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -60,9 +61,13 @@ def read_image(filename: str, **kwargs) -> tf.Tensor:
return tf.image.decode_image(stream, **kwargs)


def show_images(images: list, **kwargs):
def show_images(images: typing.List[tf.Tensor], **kwargs):
fig, axs = plt.subplots(1, len(images), figsize=(19, 10))
for image, ax in zip(images, axs):
assert image.get_shape().ndims in (3, 4), 'The tensor must be of dimension 3 or 4'
if image.get_shape().ndims == 4:
image = tf.squeeze(image)

_ = ax.imshow(image, **kwargs)
ax.axis('off')
fig.tight_layout()

0 comments on commit 5d5e467

Please sign in to comment.