Skip to content

Commit

Permalink
Merge pull request scikit-learn#7317 from amueller/common_test_names
Browse files Browse the repository at this point in the history
[MRG+1] make more explicit which checks are run
  • Loading branch information
ogrisel authored Sep 5, 2016
2 parents 49fb295 + 7fc4176 commit f916449
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
31 changes: 20 additions & 11 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
check_class_weight_balanced_linear_classifier,
check_transformer_n_iter,
check_non_transformer_estimators_n_iter,
check_get_params_invariance)
check_get_params_invariance,
_set_test_name)


def test_all_estimator_no_base_class():
Expand All @@ -55,7 +56,8 @@ def test_all_estimators():

for name, Estimator in estimators:
# some can just not be sensibly default constructed
yield check_parameters_default_constructible, name, Estimator
yield (_set_test_name(check_parameters_default_constructible, name),
name, Estimator)


def test_non_meta_estimators():
Expand All @@ -70,9 +72,9 @@ def test_non_meta_estimators():
if issubclass(Estimator, ProjectedGradientNMF):
# The ProjectedGradientNMF class is deprecated
with ignore_warnings():
yield check, name, Estimator
yield _set_test_name(check, name), name, Estimator
else:
yield check, name, Estimator
yield _set_test_name(check, name), name, Estimator


def test_configure():
Expand Down Expand Up @@ -114,7 +116,8 @@ def test_class_weight_balanced_linear_classifiers():
issubclass(clazz, LinearClassifierMixin))]

for name, Classifier in linear_classifiers:
yield check_class_weight_balanced_linear_classifier, name, Classifier
yield _set_test_name(check_class_weight_balanced_linear_classifier,
name), name, Classifier


@ignore_warnings
Expand Down Expand Up @@ -196,8 +199,9 @@ def test_non_transformer_estimators_n_iter():
else:
# Multitask models related to ENet cannot handle
# if y is mono-output.
yield (check_non_transformer_estimators_n_iter,
name, estimator, 'Multi' in name)
yield (_set_test_name(
check_non_transformer_estimators_n_iter, name),
name, estimator, 'Multi' in name)


def test_transformer_n_iter():
Expand All @@ -218,9 +222,12 @@ def test_transformer_n_iter():
if isinstance(estimator, ProjectedGradientNMF):
# The ProjectedGradientNMF class is deprecated
with ignore_warnings():
yield check_transformer_n_iter, name, estimator
yield _set_test_name(
check_transformer_n_iter, name), name, estimator
else:
yield check_transformer_n_iter, name, estimator
yield _set_test_name(
check_transformer_n_iter, name), name, estimator


def test_get_params_invariance():
# Test for estimators that support get_params, that
Expand All @@ -234,6 +241,8 @@ def test_get_params_invariance():
# If class is deprecated, ignore deprecated warnings
if hasattr(Estimator.__init__, "deprecated_original"):
with ignore_warnings():
yield check_get_params_invariance, name, Estimator
yield _set_test_name(
check_get_params_invariance, name), name, Estimator
else:
yield check_get_params_invariance, name, Estimator
yield _set_test_name(
check_get_params_invariance, name), name, Estimator
6 changes: 6 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
"GradientBoostingClassifier", "GradientBoostingRegressor"]


def _set_test_name(function, name):
function.description = ("sklearn.tests.test_common.{0}({1})".format(
function.__name__, name))
return function


def _yield_non_meta_checks(name, Estimator):
yield check_estimators_dtypes
yield check_fit_score_takes_y
Expand Down

0 comments on commit f916449

Please sign in to comment.