Skip to content

Commit

Permalink
A bag of fixes and consistency munging:
Browse files Browse the repository at this point in the history
* Support for axis / proba reduction in Estimator for prediction.
* Support for input_fn in prediction.
* Graph actions's eval using SessionManager (before was reinitializing variables and then restoring, which was leading to a number of issues). --supervisor more.
* Don't shuffle prediction data.
* Use tuple of one elements as a non streaming metrics (to not break existing usage for now).
* Different batch size doesn't lead to signature errors anymore.
* Various other small fixes.
Change: 123069738
  • Loading branch information
ilblackdragon authored and tensorflower-gardener committed May 24, 2016
1 parent 35204bc commit c80c9de
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 218 deletions.
4 changes: 2 additions & 2 deletions tensorflow/contrib/learn/python/learn/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def _predict(self, x, axis=-1, batch_size=None):
input_fn=predict_data_feeder.input_builder,
feed_fn=predict_data_feeder.get_feed_dict_fn())
if self.n_classes > 1 and axis != -1:
preds = preds['predictions'].argmax(axis=axis)
preds = preds.argmax(axis=axis)
else:
preds = preds['predictions']
preds = preds
return preds

def predict(self, x, axis=1, batch_size=None):
Expand Down
266 changes: 141 additions & 125 deletions tensorflow/contrib/learn/python/learn/estimators/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers
from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
Expand All @@ -51,7 +52,7 @@
# Default metrics for evaluation.
_EVAL_METRICS = {
'regression': {
'mean_squared_error': losses.sum_of_squares,
'mean_squared_error': metrics_lib.streaming_mean_squared_error,
},
'classification': {
'logistic': losses.sigmoid_cross_entropy,
Expand All @@ -74,28 +75,15 @@ class ModeKeys(object):


def _get_input_fn(x, y, batch_size):
# TODO(ipoloshukin): Remove this when refactor of data_feeder is done
if hasattr(x, 'create_graph') and hasattr(y, 'create_graph'):
def input_fn():
return x.create_graph(), y.create_graph()
return input_fn, None

df = data_feeder.setup_train_data_feeder(x, y,
n_classes=None,
batch_size=batch_size)
df = data_feeder.setup_train_data_feeder(
x, y, n_classes=None, batch_size=batch_size)
return df.input_builder, df.get_feed_dict_fn()


def _get_predict_input_fn(x, batch_size):
# TODO(ipoloshukin): Remove this when refactor of data_feeder is done
if hasattr(x, 'create_graph'):
def input_fn():
return x.create_graph()
return input_fn, None

df = data_feeder.setup_train_data_feeder(x, None,
n_classes=None,
batch_size=batch_size, epochs=1)
def _get_predict_input_fn(x, y, batch_size):
df = data_feeder.setup_train_data_feeder(
x, y, n_classes=None, batch_size=batch_size,
shuffle=False, epochs=1)
return df.input_builder, df.get_feed_dict_fn()


Expand Down Expand Up @@ -147,78 +135,6 @@ def __init__(self, model_dir=None, config=None):

self._graph = None

@property
def model_dir(self):
return self._model_dir

@abc.abstractproperty
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
Expected to be overriden by sub-classes that require custom support.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
pass

@abc.abstractproperty
def _get_predict_ops(self, features):
"""Method that builds model graph and returns prediction ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
Returns:
predictions: `Tensor` or `dict` of `Tensor` objects.
"""
pass

def _get_eval_ops(self, features, targets, metrics):
"""Method that builds model graph and returns evaluation ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
metrics: `dict` of functions that take predictions and targets.
Returns:
metrics: `dict` of `Tensor` objects.
"""
predictions = self._get_predict_ops(features)
result = {}
for name, metric in six.iteritems(metrics):
result[name] = metric(predictions, targets)
return result

def _get_feature_ops_from_example(self, examples_batch):
"""Method that returns features given the batch of examples.
This method will be used to export model into a server.
Args:
examples_batch: batch of tf.Example
Returns:
features: `Tensor` or `dict` of `Tensor` objects.
"""
raise NotImplementedError('_get_feature_ops_from_example not implemented '
'in BaseEstimator')

def _get_default_metric_functions(self):
"""Method that provides default metric operations.
This functions is intented to be overridden by sub-classes.
Returns:
`dict` of functions that take predictions and targets `Tensor` objects and
return `Tensor`.
"""
return {}

def fit(self, x, y, steps, batch_size=32, monitors=None):
"""Trains a model given training data X and y.
Expand Down Expand Up @@ -296,7 +212,7 @@ def evaluate(self,
input_fn=None,
feed_fn=None,
batch_size=32,
steps=100,
steps=None,
metrics=None,
name=None):
"""Evaluates given model with provided evaluation data.
Expand Down Expand Up @@ -325,37 +241,98 @@ def evaluate(self,
raise ValueError('Either x and y or input_fn must be None.')
if input_fn is None:
assert x is not None
input_fn, feed_fn = _get_input_fn(x, y, batch_size)
input_fn, feed_fn = _get_predict_input_fn(x, y, batch_size)
return self._evaluate_model(input_fn=input_fn,
feed_fn=feed_fn,
steps=steps,
metrics=metrics,
name=name)

def predict(self, x, axis=None, batch_size=None):
def predict(self, x=None, input_fn=None, batch_size=None):
"""Returns predictions for given features.
Args:
x: features.
axis: Axis on which to argmax. (for classification).
input_fn: Input function. If set, x must be None.
batch_size: Override default batch size.
Returns:
Numpy array of predicted classes or regression values.
"""
return self._infer_model(x=x, batch_size=batch_size, axis=axis)
return self._infer_model(x=x, input_fn=input_fn,
batch_size=batch_size)

def predict_proba(self, x, batch_size=None):
"""Returns prediction probabilities for given features (classification).
@property
def model_dir(self):
return self._model_dir

@abc.abstractproperty
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
Expected to be overriden by sub-classes that require custom support.
Args:
x: features.
batch_size: Override default batch size.
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
Returns:
Numpy array of predicted probabilities.
Tuple of train `Operation` and loss `Tensor`.
"""
pass

@abc.abstractproperty
def _get_predict_ops(self, features):
"""Method that builds model graph and returns prediction ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
Returns:
predictions: `Tensor` or `dict` of `Tensor` objects.
"""
pass

def _get_eval_ops(self, features, targets, metrics):
"""Method that builds model graph and returns evaluation ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
metrics: `dict` of functions that take predictions and targets.
Returns:
metrics: `dict` of `Tensor` objects.
"""
return self._infer_model(x=x, batch_size=batch_size, proba=True)
predictions = self._get_predict_ops(features)
result = {}
for name, metric in six.iteritems(metrics):
result[name] = metric(predictions, targets)
return result

def _get_feature_ops_from_example(self, examples_batch):
"""Method that returns features given the batch of examples.
This method will be used to export model into a server.
Args:
examples_batch: batch of tf.Example
Returns:
features: `Tensor` or `dict` of `Tensor` objects.
"""
raise NotImplementedError('_get_feature_ops_from_example not implemented '
'in BaseEstimator')

def _get_default_metric_functions(self):
"""Method that provides default metric operations.
This functions is intented to be overridden by sub-classes.
Returns:
`dict` of functions that take predictions and targets `Tensor` objects and
return `Tensor`.
"""
return {}

def _check_inputs(self, features, targets):
if self._features_info is not None:
Expand Down Expand Up @@ -450,6 +427,7 @@ def _extract_metric_update_ops(self, eval_dict):
logging.warning(
'Ignoring metric {}. It returned a list|tuple with len {}, '
'expected 2'.format(name, len(metric_ops)))
value_ops[name] = metric_ops
else:
value_ops[name] = metric_ops

Expand All @@ -469,7 +447,7 @@ def _evaluate_model(self,
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return

checkpoint_path = saver.latest_checkpoint(self._model_dir)
checkpoint_path = self._model_dir
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
Expand All @@ -494,39 +472,42 @@ def _evaluate_model(self,

def _infer_model(self,
x=None, input_fn=None, feed_fn=None,
batch_size=None, axis=None, proba=False):
batch_size=None):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
input_fn, feed_fn = _get_predict_input_fn(x, None, batch_size)

checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features, _ = input_fn()
predictions = self._get_predict_ops(features)
return_dict = True
if not isinstance(predictions, dict):
predictions = {'predictions': predictions}
# TODO(ipolosukhin): Support batching
predictions, return_dict = {'predictions': predictions}, False
if feed_fn is None:
return infer(checkpoint_path, predictions)
preds = {}
while True:
try:
feed_dict = feed_fn()
except StopIteration:
break
if feed_dict is None:
break
outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
for key in outputs:
if key not in preds:
preds[key] = []
preds[key].append(outputs[key])
for key in preds:
preds[key] = np.concatenate(preds[key], axis=0)
return preds
preds = infer(checkpoint_path, predictions)
else:
preds = {}
while True:
try:
feed_dict = feed_fn()
except StopIteration:
break
if feed_dict is None:
break
outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
for key in outputs:
if key not in preds:
preds[key] = []
preds[key].append(outputs[key])
for key in preds:
preds[key] = np.concatenate(preds[key], axis=0)
if return_dict:
return preds
return preds['predictions']


class Estimator(BaseEstimator):
Expand Down Expand Up @@ -571,6 +552,41 @@ def __init__(self,
self.learning_rate = learning_rate
self.clip_gradients = clip_gradients

def predict(self, x=None, input_fn=None, axis=None, batch_size=None):
"""Returns predictions for given features.
Args:
x: features.
input_fn: Input function. If set, x must be None.
axis: Axis on which to argmax (for classification).
Last axis is used by default.
batch_size: Override default batch size.
Returns:
Numpy array of predicted classes or regression values.
"""
predictions = self._infer_model(x=x, input_fn=input_fn,
batch_size=batch_size)
if self._classification:
for key in predictions:
cur_axis = (len(predictions[key].shape) - 1) if axis is None else axis
predictions[key] = np.argmax(predictions[key], axis=cur_axis)
return predictions

def predict_proba(self, x=None, input_fn=None, batch_size=None):
"""Returns prediction probabilities for given features (classification).
Args:
x: features.
input_fn: Input function. If set, x and y must be None.
batch_size: Override default batch size.
Returns:
Numpy array of predicted probabilities.
"""
return self._infer_model(x=x, input_fn=input_fn,
batch_size=batch_size)

def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
Expand Down
Loading

0 comments on commit c80c9de

Please sign in to comment.