From 61bfe8aea8cd83738fbd8c3dd4c836c3e90e5e59 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 31 Aug 2016 18:00:55 -0400 Subject: [PATCH 1/2] make more explicit which checks are run --- sklearn/tests/test_common.py | 22 ++++++++++++---------- sklearn/utils/estimator_checks.py | 5 +++++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9c42951e60057..ad9a713a16993 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -33,7 +33,8 @@ check_class_weight_balanced_linear_classifier, check_transformer_n_iter, check_non_transformer_estimators_n_iter, - check_get_params_invariance) + check_get_params_invariance, + _set_test_name) def test_all_estimator_no_base_class(): @@ -55,7 +56,7 @@ def test_all_estimators(): for name, Estimator in estimators: # some can just not be sensibly default constructed - yield check_parameters_default_constructible, name, Estimator + yield _set_test_name(check_parameters_default_constructible, name), name, Estimator def test_non_meta_estimators(): @@ -70,9 +71,9 @@ def test_non_meta_estimators(): if issubclass(Estimator, ProjectedGradientNMF): # The ProjectedGradientNMF class is deprecated with ignore_warnings(): - yield check, name, Estimator + yield _set_test_name(check, name), name, Estimator else: - yield check, name, Estimator + yield _set_test_name(check, name), name, Estimator def test_configure(): @@ -114,7 +115,7 @@ def test_class_weight_balanced_linear_classifiers(): issubclass(clazz, LinearClassifierMixin))] for name, Classifier in linear_classifiers: - yield check_class_weight_balanced_linear_classifier, name, Classifier + yield _set_test_name(check_class_weight_balanced_linear_classifier, name), name, Classifier @ignore_warnings @@ -196,7 +197,7 @@ def test_non_transformer_estimators_n_iter(): else: # Multitask models related to ENet cannot handle # if y is mono-output. - yield (check_non_transformer_estimators_n_iter, + yield (_set_test_name(check_non_transformer_estimators_n_iter, name), name, estimator, 'Multi' in name) @@ -218,9 +219,10 @@ def test_transformer_n_iter(): if isinstance(estimator, ProjectedGradientNMF): # The ProjectedGradientNMF class is deprecated with ignore_warnings(): - yield check_transformer_n_iter, name, estimator + yield _set_test_name(check_transformer_n_iter, name), name, estimator else: - yield check_transformer_n_iter, name, estimator + yield _set_test_name(check_transformer_n_iter, name), name, estimator + def test_get_params_invariance(): # Test for estimators that support get_params, that @@ -234,6 +236,6 @@ def test_get_params_invariance(): # If class is deprecated, ignore deprecated warnings if hasattr(Estimator.__init__, "deprecated_original"): with ignore_warnings(): - yield check_get_params_invariance, name, Estimator + yield _set_test_name(check_get_params_invariance, name), name, Estimator else: - yield check_get_params_invariance, name, Estimator + yield _set_test_name(check_get_params_invariance, name), name, Estimator diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index db27550d973bb..5dbf4240862d5 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -75,6 +75,11 @@ "GradientBoostingClassifier", "GradientBoostingRegressor"] +def _set_test_name(function, name): + function.description = "sklearn.tests.test_common.{0}({1})".format(function.__name__, name) + return function + + def _yield_non_meta_checks(name, Estimator): yield check_estimators_dtypes yield check_fit_score_takes_y From 7fc41768b10682d2ceb04d3a7db32e3950d11842 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 31 Aug 2016 22:37:23 -0400 Subject: [PATCH 2/2] pep8 --- sklearn/tests/test_common.py | 23 +++++++++++++++-------- sklearn/utils/estimator_checks.py | 3 ++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index ad9a713a16993..374fc0774c993 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -56,7 +56,8 @@ def test_all_estimators(): for name, Estimator in estimators: # some can just not be sensibly default constructed - yield _set_test_name(check_parameters_default_constructible, name), name, Estimator + yield (_set_test_name(check_parameters_default_constructible, name), + name, Estimator) def test_non_meta_estimators(): @@ -115,7 +116,8 @@ def test_class_weight_balanced_linear_classifiers(): issubclass(clazz, LinearClassifierMixin))] for name, Classifier in linear_classifiers: - yield _set_test_name(check_class_weight_balanced_linear_classifier, name), name, Classifier + yield _set_test_name(check_class_weight_balanced_linear_classifier, + name), name, Classifier @ignore_warnings @@ -197,8 +199,9 @@ def test_non_transformer_estimators_n_iter(): else: # Multitask models related to ENet cannot handle # if y is mono-output. - yield (_set_test_name(check_non_transformer_estimators_n_iter, name), - name, estimator, 'Multi' in name) + yield (_set_test_name( + check_non_transformer_estimators_n_iter, name), + name, estimator, 'Multi' in name) def test_transformer_n_iter(): @@ -219,9 +222,11 @@ def test_transformer_n_iter(): if isinstance(estimator, ProjectedGradientNMF): # The ProjectedGradientNMF class is deprecated with ignore_warnings(): - yield _set_test_name(check_transformer_n_iter, name), name, estimator + yield _set_test_name( + check_transformer_n_iter, name), name, estimator else: - yield _set_test_name(check_transformer_n_iter, name), name, estimator + yield _set_test_name( + check_transformer_n_iter, name), name, estimator def test_get_params_invariance(): @@ -236,6 +241,8 @@ def test_get_params_invariance(): # If class is deprecated, ignore deprecated warnings if hasattr(Estimator.__init__, "deprecated_original"): with ignore_warnings(): - yield _set_test_name(check_get_params_invariance, name), name, Estimator + yield _set_test_name( + check_get_params_invariance, name), name, Estimator else: - yield _set_test_name(check_get_params_invariance, name), name, Estimator + yield _set_test_name( + check_get_params_invariance, name), name, Estimator diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5dbf4240862d5..5c031c881addc 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -76,7 +76,8 @@ def _set_test_name(function, name): - function.description = "sklearn.tests.test_common.{0}({1})".format(function.__name__, name) + function.description = ("sklearn.tests.test_common.{0}({1})".format( + function.__name__, name)) return function