Skip to content

Commit

Permalink
Remove contrib thread pool. (tensorflow#6175)
Browse files Browse the repository at this point in the history
* Remove contrib thread pool.

* Remove commented out contrib import.

* Fix lint issues.

* move tf.data.options higher. Tweak line breaks.
  • Loading branch information
tfboyd authored Feb 11, 2019
1 parent 27e8617 commit b6c0c7f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 23 deletions.
20 changes: 9 additions & 11 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.contrib.data.python.ops import threadpool

from official.resnet import resnet_model
from official.utils.flags import core as flags_core
Expand Down Expand Up @@ -75,6 +74,15 @@ def process_record_dataset(dataset,
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads)
dataset = dataset.with_options(options)
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)

# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
Expand Down Expand Up @@ -102,16 +110,6 @@ def process_record_dataset(dataset,
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
datasets_num_private_threads,
display_name='input_pipeline_thread_pool'))

return dataset


Expand Down
3 changes: 2 additions & 1 deletion official/utils/accelerator/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def host_call_fn(global_step, *args):
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(tf.compat.v1.train.get_or_create_global_step(), [1])
global_step_tensor = tf.reshape(
tf.compat.v1.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]

return host_call_fn, [global_step_tensor] + other_tensors
Expand Down
9 changes: 6 additions & 3 deletions official/utils/data/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _serialize_shards(df_shards, columns, pool, writer):
for example in s:
writer.write(example)


def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
"""Write a dataframe to a binary file for a dataset to consume.
Expand All @@ -169,7 +170,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
Returns:
The path of the buffer.
"""
if tf.io.gfile.exists(buffer_path) and tf.io.gfile.stat(buffer_path).length > 0:
if (tf.io.gfile.exists(buffer_path) and
tf.io.gfile.stat(buffer_path).length > 0):
actual_size = tf.io.gfile.stat(buffer_path).length
if expected_size == actual_size:
return buffer_path
Expand All @@ -184,7 +186,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):

tf.io.gfile.makedirs(os.path.split(buffer_path)[0])

tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}".format(buffer_path))
tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}"
.format(buffer_path))

count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count())
Expand All @@ -195,7 +198,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
_serialize_shards(df_shards, columns, pool, writer)
count += sum([len(s) for s in df_shards])
tf.compat.v1.logging.info("{}/{} examples written."
.format(str(count).ljust(8), len(dataframe)))
.format(str(count).ljust(8), len(dataframe)))
finally:
pool.terminate()

Expand Down
6 changes: 4 additions & 2 deletions official/utils/logs/hooks_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return []

if use_tpu:
tf.compat.v1.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used.".format(name_list))
tf.compat.v1.logging.warning('hooks_helper received name_list `{}`, but a '
'TPU is specified. No hooks will be used.'
.format(name_list))
return []

train_hooks = []
Expand Down Expand Up @@ -142,6 +143,7 @@ def get_logging_metric_hook(tensors_to_log=None,
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins.
**kwargs: a dictionary of arguments.
Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format.
Expand Down
4 changes: 2 additions & 2 deletions official/utils/misc/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ def generate_synthetic_data(

def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir: {}".format(
flags_obj.model_dir))
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir)
10 changes: 6 additions & 4 deletions official/utils/testing/reference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ def _construct_and_save_reference_files(

if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "w") as f:
result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "w") as f:
json.dump(results, f)

with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "w") as f:
tf_version_json = os.path.join(data_dir, "tf_version.json")
with tf.io.gfile.GFile(tf_version_json, "w") as f:
json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f)

def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
Expand Down Expand Up @@ -262,7 +263,8 @@ def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "r") as f:
result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "r") as f:
expected_results = json.load(f)
self.assertAllClose(results, expected_results)

Expand Down

0 comments on commit b6c0c7f

Please sign in to comment.