From 602856b0955bfb3e5133e4f49ab1a60efb9d0a02 Mon Sep 17 00:00:00 2001 From: Raghav R V Date: Mon, 26 Jan 2015 21:54:55 +0530 Subject: [PATCH] TST Add test to check if estimators reset model when fit is called --- sklearn/tests/test_common.py | 3 +++ sklearn/utils/estimator_checks.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index ec535eb076f9a..77a1ec4d0af9a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -32,6 +32,7 @@ check_estimator_sparse_data, check_transformer, check_clustering, + check_fit_reset, check_clusterer_compute_labels_predict, check_regressors_int, check_regressors_train, @@ -100,6 +101,8 @@ def test_non_meta_estimators(): yield check_sparsify_coefficients, name, Estimator yield check_estimator_sparse_data, name, Estimator + # test if fit resets model + yield check_fit_reset, name, Alg def test_transformers(): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e8a9feab24377..24afef83e9e2a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1004,3 +1004,35 @@ def check_transformer_n_iter(name, estimator): assert_greater(iter_, 1) else: assert_greater(estimator.n_iter_, 1) + + +@ignore_warnings +def check_fit_reset(name, Estimator): + X1, y1 = make_blobs(n_samples=50, n_features=2, center_box=(-200, -150), + centers=2, random_state=0) + X2, y2 = make_blobs(n_samples=100, n_features=3, center_box=(-1, 1), + centers=1, random_state=1) + X3, y3 = make_blobs(n_samples=200, n_features=4, center_box=(-100, -50), + centers=5, random_state=2) + X4, y4 = make_blobs(n_samples=150, n_features=5, center_box=(50, 100), + centers=10, random_state=3) + + estimator_1 = Estimator() + estimator_2 = Estimator() + + set_fast_parameters(estimator_1) + set_fast_parameters(estimator_2) + + set_random_state(estimator_1) + set_random_state(estimator_2) + + _fit(estimator_1, X1, y1) + _fit(estimator_2, X3, y3) + assert_not_same_model(estimator_1, estimator_2) + + _fit(estimator_2, X4, y4) + assert_not_same_model(estimator_1, estimator_2) + + _fit(estimator_1, X2, y2) + _fit(estimator_2, X2, y2) + assert_same_model(estimator_1, estimator_2)