Skip to content

Commit

Permalink
Merge pull request scikit-learn#7401 from ogrisel/sgdclf-predict_prob…
Browse files Browse the repository at this point in the history
…a-delegation

[MRG] FIX 7155: GridSearchCV predict_proba delegation to SDGClassifier
  • Loading branch information
ogrisel authored Sep 13, 2016
2 parents 940f194 + 3dffa08 commit 44f1ad4
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 45 deletions.
12 changes: 6 additions & 6 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def score(self, X, y=None):
ChangedBehaviorWarning)
return self.scorer_(self.best_estimator_, X, y)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.
Expand All @@ -442,7 +442,7 @@ def predict(self, X):
"""
return self.best_estimator_.predict(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.
Expand All @@ -458,7 +458,7 @@ def predict_proba(self, X):
"""
return self.best_estimator_.predict_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.
Expand All @@ -474,7 +474,7 @@ def predict_log_proba(self, X):
"""
return self.best_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.
Expand All @@ -490,7 +490,7 @@ def decision_function(self, X):
"""
return self.best_estimator_.decision_function(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.
Expand All @@ -506,7 +506,7 @@ def transform(self, X):
"""
return self.best_estimator_.transform(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found parameters.
Expand Down
12 changes: 6 additions & 6 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _check_is_fitted(self, method_name):
else:
check_is_fitted(self, 'best_estimator_')

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.
Expand All @@ -443,7 +443,7 @@ def predict(self, X):
self._check_is_fitted('predict')
return self.best_estimator_.predict(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.
Expand All @@ -460,7 +460,7 @@ def predict_proba(self, X):
self._check_is_fitted('predict_proba')
return self.best_estimator_.predict_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.
Expand All @@ -477,7 +477,7 @@ def predict_log_proba(self, X):
self._check_is_fitted('predict_log_proba')
return self.best_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.
Expand All @@ -494,7 +494,7 @@ def decision_function(self, X):
self._check_is_fitted('decision_function')
return self.best_estimator_.decision_function(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.
Expand All @@ -511,7 +511,7 @@ def transform(self, X):
self._check_is_fitted('transform')
return self.best_estimator_.transform(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found params.
Expand Down
39 changes: 36 additions & 3 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import Imputer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier


# Neither of the following two estimators inherit from BaseEstimator,
Expand Down Expand Up @@ -967,11 +968,13 @@ def test_grid_search_failing_classifier():
refit=False, error_score=0.0)
assert_warns(FitFailedWarning, gs.fit, X, y)
n_candidates = len(gs.cv_results_['params'])

# Ensure that grid scores were set to zero as required for those fits
# that are expected to fail.
get_cand_scores = lambda i: np.array(list(
gs.cv_results_['split%d_test_score' % s][i]
for s in range(gs.n_splits_)))
def get_cand_scores(i):
return np.array(list(gs.cv_results_['split%d_test_score' % s][i]
for s in range(gs.n_splits_)))

assert all((np.all(get_cand_scores(cand_i) == 0.0)
for cand_i in range(n_candidates)
if gs.cv_results_['param_parameter'][cand_i] ==
Expand Down Expand Up @@ -1028,3 +1031,33 @@ def test_parameters_sampler_replacement():
sampler = ParameterSampler(params_distribution, n_iter=7)
samples = list(sampler)
assert_equal(len(samples), 7)


def test_stochastic_gradient_loss_param():
# Make sure the predict_proba works when loss is specified
# as one of the parameters in the param_grid.
param_grid = {
'loss': ['log'],
}
X = np.arange(20).reshape(5, -1)
y = [0, 0, 1, 1, 1]
clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
param_grid=param_grid)

# When the estimator is not fitted, `predict_proba` is not available as the
# loss is 'hinge'.
assert_false(hasattr(clf, "predict_proba"))
clf.fit(X, y)
clf.predict_proba(X)
clf.predict_log_proba(X)

# Make sure `predict_proba` is not available when setting loss=['hinge']
# in param_grid
param_grid = {
'loss': ['hinge'],
}
clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
param_grid=param_grid)
assert_false(hasattr(clf, "predict_proba"))
clf.fit(X, y)
assert_false(hasattr(clf, "predict_proba"))
3 changes: 2 additions & 1 deletion sklearn/tests/test_metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, name, construct, skip_methods=(),
skip_methods=['transform', 'inverse_transform', 'score']),
DelegatorData('BaggingClassifier', BaggingClassifier,
skip_methods=['transform', 'inverse_transform', 'score',
'predict_proba', 'predict_log_proba', 'predict'])
'predict_proba', 'predict_log_proba',
'predict'])
]


Expand Down
63 changes: 35 additions & 28 deletions sklearn/utils/metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@ class _IffHasAttrDescriptor(object):
"""Implements a conditional property using the descriptor protocol.
Using this class to create a decorator will raise an ``AttributeError``
if the ``attribute_name`` is not present on the base object.
if none of the delegates (specified in ``delegate_names``) is an attribute
of the base object or the first found delegate does not have an attribute
``attribute_name``.
This allows ducktyping of the decorated method based on ``attribute_name``.
This allows ducktyping of the decorated method based on
``delegate.attribute_name``. Here ``delegate`` is the first item in
``delegate_names`` for which ``hasattr(object, delegate) is True``.
See https://docs.python.org/3/howto/descriptor.html for an explanation of
descriptors.
"""
def __init__(self, fn, attribute_name):
def __init__(self, fn, delegate_names, attribute_name):
self.fn = fn
self.get_attribute = attrgetter(attribute_name)
self.delegate_names = delegate_names
self.attribute_name = attribute_name

# update the docstring of the descriptor
update_wrapper(self, fn)

Expand All @@ -32,7 +38,17 @@ def __get__(self, obj, type=None):
if obj is not None:
# delegate only on instances, not the classes.
# this is to allow access to the docstrings.
self.get_attribute(obj)
for delegate_name in self.delegate_names:
try:
delegate = attrgetter(delegate_name)(obj)
except AttributeError:
continue
else:
getattr(delegate, self.attribute_name)
break
else:
attrgetter(self.delegate_names[-1])(obj)

# lambda, but not partial, allows help() to work with update_wrapper
out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
# update the docstring of the returned function
Expand All @@ -46,27 +62,18 @@ def if_delegate_has_method(delegate):
This enables ducktyping by hasattr returning True according to the
sub-estimator.
>>> from sklearn.utils.metaestimators import if_delegate_has_method
>>>
>>>
>>> class MetaEst(object):
... def __init__(self, sub_est):
... self.sub_est = sub_est
...
... @if_delegate_has_method(delegate='sub_est')
... def predict(self, X):
... return self.sub_est.predict(X)
...
>>> class HasPredict(object):
... def predict(self, X):
... return X.sum(axis=1)
...
>>> class HasNoPredict(object):
... pass
...
>>> hasattr(MetaEst(HasPredict()), 'predict')
True
>>> hasattr(MetaEst(HasNoPredict()), 'predict')
False
Parameters
----------
delegate : string, list of strings or tuple of strings
Name of the sub-estimator that can be accessed as an attribute of the
base object. If a list or a tuple of names are provided, the first
sub-estimator that is an attribute of the base object will be used.
"""
return lambda fn: _IffHasAttrDescriptor(fn, '%s.%s' % (delegate, fn.__name__))
if isinstance(delegate, list):
delegate = tuple(delegate)
if not isinstance(delegate, tuple):
delegate = (delegate,)

return lambda fn: _IffHasAttrDescriptor(fn, delegate,
attribute_name=fn.__name__)
56 changes: 55 additions & 1 deletion sklearn/utils/tests/test_metaestimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nose.tools import assert_true, assert_false
from sklearn.utils.metaestimators import if_delegate_has_method
from nose.tools import assert_true


class Prefix(object):
Expand All @@ -24,3 +24,57 @@ def test_delegated_docstring():
in str(MockMetaEstimator.func.__doc__))
assert_true("This is a mock delegated function"
in str(MockMetaEstimator().func.__doc__))


class MetaEst(object):
"""A mock meta estimator"""
def __init__(self, sub_est, better_sub_est=None):
self.sub_est = sub_est
self.better_sub_est = better_sub_est

@if_delegate_has_method(delegate='sub_est')
def predict(self):
pass


class MetaEstTestTuple(MetaEst):
"""A mock meta estimator to test passing a tuple of delegates"""

@if_delegate_has_method(delegate=('sub_est', 'better_sub_est'))
def predict(self):
pass


class MetaEstTestList(MetaEst):
"""A mock meta estimator to test passing a list of delegates"""

@if_delegate_has_method(delegate=['sub_est', 'better_sub_est'])
def predict(self):
pass


class HasPredict(object):
"""A mock sub-estimator with predict method"""

def predict(self):
pass


class HasNoPredict(object):
"""A mock sub-estimator with no predict method"""
pass


def test_if_delegate_has_method():
assert_true(hasattr(MetaEst(HasPredict()), 'predict'))
assert_false(hasattr(MetaEst(HasNoPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), 'predict'))
assert_true(
hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), 'predict'))
assert_true(
hasattr(MetaEstTestList(HasPredict(), HasPredict()), 'predict'))

0 comments on commit 44f1ad4

Please sign in to comment.