Skip to content

Commit

Permalink
Add example with multiple readers pluged into TF (NVIDIA#58)
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL authored and ptrendx committed Jul 24, 2018
1 parent f9c1008 commit eb4aa4b
Showing 1 changed file with 299 additions and 0 deletions.
299 changes: 299 additions & 0 deletions docs/examples/tensorflow/tensorflow-resnet50-various-readers.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Tensorflow DALI plugin: using various readers\n",
"\n",
"### Overview\n",
"\n",
"This example shows how different readers could be used to interact with Tensorflow. It shows how flexible DALI is.\n",
"\n",
"Following readers are used in this example:\n",
"\n",
"- MXNetReader\n",
"- FileReader\n",
"- TFRecordReader\n",
"\n",
"For details on how to use them please see other [examples](..)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets start from defining some global constants"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# MXNet RecordIO\n",
"db_folder = \"/data/imagenet/train-480-val-256-recordio/\"\n",
"\n",
"# image dir with plain jpeg files\n",
"image_dir = \"../images\"\n",
"\n",
"# TFRecord\n",
"tfrecord = \"/data/imagenet/train-val-tfrecord-480/train-00001-of-01024\"\n",
"tfrecord_idx = \"idx_files/train-00001-of-01024.idx\"\n",
"tfrecord2idx_script = \"tfrecord2idx\"\n",
"\n",
"N = 4 # number of GPUs\n",
"BATCH_SIZE = 128 # batch size per GPU\n",
"ITERATIONS = 32\n",
"IMAGE_SIZE = 3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create idx file by calling `tfrecord2idx` script"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from subprocess import call\n",
"import os.path\n",
"\n",
"if not os.path.exists(\"idx_files\"):\n",
" os.mkdir(\"idx_files\")\n",
"\n",
"if not os.path.isfile(tfrecord_idx):\n",
" call([tfrecord2idx_script, tfrecord, tfrecord_idx])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us define:\n",
"- common part of pipeline, other pipelines will inherit it"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from nvidia.dali.pipeline import Pipeline\n",
"import nvidia.dali.ops as ops\n",
"import nvidia.dali.types as types\n",
"\n",
"class CommonPipeline(Pipeline):\n",
" def __init__(self, batch_size, num_threads, device_id):\n",
" super(CommonPipeline, self).__init__(batch_size, num_threads, device_id)\n",
"\n",
" self.decode = ops.nvJPEGDecoder(device = \"mixed\", output_type = types.RGB)\n",
" self.resize = ops.Resize(device = \"gpu\", random_resize = True,\n",
" resize_a = 256, resize_b = 480,\n",
" image_type = types.RGB,\n",
" interp_type = types.INTERP_LINEAR)\n",
" self.cmn = ops.CropMirrorNormalize(device = \"gpu\",\n",
" output_dtype = types.FLOAT,\n",
" crop = (227, 227),\n",
" image_type = types.RGB,\n",
" mean = [128., 128., 128.],\n",
" std = [1., 1., 1.])\n",
" self.uniform = ops.Uniform(range = (0.0, 1.0))\n",
"\n",
" def base_define_graph(self, inputs, labels):\n",
" images = self.decode(inputs)\n",
" images = self.resize(images)\n",
" output = self.cmn(images, crop_pos_x = self.uniform(),\n",
" crop_pos_y = self.uniform())\n",
" return (output, labels.gpu())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- MXNetReaderPipeline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from nvidia.dali.pipeline import Pipeline\n",
"import nvidia.dali.ops as ops\n",
"import nvidia.dali.types as types\n",
"\n",
"class MXNetReaderPipeline(CommonPipeline):\n",
" def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
" super(MXNetReaderPipeline, self).__init__(batch_size, num_threads, device_id)\n",
" self.input = ops.MXNetReader(path = [db_folder+\"train.rec\"], index_path=[db_folder+\"train.idx\"],\n",
" random_shuffle = True, shard_id = device_id, num_shards = num_gpus)\n",
"\n",
" def define_graph(self):\n",
" images, labels = self.input(name=\"Reader\")\n",
" return self.base_define_graph(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- FileReadPipeline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class FileReadPipeline(CommonPipeline):\n",
" def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
" super(FileReadPipeline, self).__init__(batch_size, num_threads, device_id)\n",
" self.input = ops.FileReader(file_root = image_dir)\n",
"\n",
" def define_graph(self):\n",
" images, labels = self.input()\n",
" return self.base_define_graph(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- TFRecordPipeline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import nvidia.dali.tfrecord as tfrec\n",
"\n",
"class TFRecordPipeline(CommonPipeline):\n",
" def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
" super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)\n",
" self.input = ops.TFRecordReader(path = tfrecord, \n",
" index_path = tfrecord_idx,\n",
" features = {\"image/encoded\" : tfrec.FixedLenFeature((), tfrec.string, \"\"),\n",
" \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64, -1)\n",
" })\n",
"\n",
" def define_graph(self):\n",
" inputs = self.input()\n",
" images = inputs[\"image/encoded\"]\n",
" labels = inputs[\"image/class/label\"]\n",
" return self.base_define_graph(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let us create function which build serialized pipeline on demand:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import nvidia.dali.plugin.tf as dali_tf\n",
"\n",
"def get_batch_test_dali(batch_size, pipe_type):\n",
" pipes = [pipe_name(batch_size=batch_size, num_threads=2, device_id = device_id, num_gpus = N) for device_id in range(N)]\n",
"\n",
" serialized_pipes = [pipe.serialize() for pipe in pipes]\n",
" del pipes\n",
" daliop = dali_tf.DALIIterator()\n",
" images = []\n",
" labels = []\n",
" for d in range(N):\n",
" with tf.device('/gpu:%i' % d):\n",
" image, label = daliop(serialized_pipeline = serialized_pipes[d],\n",
" shape = [BATCH_SIZE, 3, 227, 227],\n",
" image_type = tf.int32,\n",
" label_type = tf.float32,\n",
" device_id = d)\n",
" images.append(image)\n",
" labels.append(label)\n",
"\n",
" return [images, labels]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At the end let us test if all pipelines could be corectly build, serialized and run with TF session"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RUN: FileReadPipeline\n",
"OK : FileReadPipeline\n",
"RUN: MXNetReaderPipeline\n",
"OK : MXNetReaderPipeline\n",
"RUN: TFRecordPipeline\n",
"OK : TFRecordPipeline\n"
]
}
],
"source": [
"pipe_types = [FileReadPipeline, MXNetReaderPipeline, TFRecordPipeline]\n",
"for pipe_name in pipe_types:\n",
" print (\"RUN: \" + pipe_name.__name__)\n",
" test_batch = get_batch_test_dali(BATCH_SIZE, pipe_name)\n",
" x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name='x')\n",
" gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)\n",
" config = tf.ConfigProto(gpu_options=gpu_options)\n",
"\n",
" with tf.Session(config=config) as sess:\n",
" for i in range(ITERATIONS):\n",
" sess.run(test_batch)\n",
" print(\"OK : \" + pipe_name.__name__)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.15rc1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit eb4aa4b

Please sign in to comment.