Skip to content

Commit

Permalink
Added checkpoint_path for Estimator.predict()
Browse files Browse the repository at this point in the history
  • Loading branch information
terrytangyuan committed Mar 24, 2017
1 parent c7b80d5 commit 89621ee
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
7 changes: 5 additions & 2 deletions tensorflow/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
checkpoint_path=checkpoint_path,
name=name)

def predict(self, input_fn, predict_keys=None, hooks=None):
def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
"""Returns predictions for given features.
Args:
Expand All @@ -282,6 +282,8 @@ def predict(self, input_fn, predict_keys=None, hooks=None):
`None`, returns all.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used.
Yields:
Evaluated values of `predictions` tensors.
Expand All @@ -295,7 +297,8 @@ def predict(self, input_fn, predict_keys=None, hooks=None):
"""
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise ValueError('Could not find trained model in model_dir: {}.'.format(
self._model_dir))
Expand Down
26 changes: 25 additions & 1 deletion tensorflow/python/estimator/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,17 @@ def _model_fn(features, labels, mode):

class EstimatorPredictTest(test.TestCase):

def test_no_trained_model(self):
def test_no_trained_model_in_model_dir(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
with self.assertRaisesRegexp(ValueError,
'Could not find trained model in model_dir'):
next(est.predict(dummy_input_fn))

def test_no_trained_model_invalid_checkpoint_path(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
with self.assertRaises(ValueError):
next(est.predict(dummy_input_fn, checkpoint_path=saver.latest_checkpoint("fakedir")))

def test_tensor_predictions(self):

def _model_fn(features, labels, mode):
Expand Down Expand Up @@ -807,6 +812,25 @@ def _model_fn(features, labels, mode):
est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir)
self.assertEqual([32.], next(est2.predict(dummy_input_fn)))

def test_predict_from_checkpoint_path(self):

def _model_fn(features, labels, mode):
_, _ = features, labels
v = variables.Variable([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=prediction)

est1 = estimator.Estimator(model_fn=_model_fn)
est1.train(dummy_input_fn, steps=1)
est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir)
self.assertEqual([32.], next(est2.predict(
dummy_input_fn,
checkpoint_path=saver.latest_checkpoint(est1.model_dir))))

def test_scaffold_is_used(self):

def _model_fn_scaffold(features, labels, mode):
Expand Down

0 comments on commit 89621ee

Please sign in to comment.