forked from caicloud/tensorflow-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
perhapszzy
authored and
perhapszzy
committed
Jan 1, 2018
1 parent
ab35000
commit caa1095
Showing
9 changed files
with
1,586 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
224 changes: 224 additions & 0 deletions
224
...Learning_with_TensorFlow/1.4.0/Chapter07/.ipynb_checkpoints/7. 数据集基本使用方法-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tempfile\n", | ||
"import tensorflow as tf" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 1. 从数组创建数据集。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"1\n", | ||
"4\n", | ||
"9\n", | ||
"25\n", | ||
"64\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"input_data = [1, 2, 3, 5, 8]\n", | ||
"dataset = tf.data.Dataset.from_tensor_slices(input_data)\n", | ||
"\n", | ||
"# 定义迭代器。\n", | ||
"iterator = dataset.make_one_shot_iterator()\n", | ||
"\n", | ||
"# get_next() 返回代表一个输入数据的张量。\n", | ||
"x = iterator.get_next()\n", | ||
"y = x * x\n", | ||
"\n", | ||
"with tf.Session() as sess:\n", | ||
" for i in range(len(input_data)):\n", | ||
" print(sess.run(y))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 2. 读取文本文件里的数据。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"File1, line1.\n", | ||
"File1, line2.\n", | ||
"File2, line1.\n", | ||
"File2, line2.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# 创建文本文件作为本例的输入。\n", | ||
"with open(\"./test1.txt\", \"w\") as file:\n", | ||
" file.write(\"File1, line1.\\n\") \n", | ||
" file.write(\"File1, line2.\\n\")\n", | ||
"with open(\"./test2.txt\", \"w\") as file:\n", | ||
" file.write(\"File2, line1.\\n\") \n", | ||
" file.write(\"File2, line2.\\n\")\n", | ||
"\n", | ||
"# 从文本文件创建数据集。这里可以提供多个文件。\n", | ||
"input_files = [\"./test1.txt\", \"./test2.txt\"]\n", | ||
"dataset = tf.data.TextLineDataset(input_files)\n", | ||
"\n", | ||
"# 定义迭代器。\n", | ||
"iterator = dataset.make_one_shot_iterator()\n", | ||
"\n", | ||
"# 这里get_next()返回一个字符串类型的张量,代表文件中的一行。\n", | ||
"x = iterator.get_next() \n", | ||
"with tf.Session() as sess:\n", | ||
" for i in range(4):\n", | ||
" print(sess.run(x))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 3. 解析TFRecord文件里的数据。读取文件为本章第一节创建的文件。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"7\n", | ||
"3\n", | ||
"4\n", | ||
"6\n", | ||
"1\n", | ||
"8\n", | ||
"1\n", | ||
"0\n", | ||
"9\n", | ||
"8\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# 解析一个TFRecord的方法。\n", | ||
"def parser(record):\n", | ||
" features = tf.parse_single_example(\n", | ||
" record,\n", | ||
" features={\n", | ||
" 'image_raw':tf.FixedLenFeature([],tf.string),\n", | ||
" 'pixels':tf.FixedLenFeature([],tf.int64),\n", | ||
" 'label':tf.FixedLenFeature([],tf.int64)\n", | ||
" })\n", | ||
" decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)\n", | ||
" retyped_images = tf.cast(decoded_images, tf.float32)\n", | ||
" images = tf.reshape(retyped_images, [784])\n", | ||
" labels = tf.cast(features['label'],tf.int32)\n", | ||
" #pixels = tf.cast(features['pixels'],tf.int32)\n", | ||
" return images, labels\n", | ||
"\n", | ||
"# 从TFRecord文件创建数据集。这里可以提供多个文件。\n", | ||
"input_files = [\"output.tfrecords\"]\n", | ||
"dataset = tf.data.TFRecordDataset(input_files)\n", | ||
"\n", | ||
"# map()函数表示对数据集中的每一条数据进行调用解析方法。\n", | ||
"dataset = dataset.map(parser)\n", | ||
"\n", | ||
"# 定义遍历数据集的迭代器。\n", | ||
"iterator = dataset.make_one_shot_iterator()\n", | ||
"\n", | ||
"# 读取数据,可用于进一步计算\n", | ||
"image, label = iterator.get_next()\n", | ||
"\n", | ||
"with tf.Session() as sess:\n", | ||
" for i in range(10):\n", | ||
" x, y = sess.run([image, label]) \n", | ||
" print(y)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 4. 使用initializable_iterator来动态初始化数据集。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# 从TFRecord文件创建数据集,具体文件路径是一个placeholder,稍后再提供具体路径。\n", | ||
"input_files = tf.placeholder(tf.string)\n", | ||
"dataset = tf.data.TFRecordDataset(input_files)\n", | ||
"dataset = dataset.map(parser)\n", | ||
"\n", | ||
"# 定义遍历dataset的initializable_iterator。\n", | ||
"iterator = dataset.make_initializable_iterator()\n", | ||
"image, label = iterator.get_next()\n", | ||
"\n", | ||
"with tf.Session() as sess:\n", | ||
" # 首先初始化iterator,并给出input_files的值。\n", | ||
" sess.run(iterator.initializer,\n", | ||
" feed_dict={input_files: [\"output.tfrecords\"]})\n", | ||
" # 遍历所有数据一个epoch。当遍历结束时,程序会抛出OutOfRangeError。\n", | ||
" while True:\n", | ||
" try:\n", | ||
" x, y = sess.run([image, label])\n", | ||
" except tf.errors.OutOfRangeError:\n", | ||
" break \n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |
Oops, something went wrong.