Skip to content

Commit

Permalink
Fixes tensorflow#9654: Allow model_fn being a member function of a cl…
Browse files Browse the repository at this point in the history
…ass (tensorflow#9807)

* Fixes tensorflow#9654: Allow model_fn being a member function of a class

* Remove self from arg instead of changing a const
  • Loading branch information
terrytangyuan authored and benoitsteiner committed May 12, 2017
1 parent e7f64e3 commit eb0a59f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tensorflow/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def _model_fn_args(fn):

def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
args = _model_fn_args(model_fn)
args = set(_model_fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if 'labels' not in args:
Expand All @@ -752,7 +752,10 @@ def _verify_model_fn_args(model_fn, params):
logging.warning('Estimator\'s model_fn (%s) includes params '
'argument, but params are not passed to Estimator.',
model_fn)
non_valid_args = list(set(args) - _VALID_MODEL_FN_ARGS)
if tf_inspect.ismethod(model_fn):
if 'self' in args:
args.remove('self')
non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
if non_valid_args:
raise ValueError('model_fn (%s) has following not expected args: %s' %
(model_fn, non_valid_args))
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/estimator/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,17 @@ def model_fn(features, labels, something):
features, labels, 'something')
estimator.Estimator(model_fn=new_model_fn)

def test_if_model_fn_is_a_member_function_of_a_class(self):

class ModelFnClass(object):
def __init__(self):
estimator.Estimator(model_fn=self.model_fn)

def model_fn(self, features, labels, mode):
_, _, _ = features, labels, mode

ModelFnClass()


def dummy_input_fn():
return ({'x': constant_op.constant([[1], [1]])},
Expand Down

0 comments on commit eb0a59f

Please sign in to comment.