Skip to content

Commit

Permalink
logging hook support added
Browse files Browse the repository at this point in the history
  • Loading branch information
Branden Chan committed Apr 15, 2019
1 parent 73e6e0c commit abfb572
Show file tree
Hide file tree
Showing 17 changed files with 2,082 additions and 1 deletion.
Binary file added .DS_Store
Binary file not shown.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,7 @@ dmypy.json

# Pyre type checker
.pyre/

.idea
runs
samples
86 changes: 86 additions & 0 deletions last_bert_run
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
FIRST LONG BERT

python3 run_pretraining.py \
--input_file=${BERT_BASE_DIR}/training_data/tf_examples.tfrecord1 \
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \
--do_train=True \
--do_eval=True \
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \
--train_batch_size=1024 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=500000 \
--num_warmup_steps=10 \
--learning_rate=1e-4 \
--use_tpu=True \
--tpu_name=bert\
--save_checkpoints_steps=3500\
--iterations_per_loop=500\
--save_summary_steps=500


4000 * 1024
35K batches in 8:30 hrs
About 4k batches per hour
4 mill samples per hour

--------------------------------------------------------------
--------------------------------------------------------------

Our Small Set

approx 4million samples in one .tfrecords file


python3 run_pretraining.py \
--input_file=${BERT_BASE_DIR}/training_data/tf_examples.tfrecord1 \
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \
--do_train=True \
--do_eval=True \
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \
--train_batch_size=1024 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=30000 \
--num_warmup_steps=1000 \
--learning_rate=2e-5 \
--use_tpu=True \
--tpu_name=bert\
--save_checkpoints_steps=500\
--iterations_per_loop=50\
--save_summary_steps=50


--------------------------------------------------------------
--------------------------------------------------------------

TOY SET

python create_pretraining_data.py \
--input_file=./samples/sample_text.txt \
--output_file=./samples/sample_text.tfrecord \
--vocab_file=./samples/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=20

python3 run_pretraining.py \
--input_file=${BERT_BASE_DIR}/training_data/sample_text.tfrecord \
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \
--do_train=True \
--do_eval=True \
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=1000 \
--num_warmup_steps=100 \
--learning_rate=2e-5 \
--use_tpu=True \
--tpu_name=bert\
--save_checkpoints_steps=200\
--iterations_per_loop=10\
--save_summary_steps=10
Empty file added logs/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions logs/cloud_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities that interact with cloud service.
"""

import requests

GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}


def on_gcp():
"""Detect whether the current running environment is on GCP."""
try:
# Timeout in 5 seconds, in case the test environment has connectivity issue.
# There is not default timeout, which means it might block forever.
response = requests.get(
GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
48 changes: 48 additions & 0 deletions logs/cloud_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for cloud_lib."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unittest

import mock
import requests

from official.utils.logs import cloud_lib


class CloudLibTest(unittest.TestCase):

@mock.patch("requests.get")
def test_on_gcp(self, mock_requests_get):
mock_response = mock.MagicMock()
mock_requests_get.return_value = mock_response
mock_response.status_code = 200

self.assertEqual(cloud_lib.on_gcp(), True)

@mock.patch("requests.get")
def test_not_on_gcp(self, mock_requests_get):
mock_requests_get.side_effect = requests.exceptions.ConnectionError()

self.assertEqual(cloud_lib.on_gcp(), False)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions logs/guidelines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Logging in official models

This library adds logging functions that print or save tensor values. Official models should define all common hooks
(using hooks helper) and a benchmark logger.

1. **Training Hooks**

Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log
tensor values during training.

hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined:
* LoggingTensorHook: Logs tensor values
* ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing.
* ExamplesPerSecondHook: Logs the number of examples processed per second.
* LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data
anaylsis pipeline.


2. **Benchmarks**

The benchmark logger provides useful functions for logging environment information, and evaluation results.
The module also contains a context which is used to update the status of the run.

Example usage:

```
from absl import app as absl_app
from official.utils.logs import hooks_helper
from official.utils.logs import logger
def model_main(flags_obj):
estimator = ...
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(...)
train_hooks = hooks_helper.get_train_hooks(...)
for epoch in range(10):
estimator.train(..., hooks=train_hooks)
eval_results = estimator.evaluate(...)
# Log a dictionary of metrics
benchmark_logger.log_evaluation_result(eval_results)
# Log an individual metric
benchmark_logger.log_metric(...)
def main(_):
with logger.benchmark_context(flags.FLAGS):
model_main(flags.FLAGS)
if __name__ == "__main__":
# define flags
absl_app.run(main)
```
130 changes: 130 additions & 0 deletions logs/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Hook that counts examples per second every N steps or seconds."""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.logs import logger


class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""

def __init__(self,
batch_size,
every_n_steps=None,
every_n_secs=None,
warm_steps=0,
metric_logger=None):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""

if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps"
" and every_n_secs should be provided.")

self._logger = metric_logger or logger.BaseBenchmarkLogger()

self._timer = tf.estimator.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)

self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
self._warm_steps = warm_steps
# List of examples per second logged every_n_steps.
self.current_examples_per_sec_list = []

def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")

def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.estimator.SessionRunArgs(self._global_step_tensor)

def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step = run_values.results

if self._timer.should_trigger_for_step(
global_step) and global_step > self._warm_steps:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps

# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time)
# Logs entries to be read from hook during or after run.
self.current_examples_per_sec_list.append(current_examples_per_sec)
self._logger.log_metric(
"average_examples_per_sec", average_examples_per_sec,
global_step=global_step)

self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=global_step)
Loading

0 comments on commit abfb572

Please sign in to comment.