Skip to content

Commit

Permalink
[tune] TF2.0 TensorBoard support (ray-project#5547)
Browse files Browse the repository at this point in the history
* Fix tensorboard log issue with tensorflow2.0

* tf2 support
  • Loading branch information
idthanm authored and ericl committed Aug 27, 2019
1 parent d206963 commit 52a6a1b
Showing 1 changed file with 60 additions and 23 deletions.
83 changes: 60 additions & 23 deletions python/ray/tune/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)

tf = None
use_tf150_api = True
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]


class Logger(object):
Expand Down Expand Up @@ -135,34 +135,71 @@ def update_config(self, config):
cloudpickle.dump(self.config, f)


def to_tf_values(result, path):
if use_tf150_api:
type_list = [int, float, np.float32, np.float64, np.int32]
def tf2_compat_logger(config, logdir):
global tf
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
tf = None
raise RuntimeError("Not importing TensorFlow for test purposes")
else:
type_list = [int, float]
import tensorflow as tf
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
distutils.version.LooseVersion("1.14.0"))
if use_tf2_api:
tf = tf.compat.v2 # setting this for 1.14
return TF2Logger(config, logdir)
else:
return TFLogger(config, logdir)


class TF2Logger(Logger):
def _init(self):
self._file_writer = tf.summary.create_file_writer(self.logdir)

def on_result(self, result):
with tf.device("/CPU:0"):
with self._file_writer.as_default():
step = result.get(
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]

tmp = result.copy()
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S,
TRAINING_ITERATION
]:
if k in tmp:
del tmp[k] # not useful to log these

flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
for attr, value in flat_result.items():
if type(value) in VALID_SUMMARY_TYPES:
tf.summary.scalar(
"/".join(path + [attr]), value, step=step)
self._file_writer.flush()

def flush(self):
self._file_writer.flush()

def close(self):
self._file_writer.close()


def to_tf_values(result, path):
flat_result = flatten_dict(result, delimiter="/")
values = [
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
for attr, value in flat_result.items() if type(value) in type_list
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
]
return values


class TFLogger(Logger):
def _init(self):
try:
global tf, use_tf150_api
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
tf = None
else:
import tensorflow
tf = tensorflow
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
except ImportError:
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
logger.info(
"Initializing TFLogger instead of TF2Logger. We recommend "
"migrating to TF2.0. This class will be removed in the future.")
self._file_writer = tf.summary.FileWriter(self.logdir)

def on_result(self, result):
Expand Down Expand Up @@ -220,7 +257,7 @@ def close(self):
self._file.close()


DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)


class UnifiedLogger(Logger):
Expand Down Expand Up @@ -250,9 +287,9 @@ def _init(self):
for cls in self._logger_cls_list:
try:
self._loggers.append(cls(self.config, self.logdir))
except Exception:
logger.warning("Could not instantiate {} - skipping.".format(
str(cls)))
except Exception as exc:
logger.warning("Could not instantiate {}: {}.".format(
cls.__name__, str(exc)))
self._log_syncer = get_log_syncer(
self.logdir,
remote_dir=self.logdir,
Expand Down

0 comments on commit 52a6a1b

Please sign in to comment.