Skip to content

Commit

Permalink
Amend notes on eager compatibility for Estimator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 200581494
  • Loading branch information
martinwicke authored and tensorflower-gardener committed Jun 14, 2018
1 parent a4cadda commit e1b0ceb
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 8 deletions.
14 changes: 14 additions & 0 deletions tensorflow/python/estimator/canned/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def input_fn_eval: # returns x, y (where y represents label's class index).
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
@compatibility(eager)
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

def __init__(self,
Expand Down Expand Up @@ -313,6 +320,13 @@ def input_fn_eval: # returns x, y (where y is the label).
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
@compatibility(eager)
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

def __init__(self,
Expand Down
20 changes: 18 additions & 2 deletions tensorflow/python/estimator/canned/boosted_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,15 @@ def _create_regression_head(label_dimension, weight_column=None):

@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
"""A Classifier for Tensorflow Boosted Trees models.
@compatibility(eager)
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

def __init__(self,
feature_columns,
Expand Down Expand Up @@ -832,7 +840,15 @@ def _model_fn(features, labels, mode, config):

@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
"""A Regressor for Tensorflow Boosted Trees models.
@compatibility(eager)
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

def __init__(self,
feature_columns,
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/python/estimator/canned/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def input_fn_predict: # returns x, None
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down Expand Up @@ -418,7 +421,10 @@ def input_fn_predict: # returns x, None
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down
10 changes: 8 additions & 2 deletions tensorflow/python/estimator/canned/dnn_linear_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ def input_fn_predict: # returns x, None
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down Expand Up @@ -473,7 +476,10 @@ def input_fn_predict: # returns x, None
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down
10 changes: 8 additions & 2 deletions tensorflow/python/estimator/canned/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def input_fn_eval: # returns x, y (where y represents label's class index).
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down Expand Up @@ -370,7 +373,10 @@ def input_fn_eval: # returns x, y
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
Estimators can be used while eager execution is enabled. Note that `input_fn`
and all hooks are executed inside a graph context, so they have to be written
to be compatible with graph mode. Note that `input_fn` code using `tf.data`
generally works in both graph and eager modes.
@end_compatibility
"""

Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class Estimator(object):
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
@compatbility(eager)
Calling methods of `Estimator` will work while eager execution is enabled.
However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
will switch to graph model before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
that `input_fn` code using `tf.data` generally works in both graph and eager
modes.
@end_compatibility
"""

def __init__(self, model_fn, model_dir=None, config=None, params=None,
Expand Down

0 comments on commit e1b0ceb

Please sign in to comment.