From 28260505f57ed74d698f6ddfd0ef4ea1e4afa339 Mon Sep 17 00:00:00 2001 From: Mark Roth Date: Fri, 15 Jun 2018 12:32:59 -0400 Subject: [PATCH 1/4] Test has_fit_parameter() and fit_score_takes_y() work with @deprecated --- sklearn/utils/tests/test_estimator_checks.py | 14 ++++++++++++++ sklearn/utils/tests/test_validation.py | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 53a67693843d9..bce9a8918585a 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -9,12 +9,14 @@ from sklearn.externals import joblib from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.utils import deprecated from sklearn.utils.testing import (assert_raises_regex, assert_true, assert_equal, ignore_warnings) from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import set_random_state from sklearn.utils.estimator_checks import set_checking_parameters from sklearn.utils.estimator_checks import check_estimators_unfitted +from sklearn.utils.estimator_checks import check_fit_score_takes_y from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier from sklearn.linear_model import LinearRegression, SGDClassifier @@ -176,6 +178,18 @@ def transform(self, X): return sp.csr_matrix(X) +def test_check_fit_score_takes_y_works_on_deprecated_fit(): + # Tests that check_fit_score_takes_y works on a class with + # a deprecated fit method + + class TestEstimatorWithDeprecatedFitMethod(BaseEstimator): + @deprecated("Deprecated for the purpose of testing check_fit_score_takes_y") + def fit(self, X, y): + return self + + check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod()) + + def test_check_estimator(): # tests that the estimator actually fails on "bad" estimators. # not a complete test of all checks, which are very extensive. diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 076e6d88440f1..5fdd243b0d501 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -22,6 +22,7 @@ from sklearn.utils.testing import assert_allclose_dense_sparse from sklearn.utils import as_float_array, check_array, check_symmetric from sklearn.utils import check_X_y +from sklearn.utils import deprecated from sklearn.utils.mocking import MockDataFrame from sklearn.utils.estimator_checks import NotAnArray from sklearn.random_projection import sparse_random_matrix @@ -563,6 +564,13 @@ def test_has_fit_parameter(): assert_true(has_fit_parameter(SVR, "sample_weight")) assert_true(has_fit_parameter(SVR(), "sample_weight")) + class TestClassWithDeprecatedFitMethod: + @deprecated("Deprecated for the purpose of testing has_fit_parameter") + def fit(self, X, y, sample_weight=None): + pass + + assert_true(has_fit_parameter(TestClassWithDeprecatedFitMethod, "sample_weight")) + def test_check_symmetric(): arr_sym = np.array([[0, 1], [1, 2]]) From a48c0496957d49ffbf5a9d3a375893547fb60cec Mon Sep 17 00:00:00 2001 From: Mark Roth Date: Fri, 15 Jun 2018 16:05:32 -0400 Subject: [PATCH 2/4] Fix function introspection in Python 2 --- sklearn/utils/deprecation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/deprecation.py b/sklearn/utils/deprecation.py index 5621f436d9baf..fc06f9bc84d3b 100644 --- a/sklearn/utils/deprecation.py +++ b/sklearn/utils/deprecation.py @@ -78,6 +78,9 @@ def wrapped(*args, **kwargs): return fun(*args, **kwargs) wrapped.__doc__ = self._update_doc(wrapped.__doc__) + # Add a reference to the wrapped function so that we can introspect + # on function arguments in Python 2 (already works in Python 3) + wrapped.__wrapped__ = fun return wrapped From d9bd2296d03455bc87374468ef77478f2378c556 Mon Sep 17 00:00:00 2001 From: Mark Roth Date: Fri, 15 Jun 2018 16:12:07 -0400 Subject: [PATCH 3/4] Fix style --- sklearn/utils/tests/test_estimator_checks.py | 3 ++- sklearn/utils/tests/test_validation.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index bce9a8918585a..049dff4baa920 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -183,7 +183,8 @@ def test_check_fit_score_takes_y_works_on_deprecated_fit(): # a deprecated fit method class TestEstimatorWithDeprecatedFitMethod(BaseEstimator): - @deprecated("Deprecated for the purpose of testing check_fit_score_takes_y") + @deprecated("Deprecated for the purpose of testing " + "check_fit_score_takes_y") def fit(self, X, y): return self diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5fdd243b0d501..3d17d71d79820 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -569,7 +569,8 @@ class TestClassWithDeprecatedFitMethod: def fit(self, X, y, sample_weight=None): pass - assert_true(has_fit_parameter(TestClassWithDeprecatedFitMethod, "sample_weight")) + assert_true(has_fit_parameter(TestClassWithDeprecatedFitMethod, + "sample_weight")) def test_check_symmetric(): From 72bbb65c5562664cb523b6499d022d3cfac59930 Mon Sep 17 00:00:00 2001 From: Mark Roth Date: Mon, 18 Jun 2018 17:42:24 -0400 Subject: [PATCH 4/4] Replace use of assert_true with assert --- sklearn/utils/tests/test_validation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 3d17d71d79820..5af26ac5b978f 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -569,8 +569,9 @@ class TestClassWithDeprecatedFitMethod: def fit(self, X, y, sample_weight=None): pass - assert_true(has_fit_parameter(TestClassWithDeprecatedFitMethod, - "sample_weight")) + assert has_fit_parameter(TestClassWithDeprecatedFitMethod, + "sample_weight"), \ + "has_fit_parameter fails for class with deprecated fit method." def test_check_symmetric():