From c407bf47fa34bb23a3ac39764ec15569d4fc1d94 Mon Sep 17 00:00:00 2001 From: Mohit Date: Fri, 25 Aug 2023 15:55:25 +0530 Subject: [PATCH 1/4] extend test cases of test-common.py --- sklearn/preprocessing/tests/test_common.py | 32 ++++++++++++++-------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 9ebef6c000050..80901390aeba9 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from scipy import sparse from sklearn.base import clone from sklearn.datasets import load_iris @@ -22,6 +21,14 @@ scale, ) from sklearn.utils._testing import assert_allclose, assert_array_equal +from sklearn.utils.fixes import ( + BSR_CONTAINERS, + COO_CONTAINERS, + CSC_CONTAINERS, + CSR_CONTAINERS, + DOK_CONTAINERS, + LIL_CONTAINERS, +) iris = load_iris() @@ -45,8 +52,19 @@ def _get_valid_samples_by_column(X, col): (RobustScaler(with_centering=False), robust_scale, True, False, []), ], ) +@pytest.mark.parametrize( + "sparse_container", + [ + CSR_CONTAINERS, + CSC_CONTAINERS, + COO_CONTAINERS, + LIL_CONTAINERS, + DOK_CONTAINERS, + BSR_CONTAINERS, + ], +) def test_missing_value_handling( - est, func, support_sparse, strictly_positive, omit_kwargs + est, func, support_sparse, strictly_positive, omit_kwargs, sparse_container ): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) @@ -113,15 +131,7 @@ def test_missing_value_handling( Xt_dense = est_dense.fit(X_train).transform(X_test) Xt_inv_dense = est_dense.inverse_transform(Xt_dense) - for sparse_constructor in ( - sparse.csr_matrix, - sparse.csc_matrix, - sparse.bsr_matrix, - sparse.coo_matrix, - sparse.dia_matrix, - sparse.dok_matrix, - sparse.lil_matrix, - ): + for sparse_constructor in sparse_container: # check that the dense and sparse inputs lead to the same results # precompute the matrix to avoid catching side warnings X_train_sp = sparse_constructor(X_train) From 32939bf4ebc656d357cfa42140ca2bdc8c716f25 Mon Sep 17 00:00:00 2001 From: Mohit Date: Tue, 12 Sep 2023 14:46:06 +0530 Subject: [PATCH 2/4] fix: removed the loop --- sklearn/preprocessing/tests/test_common.py | 39 ++++++++++------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 80901390aeba9..15ebd8cf73380 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -54,14 +54,12 @@ def _get_valid_samples_by_column(X, col): ) @pytest.mark.parametrize( "sparse_container", - [ - CSR_CONTAINERS, - CSC_CONTAINERS, - COO_CONTAINERS, - LIL_CONTAINERS, - DOK_CONTAINERS, - BSR_CONTAINERS, - ], + CSR_CONTAINERS + + CSC_CONTAINERS + + COO_CONTAINERS + + LIL_CONTAINERS + + DOK_CONTAINERS + + BSR_CONTAINERS, ) def test_missing_value_handling( est, func, support_sparse, strictly_positive, omit_kwargs, sparse_container @@ -131,21 +129,20 @@ def test_missing_value_handling( Xt_dense = est_dense.fit(X_train).transform(X_test) Xt_inv_dense = est_dense.inverse_transform(Xt_dense) - for sparse_constructor in sparse_container: # check that the dense and sparse inputs lead to the same results # precompute the matrix to avoid catching side warnings - X_train_sp = sparse_constructor(X_train) - X_test_sp = sparse_constructor(X_test) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", PendingDeprecationWarning) - warnings.simplefilter("error", RuntimeWarning) - Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) - - assert_allclose(Xt_sp.A, Xt_dense) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", PendingDeprecationWarning) - warnings.simplefilter("error", RuntimeWarning) - Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) + X_train_sp = sparse_container(X_train) + X_test_sp = sparse_container(X_test) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + warnings.simplefilter("error", RuntimeWarning) + Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) + + assert_allclose(Xt_sp.A, Xt_dense) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + warnings.simplefilter("error", RuntimeWarning) + Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) assert_allclose(Xt_inv_sp.A, Xt_inv_dense) From 5dea1212e1775e23e893fd9ebafc80f039458c17 Mon Sep 17 00:00:00 2001 From: Mohit Date: Wed, 13 Sep 2023 18:25:57 +0530 Subject: [PATCH 3/4] added suggestion --- sklearn/preprocessing/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 15ebd8cf73380..d86a38149bb91 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -138,13 +138,13 @@ def test_missing_value_handling( warnings.simplefilter("error", RuntimeWarning) Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) - assert_allclose(Xt_sp.A, Xt_dense) + assert_allclose(Xt_sp.toarray(), Xt_dense) with warnings.catch_warnings(): warnings.simplefilter("ignore", PendingDeprecationWarning) warnings.simplefilter("error", RuntimeWarning) Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) - assert_allclose(Xt_inv_sp.A, Xt_inv_dense) + assert_allclose(Xt_inv_sp.toarray(), Xt_inv_dense) @pytest.mark.parametrize( From 1b74ae6036b3f7051952d58b549b539e4b2519c8 Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Sat, 30 Sep 2023 16:37:56 +0800 Subject: [PATCH 4/4] Resolve conversations: still use for loop to avoid recomputation of dense --- sklearn/preprocessing/tests/test_common.py | 45 +++++++++++----------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index d86a38149bb91..09f702f64ce23 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -26,6 +26,7 @@ COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS, + DIA_CONTAINERS, DOK_CONTAINERS, LIL_CONTAINERS, ) @@ -52,17 +53,8 @@ def _get_valid_samples_by_column(X, col): (RobustScaler(with_centering=False), robust_scale, True, False, []), ], ) -@pytest.mark.parametrize( - "sparse_container", - CSR_CONTAINERS - + CSC_CONTAINERS - + COO_CONTAINERS - + LIL_CONTAINERS - + DOK_CONTAINERS - + BSR_CONTAINERS, -) def test_missing_value_handling( - est, func, support_sparse, strictly_positive, omit_kwargs, sparse_container + est, func, support_sparse, strictly_positive, omit_kwargs ): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) @@ -129,20 +121,29 @@ def test_missing_value_handling( Xt_dense = est_dense.fit(X_train).transform(X_test) Xt_inv_dense = est_dense.inverse_transform(Xt_dense) + for sparse_container in ( + BSR_CONTAINERS + + COO_CONTAINERS + + CSC_CONTAINERS + + CSR_CONTAINERS + + DIA_CONTAINERS + + DOK_CONTAINERS + + LIL_CONTAINERS + ): # check that the dense and sparse inputs lead to the same results # precompute the matrix to avoid catching side warnings - X_train_sp = sparse_container(X_train) - X_test_sp = sparse_container(X_test) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", PendingDeprecationWarning) - warnings.simplefilter("error", RuntimeWarning) - Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) - - assert_allclose(Xt_sp.toarray(), Xt_dense) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", PendingDeprecationWarning) - warnings.simplefilter("error", RuntimeWarning) - Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) + X_train_sp = sparse_container(X_train) + X_test_sp = sparse_container(X_test) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + warnings.simplefilter("error", RuntimeWarning) + Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) + + assert_allclose(Xt_sp.toarray(), Xt_dense) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + warnings.simplefilter("error", RuntimeWarning) + Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) assert_allclose(Xt_inv_sp.toarray(), Xt_inv_dense)