Skip to content

Commit

Permalink
add quickdraw model files
Browse files Browse the repository at this point in the history
  • Loading branch information
vcarbune committed Dec 1, 2017
1 parent 0b2bc49 commit d3258fe
Show file tree
Hide file tree
Showing 4 changed files with 633 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tutorials/rnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ and use them are available in the tutorials.

* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
* [RNN Tutorial for Drawing Classification](http://tensorflow.org/tutorials/recurrent_quickdraw)

Here is a short overview of what is in this directory.

File | What's in it?
--- | ---
`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
`quickdraw/` | Quick, Draw! model, see the [RNN Tutorial for Drawing Classification](http://tensorflow.org/tutorials/recurrent_quickdraw)
42 changes: 42 additions & 0 deletions tutorials/rnn/quickdraw/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Description:
# Example classification model on Quick, Draw! dataset.

package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

py_binary(
name = "train_model",
srcs = [
"train_model.py",
],
srcs_version = "PY2AND3",
deps = [
"//third_party/py/tensorflow",
],
)

py_binary(
name = "create_dataset",
srcs = [
"create_dataset.py",
],
deps = [
"//third_party/py/numpy",
"//third_party/py/tensorflow",
],
)

filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//third_party/tensorflow:__subpackages__"],
)
211 changes: 211 additions & 0 deletions tutorials/rnn/quickdraw/create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Creates training and eval data from Quickdraw NDJSON files.
This tool reads the NDJSON files from https://quickdraw.withgoogle.com/data
and converts them into tensorflow.Example stored in TFRecord files.
The tensorflow example will contain 3 features:
shape - contains the shape of the sequence [length, dim] where dim=3.
class_index - the class index of the class for the example.
ink - a length * dim vector of the ink.
It creates disjoint training and evaluation sets.
python create_dataset.py \
--ndjson_path ${HOME}/ndjson \
--output_path ${HOME}/tfrecord
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import os
import random
import sys
import numpy as np
import tensorflow as tf


def parse_line(ndjson_line):
"""Parse an ndjson line and return ink (as np array) and classname."""
sample = json.loads(ndjson_line)
class_name = sample["word"]
if not class_name:
print ("Empty classname")
return None, None
inkarray = sample["drawing"]
stroke_lengths = [len(stroke[0]) for stroke in inkarray]
total_points = sum(stroke_lengths)
np_ink = np.zeros((total_points, 3), dtype=np.float32)
current_t = 0
if not inkarray:
print("Empty inkarray")
return None, None
for stroke in inkarray:
if len(stroke[0]) != len(stroke[1]):
print("Inconsistent number of x and y coordinates.")
return None, None
for i in [0, 1]:
np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
current_t += len(stroke[0])
np_ink[current_t - 1, 2] = 1 # stroke_end
# Preprocessing.
# 1. Size normalization.
lower = np.min(np_ink[:, 0:2], axis=0)
upper = np.max(np_ink[:, 0:2], axis=0)
scale = upper - lower
scale[scale == 0] = 1
np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
# 2. Compute deltas.
np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
np_ink = np_ink[1:, :]
return np_ink, class_name


def convert_data(trainingdata_dir,
observations_per_class,
output_file,
classnames,
output_shards=10,
offset=0):
"""Convert training data from ndjson files into tf.Example in tf.Record.
Args:
trainingdata_dir: path to the directory containin the training data.
The training data is stored in that directory as ndjson files.
observations_per_class: the number of items to load per class.
output_file: path where to write the output.
classnames: array with classnames - is auto created if not passed in.
output_shards: the number of shards to write the output in.
offset: the number of items to skip at the beginning of each file.
Returns:
classnames: the class names as strings. classnames[classes[i]] is the
textual representation of the class of the i-th data point.
"""

def _pick_output_shard():
return random.randint(0, output_shards - 1)

file_handles = []
# Open all input files.
for filename in sorted(tf.gfile.ListDirectory(trainingdata_dir)):
if not filename.endswith(".ndjson"):
print("Skipping", filename)
continue
file_handles.append(
tf.gfile.GFile(os.path.join(trainingdata_dir, filename), "r"))
if offset: # Fast forward all files to skip the offset.
count = 0
for _ in file_handles[-1]:
count += 1
if count == offset:
break

writers = []
for i in range(FLAGS.output_shards):
writers.append(
tf.python_io.TFRecordWriter("%s-%05i-of-%05i" % (output_file, i,
output_shards)))

reading_order = range(len(file_handles)) * observations_per_class
random.shuffle(reading_order)

for c in reading_order:
line = file_handles[c].readline()
ink = None
while ink is None:
ink, class_name = parse_line(line)
if ink is None:
print ("Couldn't parse ink from '" + line + "'.")
if class_name not in classnames:
classnames.append(class_name)
features = {}
features["class_index"] = tf.train.Feature(int64_list=tf.train.Int64List(
value=[classnames.index(class_name)]))
features["ink"] = tf.train.Feature(float_list=tf.train.FloatList(
value=ink.flatten()))
features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List(
value=ink.shape))
f = tf.train.Features(feature=features)
example = tf.train.Example(features=f)
writers[_pick_output_shard()].write(example.SerializeToString())

# Close all files
for w in writers:
w.close()
for f in file_handles:
f.close()
# Write the class list.
with tf.gfile.GFile(output_file + ".classes", "w") as f:
for class_name in classnames:
f.write(class_name + "\n")
return classnames


def main(argv):
del argv
classnames = convert_data(
FLAGS.ndjson_path,
FLAGS.train_observations_per_class,
os.path.join(FLAGS.output_path, "training.tfrecord"),
classnames=[],
output_shards=FLAGS.output_shards,
offset=0)
convert_data(
FLAGS.ndjson_path,
FLAGS.eval_observations_per_class,
os.path.join(FLAGS.output_path, "eval.tfrecord"),
classnames=classnames,
output_shards=FLAGS.output_shards,
offset=FLAGS.train_observations_per_class)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--ndjson_path",
type=str,
default="",
help="Directory where the ndjson files are stored.")
parser.add_argument(
"--output_path",
type=str,
default="",
help="Directory where to store the output TFRecord files.")
parser.add_argument(
"--train_observations_per_class",
type=int,
default=10000,
help="How many items per class to load for training.")
parser.add_argument(
"--eval_observations_per_class",
type=int,
default=1000,
help="How many items per class to load for evaluation.")
parser.add_argument(
"--output_shards",
type=int,
default=10,
help="Number of shards for the output.")

FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Loading

0 comments on commit d3258fe

Please sign in to comment.