Skip to content

Commit

Permalink
Add new test ID and test env info to the benchmark run. (tensorflow#4426
Browse files Browse the repository at this point in the history
)

* Add new test ID and test env info to the benchmark run.

* Fix test.

* Fix lint

* Address review comment.
  • Loading branch information
qlzh727 authored Jun 1, 2018
1 parent 47c5642 commit d2d6ab4
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 12 deletions.
4 changes: 3 additions & 1 deletion official/boosted_trees/train_higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def train_boosted_trees(flags_obj):
benchmark_logger.log_run_info(
model_name="boosted_trees",
dataset_name="higgs",
run_params=run_params)
run_params=run_params,
test_id=flags_obj.benchmark_test_id)

# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
Expand Down Expand Up @@ -244,6 +245,7 @@ def main(_):
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)

flags.DEFINE_integer(
Expand Down
3 changes: 2 additions & 1 deletion official/recommendation/ncf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def run_ncf(_):
benchmark_logger.log_run_info(
model_name="recommendation",
dataset_name=FLAGS.dataset,
run_params=run_params)
run_params=run_params,
test_id=FLAGS.benchmark_test_id)

# Training and evaluation cycle
def train_input_fn():
Expand Down
4 changes: 3 additions & 1 deletion official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,10 @@ def resnet_main(
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}

benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('resnet', dataset_name, run_params)
benchmark_logger.log_run_info('resnet', dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)

train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
Expand Down
3 changes: 2 additions & 1 deletion official/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def run_transformer(flags_obj):
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params.__dict__)
run_params=params.__dict__,
test_id=flags_obj.benchmark_test_id)

# Train and evaluate transformer model
estimator = tf.estimator.Estimator(
Expand Down
8 changes: 8 additions & 0 deletions official/utils/flags/_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."))
flags.DEFINE_string(
name="benchmark_test_id", short_name="bti", default=None,
help=help_wrap("The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware "
"independent and could be used compare the performance "
"between different test runs. This flag is designed for "
"human consumption, and does not have any impact within "
"the system."))

if benchmark_log_dir:
flags.DEFINE_string(
Expand Down
31 changes: 31 additions & 0 deletions official/utils/logs/cloud_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities that interact with cloud service.
"""

import requests

GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"}


def on_gcp():
"""Detect whether the current running environment is on GCP"""
try:
response = requests.get(GCP_METADATA_URL, headers=GCP_METADATA_HEADER)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
48 changes: 48 additions & 0 deletions official/utils/logs/cloud_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for cloud_lib."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unittest

import mock
import requests

from official.utils.logs import cloud_lib


class CloudLibTest(unittest.TestCase):

@mock.patch("requests.get")
def test_on_gcp(self, mock_requests_get):
mock_response = mock.MagicMock()
mock_requests_get.return_value = mock_response
mock_response.status_code = 200

self.assertEqual(cloud_lib.on_gcp(), True)

@mock.patch("requests.get")
def test_not_on_gcp(self, mock_requests_get):
mock_requests_get.side_effect = requests.exceptions.ConnectionError()

self.assertEqual(cloud_lib.on_gcp(), False)


if __name__ == "__main__":
unittest.main()
32 changes: 25 additions & 7 deletions official/utils/logs/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@
import tensorflow as tf
from tensorflow.python.client import device_lib

from official.utils.logs import cloud_lib

METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
GCP_TEST_ENV = "GCP"
RUN_STATUS_SUCCESS = "success"
RUN_STATUS_FAILURE = "failure"
RUN_STATUS_RUNNING = "running"


FLAGS = flags.FLAGS

# Don't use it directly. Use get_benchmark_logger to access a logger.
Expand Down Expand Up @@ -141,9 +145,10 @@ def log_metric(self, name, value, unit=None, global_step=None, extras=None):
if metric:
tf.logging.info("Benchmark metric: %s", metric)

def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
tf.logging.info("Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params))
_gather_run_info(model_name, dataset_name, run_params,
test_id))

def on_finish(self, status):
pass
Expand Down Expand Up @@ -183,7 +188,7 @@ def log_metric(self, name, value, unit=None, global_step=None, extras=None):
tf.logging.warning("Failed to dump metric to log file: "
"name %s, value %s, error %s", name, value, e)

def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Expand All @@ -193,8 +198,10 @@ def log_run_info(self, model_name, dataset_name, run_params):
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params)
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)

with tf.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
Expand Down Expand Up @@ -251,7 +258,7 @@ def log_metric(self, name, value, unit=None, global_step=None, extras=None):
self._run_id,
[metric]))

def log_run_info(self, model_name, dataset_name, run_params):
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Expand All @@ -261,8 +268,10 @@ def log_run_info(self, model_name, dataset_name, run_params):
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params)
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
# Starting new thread for bigquery upload in case it might take long time
# and impact the benchmark and performance measurement. Starting a new
# thread might have potential performance impact for model that run on CPU.
Expand All @@ -288,12 +297,13 @@ def on_finish(self, status):
status))


def _gather_run_info(model_name, dataset_name, run_params):
def _gather_run_info(model_name, dataset_name, run_params, test_id):
"""Collect the benchmark run information for the local environment."""
run_info = {
"model_name": model_name,
"dataset": {"name": dataset_name},
"machine_config": {},
"test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
_collect_tensorflow_info(run_info)
Expand All @@ -302,6 +312,7 @@ def _gather_run_info(model_name, dataset_name, run_params):
_collect_cpu_info(run_info)
_collect_gpu_info(run_info)
_collect_memory_info(run_info)
_collect_test_environment(run_info)
return run_info


Expand Down Expand Up @@ -403,6 +414,13 @@ def _collect_memory_info(run_info):
tf.logging.warn("'psutil' not imported. Memory info will not be logged.")


def _collect_test_environment(run_info):
"""Detect the local environment, eg GCE, AWS or DGX, etc."""
if cloud_lib.on_gcp():
run_info["test_environment"] = GCP_TEST_ENV
# TODO(scottzhu): Add more testing env detection for other platform


def _parse_gpu_model(physical_device_desc):
# Assume all the GPU connected are same model
for kv in physical_device_desc.split(","):
Expand Down
3 changes: 2 additions & 1 deletion official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def eval_input_fn():
}

benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params,
test_id=flags_obj.benchmark_test_id)

loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
Expand Down

0 comments on commit d2d6ab4

Please sign in to comment.