Skip to content

Commit

Permalink
Merge pull request caicloud#90 from perhapszzy/master
Browse files Browse the repository at this point in the history
add more examples
  • Loading branch information
perhapszzy authored Jan 1, 2018
2 parents 37412de + caa1095 commit b303814
Show file tree
Hide file tree
Showing 9 changed files with 1,586 additions and 30 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Deep_Learning_with_TensorFlow/datasets/flower_processed_data.npy
Deep_Learning_with_TensorFlow/1.4.0/Chapter05/5. MNIST\346\234\200\344\275\263\345\256\236\350\267\265/MNIST_model/*
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/output.tfrecords
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/data.tfrecords*
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/output_test.tfrecords
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/test1.txt
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/test2.txt
Deep_Learning_with_TensorFlow/1.4.0/Chapter08/sin.png
Deep_Learning_with_TensorFlow/1.4.0/Chapter10/log/*
Deep_Learning_with_TensorFlow/1.4.0/Chapter11/log/*
.DS_Store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
"TFRecord文件已保存。\n"
"TFRecord训练文件已保存。\n",
"TFRecord测试文件已保存。\n"
]
}
],
Expand All @@ -45,27 +46,43 @@
"def _bytes_feature(value):\n",
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
"\n",
"# 读取mnist数据。\n",
"# 将数据转化为tf.train.Example格式。\n",
"def _make_example(pixels, label, image):\n",
" image_raw = image.tostring()\n",
" example = tf.train.Example(features=tf.train.Features(feature={\n",
" 'pixels': _int64_feature(pixels),\n",
" 'label': _int64_feature(np.argmax(label)),\n",
" 'image_raw': _bytes_feature(image_raw)\n",
" }))\n",
" return example\n",
"\n",
"# 读取mnist训练数据。\n",
"mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\",dtype=tf.uint8, one_hot=True)\n",
"images = mnist.train.images\n",
"labels = mnist.train.labels\n",
"pixels = images.shape[1]\n",
"num_examples = mnist.train.num_examples\n",
"\n",
"# 输出TFRecord文件的地址。\n",
"filename = \"output.tfrecords\"\n",
"writer = tf.python_io.TFRecordWriter(filename)\n",
"for index in range(num_examples):\n",
" image_raw = images[index].tostring()\n",
"# 输出包含训练数据的TFRecord文件。\n",
"with tf.python_io.TFRecordWriter(\"output.tfrecords\") as writer:\n",
" for index in range(num_examples):\n",
" example = _make_example(pixels, labels[index], images[index])\n",
" writer.write(example.SerializeToString())\n",
"print(\"TFRecord训练文件已保存。\")\n",
"\n",
" example = tf.train.Example(features=tf.train.Features(feature={\n",
" 'pixels': _int64_feature(pixels),\n",
" 'label': _int64_feature(np.argmax(labels[index])),\n",
" 'image_raw': _bytes_feature(image_raw)\n",
" }))\n",
" writer.write(example.SerializeToString())\n",
"writer.close()\n",
"print \"TFRecord文件已保存。\""
"# 读取mnist测试数据。\n",
"images_test = mnist.test.images\n",
"labels_test = mnist.test.labels\n",
"pixels_test = images_test.shape[1]\n",
"num_examples_test = mnist.test.num_examples\n",
"\n",
"# 输出包含测试数据的TFRecord文件。\n",
"with tf.python_io.TFRecordWriter(\"output_test.tfrecords\") as writer:\n",
" for index in range(num_examples_test):\n",
" example = _make_example(\n",
" pixels_test, labels_test[index], images_test[index])\n",
" writer.write(example.SerializeToString())\n",
"print(\"TFRecord测试文件已保存。\")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tempfile\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. 从数组创建数据集。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"4\n",
"9\n",
"25\n",
"64\n"
]
}
],
"source": [
"input_data = [1, 2, 3, 5, 8]\n",
"dataset = tf.data.Dataset.from_tensor_slices(input_data)\n",
"\n",
"# 定义迭代器。\n",
"iterator = dataset.make_one_shot_iterator()\n",
"\n",
"# get_next() 返回代表一个输入数据的张量。\n",
"x = iterator.get_next()\n",
"y = x * x\n",
"\n",
"with tf.Session() as sess:\n",
" for i in range(len(input_data)):\n",
" print(sess.run(y))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. 读取文本文件里的数据。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File1, line1.\n",
"File1, line2.\n",
"File2, line1.\n",
"File2, line2.\n"
]
}
],
"source": [
"# 创建文本文件作为本例的输入。\n",
"with open(\"./test1.txt\", \"w\") as file:\n",
" file.write(\"File1, line1.\\n\") \n",
" file.write(\"File1, line2.\\n\")\n",
"with open(\"./test2.txt\", \"w\") as file:\n",
" file.write(\"File2, line1.\\n\") \n",
" file.write(\"File2, line2.\\n\")\n",
"\n",
"# 从文本文件创建数据集。这里可以提供多个文件。\n",
"input_files = [\"./test1.txt\", \"./test2.txt\"]\n",
"dataset = tf.data.TextLineDataset(input_files)\n",
"\n",
"# 定义迭代器。\n",
"iterator = dataset.make_one_shot_iterator()\n",
"\n",
"# 这里get_next()返回一个字符串类型的张量,代表文件中的一行。\n",
"x = iterator.get_next() \n",
"with tf.Session() as sess:\n",
" for i in range(4):\n",
" print(sess.run(x))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. 解析TFRecord文件里的数据。读取文件为本章第一节创建的文件。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7\n",
"3\n",
"4\n",
"6\n",
"1\n",
"8\n",
"1\n",
"0\n",
"9\n",
"8\n"
]
}
],
"source": [
"# 解析一个TFRecord的方法。\n",
"def parser(record):\n",
" features = tf.parse_single_example(\n",
" record,\n",
" features={\n",
" 'image_raw':tf.FixedLenFeature([],tf.string),\n",
" 'pixels':tf.FixedLenFeature([],tf.int64),\n",
" 'label':tf.FixedLenFeature([],tf.int64)\n",
" })\n",
" decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)\n",
" retyped_images = tf.cast(decoded_images, tf.float32)\n",
" images = tf.reshape(retyped_images, [784])\n",
" labels = tf.cast(features['label'],tf.int32)\n",
" #pixels = tf.cast(features['pixels'],tf.int32)\n",
" return images, labels\n",
"\n",
"# 从TFRecord文件创建数据集。这里可以提供多个文件。\n",
"input_files = [\"output.tfrecords\"]\n",
"dataset = tf.data.TFRecordDataset(input_files)\n",
"\n",
"# map()函数表示对数据集中的每一条数据进行调用解析方法。\n",
"dataset = dataset.map(parser)\n",
"\n",
"# 定义遍历数据集的迭代器。\n",
"iterator = dataset.make_one_shot_iterator()\n",
"\n",
"# 读取数据,可用于进一步计算\n",
"image, label = iterator.get_next()\n",
"\n",
"with tf.Session() as sess:\n",
" for i in range(10):\n",
" x, y = sess.run([image, label]) \n",
" print(y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4. 使用initializable_iterator来动态初始化数据集。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 从TFRecord文件创建数据集,具体文件路径是一个placeholder,稍后再提供具体路径。\n",
"input_files = tf.placeholder(tf.string)\n",
"dataset = tf.data.TFRecordDataset(input_files)\n",
"dataset = dataset.map(parser)\n",
"\n",
"# 定义遍历dataset的initializable_iterator。\n",
"iterator = dataset.make_initializable_iterator()\n",
"image, label = iterator.get_next()\n",
"\n",
"with tf.Session() as sess:\n",
" # 首先初始化iterator,并给出input_files的值。\n",
" sess.run(iterator.initializer,\n",
" feed_dict={input_files: [\"output.tfrecords\"]})\n",
" # 遍历所有数据一个epoch。当遍历结束时,程序会抛出OutOfRangeError。\n",
" while True:\n",
" try:\n",
" x, y = sess.run([image, label])\n",
" except tf.errors.OutOfRangeError:\n",
" break \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Loading

0 comments on commit b303814

Please sign in to comment.