From 457b80fa8d1221d2508aaa41c08013a06214c500 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 8 Oct 2018 13:02:04 -0400 Subject: [PATCH 01/11] Added check for idempotence of fit() --- sklearn/utils/estimator_checks.py | 59 +++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 54369033a75d3..e4cdfdcf4bccd 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -25,6 +25,7 @@ from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_in from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_allclose_dense_sparse from sklearn.utils.testing import assert_warns_message @@ -266,6 +267,7 @@ def _yield_all_checks(name, estimator): yield check_set_params yield check_dict_unchanged yield check_dont_overwrite_parameters + yield check_fit_idempotent def check_estimator(Estimator): @@ -2345,3 +2347,60 @@ def check_outliers_fit_predict(name, estimator_orig): for contamination in [-0.5, 2.3]: estimator.set_params(contamination=contamination) assert_raises(ValueError, estimator.fit_predict, X) + + +def check_fit_idempotent(name, estimator_orig): + # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would + # check that the estimated parameters during training (e.g. coefs_) are + # the same, but having a universal comparison function for those + # attributes is difficult and full of edge cases. So instead we check that + # predict(), predict_proba(), decision_function() and transform() return + # the same results. + + est = clone(estimator_orig) + + np.random.seed(0) + X = np.random.normal(loc=100, size=(100, 2)) + if is_regressor(estimator_orig): + y = np.random.normal(size=100) + else: + y = np.random.randint(low=0, high=2, size=100) + if est.__class__.__name__.startswith('MultiTask'): + y = np.stack([y, y], axis=1) + + X_train, X_test, y_train, _ = train_test_split(X, y) + + if 'random_state' in est.get_params().keys(): + est.set_params(random_state=0) + if 'warm_start' in est.get_params().keys(): + est.set_params(warm_start=False) + + est.fit(X_train, y_train) + + if hasattr(est, 'predict'): + pred_1 = est.predict(X_test) + if hasattr(est, 'predict_proba'): + pred_proba_1 = est.predict_proba(X_test) + if hasattr(est, 'decision_function'): + decision_1 = est.decision_function(X_test) + if hasattr(est, 'transform'): + transform_1 = est.transform(X_test) + + # Fit again + est.fit(X_train, y_train) + + if hasattr(est, 'predict'): + pred_2 = est.predict(X_test) + assert_array_almost_equal(pred_1, pred_2) + if hasattr(est, 'predict_proba'): + pred_proba_2 = est.predict_proba(X_test) + assert_array_almost_equal(pred_proba_1, pred_proba_2) + if hasattr(est, 'decision_function'): + decision_2 = est.decision_function(X_test) + assert_array_almost_equal(decision_1, decision_2) + if hasattr(est, 'transform'): + transform_2 = est.transform(X_test) + if sparse.issparse(transform_1): + transform_1 = transform_1.toarray() + transform_2 = transform_2.toarray() + assert np.allclose(transform_1, transform_2) From c578b222ad477063135b42b70aac1df16fe1ca2f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 8 Oct 2018 17:35:15 -0400 Subject: [PATCH 02/11] Fixed test for estimator with pairwise kernel or metric --- sklearn/utils/estimator_checks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e4cdfdcf4bccd..db6826912ff82 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2369,6 +2369,9 @@ def check_fit_idempotent(name, estimator_orig): y = np.stack([y, y], axis=1) X_train, X_test, y_train, _ = train_test_split(X, y) + # some estimators expect a square matrix + X_train = pairwise_estimator_convert_X(X_train, est) + X_test = pairwise_estimator_convert_X(X_train, est) if 'random_state' in est.get_params().keys(): est.set_params(random_state=0) From 7365d0954c001dfc81a23acbd9e73635c583d5b0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 10 Oct 2018 10:35:42 -0400 Subject: [PATCH 03/11] Used column_stack instead of stack, should fix the tests --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index db6826912ff82..24070cbf46dfb 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2366,7 +2366,7 @@ def check_fit_idempotent(name, estimator_orig): else: y = np.random.randint(low=0, high=2, size=100) if est.__class__.__name__.startswith('MultiTask'): - y = np.stack([y, y], axis=1) + y = np.column_stack([y, y]) X_train, X_test, y_train, _ = train_test_split(X, y) # some estimators expect a square matrix @@ -2406,4 +2406,4 @@ def check_fit_idempotent(name, estimator_orig): if sparse.issparse(transform_1): transform_1 = transform_1.toarray() transform_2 = transform_2.toarray() - assert np.allclose(transform_1, transform_2) + assert_array_almost_equal(transform_1, transform_2) From b8edccbf9c64a015cc1db06fb68b71d3aef50fc6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 15 Oct 2018 11:41:55 -0400 Subject: [PATCH 04/11] Refactored test --- sklearn/utils/estimator_checks.py | 70 +++++++++++++------------------ 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 24070cbf46dfb..b7d31d77029b4 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -25,7 +25,6 @@ from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_in from sklearn.utils.testing import assert_array_equal -from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_allclose_dense_sparse from sklearn.utils.testing import assert_warns_message @@ -2357,53 +2356,40 @@ def check_fit_idempotent(name, estimator_orig): # predict(), predict_proba(), decision_function() and transform() return # the same results. - est = clone(estimator_orig) + check_methods = ["predict", "transform", "decision_function", + "predict_proba"] + rng = np.random.RandomState(0) + rng.randint(100) - np.random.seed(0) - X = np.random.normal(loc=100, size=(100, 2)) + estimator = clone(estimator_orig) + set_random_state(estimator) + if 'warm_start' in estimator.get_params().keys(): + estimator.set_params(warm_start=False) + + n_samples = 100 + X = rng.normal(loc=100, size=(n_samples, 2)) if is_regressor(estimator_orig): - y = np.random.normal(size=100) + y = rng.normal(size=n_samples) else: - y = np.random.randint(low=0, high=2, size=100) - if est.__class__.__name__.startswith('MultiTask'): - y = np.column_stack([y, y]) - + y = rng.randint(low=0, high=2, size=n_samples) + y = multioutput_estimator_convert_y_2d(estimator, y) X_train, X_test, y_train, _ = train_test_split(X, y) # some estimators expect a square matrix - X_train = pairwise_estimator_convert_X(X_train, est) - X_test = pairwise_estimator_convert_X(X_train, est) - - if 'random_state' in est.get_params().keys(): - est.set_params(random_state=0) - if 'warm_start' in est.get_params().keys(): - est.set_params(warm_start=False) + X_train = pairwise_estimator_convert_X(X_train, estimator) + X_test = pairwise_estimator_convert_X(X_train, estimator) - est.fit(X_train, y_train) + # Fit for the first time + estimator.fit(X_train, y_train) - if hasattr(est, 'predict'): - pred_1 = est.predict(X_test) - if hasattr(est, 'predict_proba'): - pred_proba_1 = est.predict_proba(X_test) - if hasattr(est, 'decision_function'): - decision_1 = est.decision_function(X_test) - if hasattr(est, 'transform'): - transform_1 = est.transform(X_test) + result = dict() + for method in check_methods: + if hasattr(estimator, method): + result[method] = getattr(estimator, method)(X_test) # Fit again - est.fit(X_train, y_train) - - if hasattr(est, 'predict'): - pred_2 = est.predict(X_test) - assert_array_almost_equal(pred_1, pred_2) - if hasattr(est, 'predict_proba'): - pred_proba_2 = est.predict_proba(X_test) - assert_array_almost_equal(pred_proba_1, pred_proba_2) - if hasattr(est, 'decision_function'): - decision_2 = est.decision_function(X_test) - assert_array_almost_equal(decision_1, decision_2) - if hasattr(est, 'transform'): - transform_2 = est.transform(X_test) - if sparse.issparse(transform_1): - transform_1 = transform_1.toarray() - transform_2 = transform_2.toarray() - assert_array_almost_equal(transform_1, transform_2) + estimator.fit(X_train, y_train) + + for method in check_methods: + if hasattr(estimator, method): + new_result = getattr(estimator, method)(X_test) + assert_allclose_dense_sparse(result[method], new_result) From 88a81aff3765ca57855b763cac675a83ad814d47 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 15 Oct 2018 11:46:35 -0400 Subject: [PATCH 05/11] Removed useless line --- sklearn/utils/estimator_checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index b7d31d77029b4..0c929beeb5858 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2359,7 +2359,6 @@ def check_fit_idempotent(name, estimator_orig): check_methods = ["predict", "transform", "decision_function", "predict_proba"] rng = np.random.RandomState(0) - rng.randint(100) estimator = clone(estimator_orig) set_random_state(estimator) From 7cffbb54a33086801ba7144aa8e0e686a95fa33e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 15 Oct 2018 12:38:59 -0400 Subject: [PATCH 06/11] addressed comment: X_train -> X_test --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 0c929beeb5858..148639102b34b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2375,7 +2375,7 @@ def check_fit_idempotent(name, estimator_orig): X_train, X_test, y_train, _ = train_test_split(X, y) # some estimators expect a square matrix X_train = pairwise_estimator_convert_X(X_train, estimator) - X_test = pairwise_estimator_convert_X(X_train, estimator) + X_test = pairwise_estimator_convert_X(X_test, estimator) # Fit for the first time estimator.fit(X_train, y_train) From 0bca0354ddf6b03ee55c25d176e7de42480df944 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 16 Oct 2018 10:04:03 -0400 Subject: [PATCH 07/11] should fix test --- sklearn/utils/estimator_checks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 148639102b34b..ddc8afad956e7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2372,7 +2372,8 @@ def check_fit_idempotent(name, estimator_orig): else: y = rng.randint(low=0, high=2, size=n_samples) y = multioutput_estimator_convert_y_2d(estimator, y) - X_train, X_test, y_train, _ = train_test_split(X, y) + X_train, X_test, y_train, _ = train_test_split(X, y, train_size=.5, + random_state=rng) # some estimators expect a square matrix X_train = pairwise_estimator_convert_X(X_train, estimator) X_test = pairwise_estimator_convert_X(X_test, estimator) From 275b4f9ef896919602dbfc382f7373cb591100d1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 18 Oct 2018 15:43:43 -0400 Subject: [PATCH 08/11] set test_size instead of train_size to avoid deprecation warning --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ddc8afad956e7..6a65aa8aa065b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2372,7 +2372,7 @@ def check_fit_idempotent(name, estimator_orig): else: y = rng.randint(low=0, high=2, size=n_samples) y = multioutput_estimator_convert_y_2d(estimator, y) - X_train, X_test, y_train, _ = train_test_split(X, y, train_size=.5, + X_train, X_test, y_train, _ = train_test_split(X, y, test_size=.5, random_state=rng) # some estimators expect a square matrix X_train = pairwise_estimator_convert_X(X_train, estimator) From 53b7f55ac126e4307b95d33ab1222fbefc2ec06b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 22 Oct 2018 10:58:56 -0400 Subject: [PATCH 09/11] dict -> {} --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 6a65aa8aa065b..15e87be974f50 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2381,7 +2381,7 @@ def check_fit_idempotent(name, estimator_orig): # Fit for the first time estimator.fit(X_train, y_train) - result = dict() + result = {} for method in check_methods: if hasattr(estimator, method): result[method] = getattr(estimator, method)(X_test) From ee8a0c9091708db08b26feb9ac6d8d3d31995edc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Oct 2018 14:59:55 -0400 Subject: [PATCH 10/11] Used _safe_split for splitting data --- sklearn/utils/estimator_checks.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 15e87be974f50..5b2b3acf7284e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -54,6 +54,8 @@ from sklearn.exceptions import DataConversionWarning from sklearn.exceptions import SkipTestWarning from sklearn.model_selection import train_test_split +from sklearn.model_selection import ShuffleSplit +from sklearn.model_selection._validation import _safe_split from sklearn.metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances) @@ -2367,16 +2369,16 @@ def check_fit_idempotent(name, estimator_orig): n_samples = 100 X = rng.normal(loc=100, size=(n_samples, 2)) + X = pairwise_estimator_convert_X(X, estimator) if is_regressor(estimator_orig): y = rng.normal(size=n_samples) else: y = rng.randint(low=0, high=2, size=n_samples) y = multioutput_estimator_convert_y_2d(estimator, y) - X_train, X_test, y_train, _ = train_test_split(X, y, test_size=.5, - random_state=rng) - # some estimators expect a square matrix - X_train = pairwise_estimator_convert_X(X_train, estimator) - X_test = pairwise_estimator_convert_X(X_test, estimator) + + train, test = next(ShuffleSplit(test_size=.2, random_state=rng).split(X)) + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) # Fit for the first time estimator.fit(X_train, y_train) From 19b68ea6234bd557f58f2906d829ac345bb40993 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 28 Oct 2018 19:00:53 -0400 Subject: [PATCH 11/11] Added whatsnew entry --- doc/whats_new/v0.21.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 735563ddc5b43..860dc82af5a76 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -90,3 +90,9 @@ Changes to estimator checks --------------------------- These changes mostly affect library developers. + +- Add ``check_fit_idempotent`` to + :func:`~utils.estimator_checks.check_estimator`, which checks that + when `fit` is called twice with the same data, the ouput of + `predit`, `predict_proba`, `transform`, and `decision_function` does not + change. :issue:`12328` by :user:`Nicolas Hug`