Skip to content

Commit

Permalink
Moved tpu_config.RunConfig check to inside of use_tpu block
Browse files Browse the repository at this point in the history
  • Loading branch information
terrytangyuan committed Jun 23, 2017
1 parent 2336cdf commit 4470dbe
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,27 +146,25 @@ def end(self, session):
class TpuEstimator(estimator_lib.Estimator):
"""Estimator with TPU support.
The only difference is a wrapped model_fn is set in the constructor.
The only difference is a wrapped model_fn is set in the constructor.
"""

def __init__(self,
model_fn=None,
model_dir=None,
config=None,
params=None,
use_tpu=True):
if use_tpu:
if not isinstance(config, tpu_config.RunConfig):
raise ValueError('`config` must be `tpu_config.RunConfig`')
model_function = wrapped_model_fn(model_fn, config)
else:
model_function = model_fn

super(TpuEstimator, self).__init__(
model_fn=model_function,
model_dir=model_dir,
config=config,
params=params)
if not isinstance(config, tpu_config.RunConfig):
raise ValueError('`config` must be `tpu_config.RunConfig`')

def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
Expand Down

0 comments on commit 4470dbe

Please sign in to comment.