-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Branden Chan
committed
Apr 15, 2019
1 parent
73e6e0c
commit abfb572
Showing
17 changed files
with
2,082 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,3 +114,7 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
.idea | ||
runs | ||
samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
FIRST LONG BERT | ||
|
||
python3 run_pretraining.py \ | ||
--input_file=${BERT_BASE_DIR}/training_data/tf_examples.tfrecord1 \ | ||
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \ | ||
--do_train=True \ | ||
--do_eval=True \ | ||
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \ | ||
--train_batch_size=1024 \ | ||
--max_seq_length=128 \ | ||
--max_predictions_per_seq=20 \ | ||
--num_train_steps=500000 \ | ||
--num_warmup_steps=10 \ | ||
--learning_rate=1e-4 \ | ||
--use_tpu=True \ | ||
--tpu_name=bert\ | ||
--save_checkpoints_steps=3500\ | ||
--iterations_per_loop=500\ | ||
--save_summary_steps=500 | ||
|
||
|
||
4000 * 1024 | ||
35K batches in 8:30 hrs | ||
About 4k batches per hour | ||
4 mill samples per hour | ||
|
||
-------------------------------------------------------------- | ||
-------------------------------------------------------------- | ||
|
||
Our Small Set | ||
|
||
approx 4million samples in one .tfrecords file | ||
|
||
|
||
python3 run_pretraining.py \ | ||
--input_file=${BERT_BASE_DIR}/training_data/tf_examples.tfrecord1 \ | ||
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \ | ||
--do_train=True \ | ||
--do_eval=True \ | ||
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \ | ||
--train_batch_size=1024 \ | ||
--max_seq_length=128 \ | ||
--max_predictions_per_seq=20 \ | ||
--num_train_steps=30000 \ | ||
--num_warmup_steps=1000 \ | ||
--learning_rate=2e-5 \ | ||
--use_tpu=True \ | ||
--tpu_name=bert\ | ||
--save_checkpoints_steps=500\ | ||
--iterations_per_loop=50\ | ||
--save_summary_steps=50 | ||
|
||
|
||
-------------------------------------------------------------- | ||
-------------------------------------------------------------- | ||
|
||
TOY SET | ||
|
||
python create_pretraining_data.py \ | ||
--input_file=./samples/sample_text.txt \ | ||
--output_file=./samples/sample_text.tfrecord \ | ||
--vocab_file=./samples/vocab.txt \ | ||
--do_lower_case=True \ | ||
--max_seq_length=128 \ | ||
--max_predictions_per_seq=20 \ | ||
--masked_lm_prob=0.15 \ | ||
--random_seed=12345 \ | ||
--dupe_factor=20 | ||
|
||
python3 run_pretraining.py \ | ||
--input_file=${BERT_BASE_DIR}/training_data/sample_text.tfrecord \ | ||
--output_dir=${BERT_BASE_DIR}/${RUN_DIR} \ | ||
--do_train=True \ | ||
--do_eval=True \ | ||
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \ | ||
--train_batch_size=32 \ | ||
--max_seq_length=128 \ | ||
--max_predictions_per_seq=20 \ | ||
--num_train_steps=1000 \ | ||
--num_warmup_steps=100 \ | ||
--learning_rate=2e-5 \ | ||
--use_tpu=True \ | ||
--tpu_name=bert\ | ||
--save_checkpoints_steps=200\ | ||
--iterations_per_loop=10\ | ||
--save_summary_steps=10 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Utilities that interact with cloud service. | ||
""" | ||
|
||
import requests | ||
|
||
GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname" | ||
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"} | ||
|
||
|
||
def on_gcp(): | ||
"""Detect whether the current running environment is on GCP.""" | ||
try: | ||
# Timeout in 5 seconds, in case the test environment has connectivity issue. | ||
# There is not default timeout, which means it might block forever. | ||
response = requests.get( | ||
GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5) | ||
return response.status_code == 200 | ||
except requests.exceptions.RequestException: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Tests for cloud_lib.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import unittest | ||
|
||
import mock | ||
import requests | ||
|
||
from official.utils.logs import cloud_lib | ||
|
||
|
||
class CloudLibTest(unittest.TestCase): | ||
|
||
@mock.patch("requests.get") | ||
def test_on_gcp(self, mock_requests_get): | ||
mock_response = mock.MagicMock() | ||
mock_requests_get.return_value = mock_response | ||
mock_response.status_code = 200 | ||
|
||
self.assertEqual(cloud_lib.on_gcp(), True) | ||
|
||
@mock.patch("requests.get") | ||
def test_not_on_gcp(self, mock_requests_get): | ||
mock_requests_get.side_effect = requests.exceptions.ConnectionError() | ||
|
||
self.assertEqual(cloud_lib.on_gcp(), False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Logging in official models | ||
|
||
This library adds logging functions that print or save tensor values. Official models should define all common hooks | ||
(using hooks helper) and a benchmark logger. | ||
|
||
1. **Training Hooks** | ||
|
||
Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log | ||
tensor values during training. | ||
|
||
hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined: | ||
* LoggingTensorHook: Logs tensor values | ||
* ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing. | ||
* ExamplesPerSecondHook: Logs the number of examples processed per second. | ||
* LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data | ||
anaylsis pipeline. | ||
|
||
|
||
2. **Benchmarks** | ||
|
||
The benchmark logger provides useful functions for logging environment information, and evaluation results. | ||
The module also contains a context which is used to update the status of the run. | ||
|
||
Example usage: | ||
|
||
``` | ||
from absl import app as absl_app | ||
from official.utils.logs import hooks_helper | ||
from official.utils.logs import logger | ||
def model_main(flags_obj): | ||
estimator = ... | ||
benchmark_logger = logger.get_benchmark_logger() | ||
benchmark_logger.log_run_info(...) | ||
train_hooks = hooks_helper.get_train_hooks(...) | ||
for epoch in range(10): | ||
estimator.train(..., hooks=train_hooks) | ||
eval_results = estimator.evaluate(...) | ||
# Log a dictionary of metrics | ||
benchmark_logger.log_evaluation_result(eval_results) | ||
# Log an individual metric | ||
benchmark_logger.log_metric(...) | ||
def main(_): | ||
with logger.benchmark_context(flags.FLAGS): | ||
model_main(flags.FLAGS) | ||
if __name__ == "__main__": | ||
# define flags | ||
absl_app.run(main) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Hook that counts examples per second every N steps or seconds.""" | ||
|
||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf # pylint: disable=g-bad-import-order | ||
|
||
from official.utils.logs import logger | ||
|
||
|
||
class ExamplesPerSecondHook(tf.estimator.SessionRunHook): | ||
"""Hook to print out examples per second. | ||
Total time is tracked and then divided by the total number of steps | ||
to get the average step time and then batch_size is used to determine | ||
the running average of examples per second. The examples per second for the | ||
most recent interval is also logged. | ||
""" | ||
|
||
def __init__(self, | ||
batch_size, | ||
every_n_steps=None, | ||
every_n_secs=None, | ||
warm_steps=0, | ||
metric_logger=None): | ||
"""Initializer for ExamplesPerSecondHook. | ||
Args: | ||
batch_size: Total batch size across all workers used to calculate | ||
examples/second from global time. | ||
every_n_steps: Log stats every n steps. | ||
every_n_secs: Log stats every n seconds. Exactly one of the | ||
`every_n_steps` or `every_n_secs` should be set. | ||
warm_steps: The number of steps to be skipped before logging and running | ||
average calculation. warm_steps steps refers to global steps across all | ||
workers, not on each worker | ||
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that | ||
hook should use to write the log. If None, BaseBenchmarkLogger will | ||
be used. | ||
Raises: | ||
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or | ||
both are set. | ||
""" | ||
|
||
if (every_n_steps is None) == (every_n_secs is None): | ||
raise ValueError("exactly one of every_n_steps" | ||
" and every_n_secs should be provided.") | ||
|
||
self._logger = metric_logger or logger.BaseBenchmarkLogger() | ||
|
||
self._timer = tf.estimator.SecondOrStepTimer( | ||
every_steps=every_n_steps, every_secs=every_n_secs) | ||
|
||
self._step_train_time = 0 | ||
self._total_steps = 0 | ||
self._batch_size = batch_size | ||
self._warm_steps = warm_steps | ||
# List of examples per second logged every_n_steps. | ||
self.current_examples_per_sec_list = [] | ||
|
||
def begin(self): | ||
"""Called once before using the session to check global step.""" | ||
self._global_step_tensor = tf.compat.v1.train.get_global_step() | ||
if self._global_step_tensor is None: | ||
raise RuntimeError( | ||
"Global step should be created to use StepCounterHook.") | ||
|
||
def before_run(self, run_context): # pylint: disable=unused-argument | ||
"""Called before each call to run(). | ||
Args: | ||
run_context: A SessionRunContext object. | ||
Returns: | ||
A SessionRunArgs object or None if never triggered. | ||
""" | ||
return tf.estimator.SessionRunArgs(self._global_step_tensor) | ||
|
||
def after_run(self, run_context, run_values): # pylint: disable=unused-argument | ||
"""Called after each call to run(). | ||
Args: | ||
run_context: A SessionRunContext object. | ||
run_values: A SessionRunValues object. | ||
""" | ||
global_step = run_values.results | ||
|
||
if self._timer.should_trigger_for_step( | ||
global_step) and global_step > self._warm_steps: | ||
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( | ||
global_step) | ||
if elapsed_time is not None: | ||
self._step_train_time += elapsed_time | ||
self._total_steps += elapsed_steps | ||
|
||
# average examples per second is based on the total (accumulative) | ||
# training steps and training time so far | ||
average_examples_per_sec = self._batch_size * ( | ||
self._total_steps / self._step_train_time) | ||
# current examples per second is based on the elapsed training steps | ||
# and training time per batch | ||
current_examples_per_sec = self._batch_size * ( | ||
elapsed_steps / elapsed_time) | ||
# Logs entries to be read from hook during or after run. | ||
self.current_examples_per_sec_list.append(current_examples_per_sec) | ||
self._logger.log_metric( | ||
"average_examples_per_sec", average_examples_per_sec, | ||
global_step=global_step) | ||
|
||
self._logger.log_metric( | ||
"current_examples_per_sec", current_examples_per_sec, | ||
global_step=global_step) |
Oops, something went wrong.