forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
633 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Description: | ||
# Example classification model on Quick, Draw! dataset. | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
licenses(["notice"]) # Apache 2.0 | ||
|
||
exports_files(["LICENSE"]) | ||
|
||
py_binary( | ||
name = "train_model", | ||
srcs = [ | ||
"train_model.py", | ||
], | ||
srcs_version = "PY2AND3", | ||
deps = [ | ||
"//third_party/py/tensorflow", | ||
], | ||
) | ||
|
||
py_binary( | ||
name = "create_dataset", | ||
srcs = [ | ||
"create_dataset.py", | ||
], | ||
deps = [ | ||
"//third_party/py/numpy", | ||
"//third_party/py/tensorflow", | ||
], | ||
) | ||
|
||
filegroup( | ||
name = "all_files", | ||
srcs = glob( | ||
["**/*"], | ||
exclude = [ | ||
"**/METADATA", | ||
"**/OWNERS", | ||
], | ||
), | ||
visibility = ["//third_party/tensorflow:__subpackages__"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
r"""Creates training and eval data from Quickdraw NDJSON files. | ||
This tool reads the NDJSON files from https://quickdraw.withgoogle.com/data | ||
and converts them into tensorflow.Example stored in TFRecord files. | ||
The tensorflow example will contain 3 features: | ||
shape - contains the shape of the sequence [length, dim] where dim=3. | ||
class_index - the class index of the class for the example. | ||
ink - a length * dim vector of the ink. | ||
It creates disjoint training and evaluation sets. | ||
python create_dataset.py \ | ||
--ndjson_path ${HOME}/ndjson \ | ||
--output_path ${HOME}/tfrecord | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import json | ||
import os | ||
import random | ||
import sys | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
|
||
def parse_line(ndjson_line): | ||
"""Parse an ndjson line and return ink (as np array) and classname.""" | ||
sample = json.loads(ndjson_line) | ||
class_name = sample["word"] | ||
if not class_name: | ||
print ("Empty classname") | ||
return None, None | ||
inkarray = sample["drawing"] | ||
stroke_lengths = [len(stroke[0]) for stroke in inkarray] | ||
total_points = sum(stroke_lengths) | ||
np_ink = np.zeros((total_points, 3), dtype=np.float32) | ||
current_t = 0 | ||
if not inkarray: | ||
print("Empty inkarray") | ||
return None, None | ||
for stroke in inkarray: | ||
if len(stroke[0]) != len(stroke[1]): | ||
print("Inconsistent number of x and y coordinates.") | ||
return None, None | ||
for i in [0, 1]: | ||
np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i] | ||
current_t += len(stroke[0]) | ||
np_ink[current_t - 1, 2] = 1 # stroke_end | ||
# Preprocessing. | ||
# 1. Size normalization. | ||
lower = np.min(np_ink[:, 0:2], axis=0) | ||
upper = np.max(np_ink[:, 0:2], axis=0) | ||
scale = upper - lower | ||
scale[scale == 0] = 1 | ||
np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale | ||
# 2. Compute deltas. | ||
np_ink[1:, 0:2] -= np_ink[0:-1, 0:2] | ||
np_ink = np_ink[1:, :] | ||
return np_ink, class_name | ||
|
||
|
||
def convert_data(trainingdata_dir, | ||
observations_per_class, | ||
output_file, | ||
classnames, | ||
output_shards=10, | ||
offset=0): | ||
"""Convert training data from ndjson files into tf.Example in tf.Record. | ||
Args: | ||
trainingdata_dir: path to the directory containin the training data. | ||
The training data is stored in that directory as ndjson files. | ||
observations_per_class: the number of items to load per class. | ||
output_file: path where to write the output. | ||
classnames: array with classnames - is auto created if not passed in. | ||
output_shards: the number of shards to write the output in. | ||
offset: the number of items to skip at the beginning of each file. | ||
Returns: | ||
classnames: the class names as strings. classnames[classes[i]] is the | ||
textual representation of the class of the i-th data point. | ||
""" | ||
|
||
def _pick_output_shard(): | ||
return random.randint(0, output_shards - 1) | ||
|
||
file_handles = [] | ||
# Open all input files. | ||
for filename in sorted(tf.gfile.ListDirectory(trainingdata_dir)): | ||
if not filename.endswith(".ndjson"): | ||
print("Skipping", filename) | ||
continue | ||
file_handles.append( | ||
tf.gfile.GFile(os.path.join(trainingdata_dir, filename), "r")) | ||
if offset: # Fast forward all files to skip the offset. | ||
count = 0 | ||
for _ in file_handles[-1]: | ||
count += 1 | ||
if count == offset: | ||
break | ||
|
||
writers = [] | ||
for i in range(FLAGS.output_shards): | ||
writers.append( | ||
tf.python_io.TFRecordWriter("%s-%05i-of-%05i" % (output_file, i, | ||
output_shards))) | ||
|
||
reading_order = range(len(file_handles)) * observations_per_class | ||
random.shuffle(reading_order) | ||
|
||
for c in reading_order: | ||
line = file_handles[c].readline() | ||
ink = None | ||
while ink is None: | ||
ink, class_name = parse_line(line) | ||
if ink is None: | ||
print ("Couldn't parse ink from '" + line + "'.") | ||
if class_name not in classnames: | ||
classnames.append(class_name) | ||
features = {} | ||
features["class_index"] = tf.train.Feature(int64_list=tf.train.Int64List( | ||
value=[classnames.index(class_name)])) | ||
features["ink"] = tf.train.Feature(float_list=tf.train.FloatList( | ||
value=ink.flatten())) | ||
features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List( | ||
value=ink.shape)) | ||
f = tf.train.Features(feature=features) | ||
example = tf.train.Example(features=f) | ||
writers[_pick_output_shard()].write(example.SerializeToString()) | ||
|
||
# Close all files | ||
for w in writers: | ||
w.close() | ||
for f in file_handles: | ||
f.close() | ||
# Write the class list. | ||
with tf.gfile.GFile(output_file + ".classes", "w") as f: | ||
for class_name in classnames: | ||
f.write(class_name + "\n") | ||
return classnames | ||
|
||
|
||
def main(argv): | ||
del argv | ||
classnames = convert_data( | ||
FLAGS.ndjson_path, | ||
FLAGS.train_observations_per_class, | ||
os.path.join(FLAGS.output_path, "training.tfrecord"), | ||
classnames=[], | ||
output_shards=FLAGS.output_shards, | ||
offset=0) | ||
convert_data( | ||
FLAGS.ndjson_path, | ||
FLAGS.eval_observations_per_class, | ||
os.path.join(FLAGS.output_path, "eval.tfrecord"), | ||
classnames=classnames, | ||
output_shards=FLAGS.output_shards, | ||
offset=FLAGS.train_observations_per_class) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.register("type", "bool", lambda v: v.lower() == "true") | ||
parser.add_argument( | ||
"--ndjson_path", | ||
type=str, | ||
default="", | ||
help="Directory where the ndjson files are stored.") | ||
parser.add_argument( | ||
"--output_path", | ||
type=str, | ||
default="", | ||
help="Directory where to store the output TFRecord files.") | ||
parser.add_argument( | ||
"--train_observations_per_class", | ||
type=int, | ||
default=10000, | ||
help="How many items per class to load for training.") | ||
parser.add_argument( | ||
"--eval_observations_per_class", | ||
type=int, | ||
default=1000, | ||
help="How many items per class to load for evaluation.") | ||
parser.add_argument( | ||
"--output_shards", | ||
type=int, | ||
default=10, | ||
help="Number of shards for the output.") | ||
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Oops, something went wrong.