From 7448c04c5b7aa9a694da1b893ea3b7409a3c5fe9 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 16 Jul 2024 12:57:23 +0200 Subject: [PATCH 1/4] MNT Fix E721 linting issues to do type comparisons with is --- .../plot_tweedie_regression_insurance_claims.py | 2 +- sklearn/cluster/_optics.py | 2 +- sklearn/cluster/tests/test_dbscan.py | 2 +- sklearn/linear_model/tests/test_ridge.py | 4 ++-- sklearn/metrics/pairwise.py | 2 +- sklearn/model_selection/_split.py | 2 +- sklearn/model_selection/tests/test_validation.py | 8 ++++---- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_validation.py | 2 +- sklearn/utils/validation.py | 2 +- 10 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/linear_model/plot_tweedie_regression_insurance_claims.py b/examples/linear_model/plot_tweedie_regression_insurance_claims.py index 31a91fb37c766..b18702bdef2b6 100644 --- a/examples/linear_model/plot_tweedie_regression_insurance_claims.py +++ b/examples/linear_model/plot_tweedie_regression_insurance_claims.py @@ -79,7 +79,7 @@ def load_mtpl2(n_samples=None): df["ClaimAmount"] = df["ClaimAmount"].fillna(0) # unquote string fields - for column_name in df.columns[df.dtypes.values == object]: + for column_name in df.columns[[t is object for t in df.dtypes.values]]: df[column_name] = df[column_name].str.strip("'") return df.iloc[:n_samples] diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index b2a0c4d642a00..70eee67b0a98b 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -324,7 +324,7 @@ def fit(self, X, y=None): Returns a fitted instance of self. """ dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float - if dtype == bool and X.dtype != bool: + if dtype is bool and X.dtype != bool: msg = ( "Data will be converted to boolean for" f" metric {self.metric}, to avoid this warning," diff --git a/sklearn/cluster/tests/test_dbscan.py b/sklearn/cluster/tests/test_dbscan.py index d42cc2b17d518..556f89312d2fc 100644 --- a/sklearn/cluster/tests/test_dbscan.py +++ b/sklearn/cluster/tests/test_dbscan.py @@ -291,7 +291,7 @@ def test_input_validation(): def test_pickle(): obj = DBSCAN() s = pickle.dumps(obj) - assert type(pickle.loads(s)) == obj.__class__ + assert type(pickle.loads(s)) is obj.__class__ def test_boundaries(): diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 167ce0bac4cba..9be28cac141b1 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -1020,7 +1020,7 @@ def _test_ridge_cv(sparse_container): ridge_cv.predict(X) assert len(ridge_cv.coef_.shape) == 1 - assert type(ridge_cv.intercept_) == np.float64 + assert type(ridge_cv.intercept_) is np.float64 cv = KFold(5) ridge_cv.set_params(cv=cv) @@ -1028,7 +1028,7 @@ def _test_ridge_cv(sparse_container): ridge_cv.predict(X) assert len(ridge_cv.coef_.shape) == 1 - assert type(ridge_cv.intercept_) == np.float64 + assert type(ridge_cv.intercept_) is np.float64 @pytest.mark.parametrize( diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index f8b163813d6d6..3408a840ef336 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -2388,7 +2388,7 @@ def pairwise_distances( dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else "infer_float" - if dtype == bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)): + if dtype is bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)): msg = "Data was converted to boolean for metric %s" % metric warnings.warn(msg, DataConversionWarning) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index bfd741eee5811..af35f903e4832 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -2935,7 +2935,7 @@ def _build_repr(self): value = getattr(self, key, None) if value is None and hasattr(self, "cvargs"): value = self.cvargs.get(key, None) - if len(w) and w[0].category == FutureWarning: + if len(w) and w[0].category is FutureWarning: # if the parameter is deprecated, don't show it continue finally: diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 33d4d366bf17a..911a3bac2d672 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -586,10 +586,10 @@ def custom_scorer(clf, X, y): ) # Make sure all the arrays are of np.ndarray type - assert type(cv_results["test_r2"]) == np.ndarray - assert type(cv_results["test_neg_mean_squared_error"]) == np.ndarray - assert type(cv_results["fit_time"]) == np.ndarray - assert type(cv_results["score_time"]) == np.ndarray + assert type(cv_results["test_r2"]) is np.ndarray + assert type(cv_results["test_neg_mean_squared_error"]) is np.ndarray + assert type(cv_results["fit_time"]) is np.ndarray + assert type(cv_results["score_time"]) is np.ndarray # Ensure all the times are within sane limits assert np.all(cv_results["fit_time"] >= 0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 422a23bb5ef72..d3e10616c17b2 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1501,7 +1501,7 @@ def _apply_on_subsets(func, X): result_by_batch = [func(batch.reshape(1, n_features)) for batch in X] # func can output tuple (e.g. score_samples) - if type(result_full) == tuple: + if type(result_full) is tuple: result_full = result_full[0] result_by_batch = list(map(lambda x: x[0], result_by_batch)) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5bde51ae514d9..c567cafbac624 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1341,7 +1341,7 @@ def test_check_scalar_invalid( include_boundaries=include_boundaries, ) assert str(raised_error.value) == str(err_msg) - assert type(raised_error.value) == type(err_msg) + assert type(raised_error.value) is type(err_msg) _psd_cases_valid = { diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index af9fdb4a79cba..612d93f1b21aa 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -879,7 +879,7 @@ def is_sparse(dtype): ) if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig): dtype_orig = np.result_type(*dtypes_orig) - elif pandas_requires_conversion and any(d == object for d in dtypes_orig): + elif pandas_requires_conversion and any(d is object for d in dtypes_orig): # Force object if any of the dtypes is an object dtype_orig = object From 8705dc32aaaa7932a989c0b55b76c050a29185cb Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 30 Jul 2024 13:49:22 +0200 Subject: [PATCH 2/4] apply Guillaume's suggestions --- sklearn/cluster/_optics.py | 2 +- sklearn/metrics/pairwise.py | 4 +++- sklearn/model_selection/tests/test_validation.py | 8 ++++---- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_validation.py | 2 +- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 816b46b0d4984..347c33869aaf4 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -327,7 +327,7 @@ def fit(self, X, y=None): Returns a fitted instance of self. """ dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float - if dtype is bool and X.dtype != bool: + if dtype is bool and X.dtype is not bool: msg = ( "Data will be converted to boolean for" f" metric {self.metric}, to avoid this warning," diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index c463d5dfc1cdb..2ee58f891df6d 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -2433,7 +2433,9 @@ def pairwise_distances( dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else "infer_float" - if dtype is bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)): + if dtype is bool and ( + X.dtype is not bool or (Y is not None and Y.dtype is not bool) + ): msg = "Data was converted to boolean for metric %s" % metric warnings.warn(msg, DataConversionWarning) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 2578a86fa8103..4ff69fe1a1c9e 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -586,10 +586,10 @@ def custom_scorer(clf, X, y): ) # Make sure all the arrays are of np.ndarray type - assert type(cv_results["test_r2"]) is np.ndarray - assert type(cv_results["test_neg_mean_squared_error"]) is np.ndarray - assert type(cv_results["fit_time"]) is np.ndarray - assert type(cv_results["score_time"]) is np.ndarray + assert isinstance(cv_results["test_r2"], np.ndarray) + assert isinstance(cv_results["test_neg_mean_squared_error"], np.ndarray) + assert isinstance(cv_results["fit_time"], np.ndarray) + assert isinstance(cv_results["score_time"], np.ndarray) # Ensure all the times are within sane limits assert np.all(cv_results["fit_time"] >= 0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7421259ab2aab..ea7e8f52c1dca 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1502,7 +1502,7 @@ def _apply_on_subsets(func, X): result_by_batch = [func(batch.reshape(1, n_features)) for batch in X] # func can output tuple (e.g. score_samples) - if type(result_full) is tuple: + if isinstance(result_full, tuple): result_full = result_full[0] result_by_batch = list(map(lambda x: x[0], result_by_batch)) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 363670ef893c0..3afe02ffeef6f 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1339,7 +1339,7 @@ def test_check_scalar_invalid( include_boundaries=include_boundaries, ) assert str(raised_error.value) == str(err_msg) - assert type(raised_error.value) is type(err_msg) + assert isinstance(raised_error.value, type(err_msg)) _psd_cases_valid = { From 488a191e0f5fe65425494fcfe936c883ac87809b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 31 Jul 2024 12:17:09 +0200 Subject: [PATCH 3/4] revert numpy type comparison with is --- sklearn/cluster/_optics.py | 2 +- sklearn/metrics/pairwise.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 347c33869aaf4..816b46b0d4984 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -327,7 +327,7 @@ def fit(self, X, y=None): Returns a fitted instance of self. """ dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float - if dtype is bool and X.dtype is not bool: + if dtype is bool and X.dtype != bool: msg = ( "Data will be converted to boolean for" f" metric {self.metric}, to avoid this warning," diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 2ee58f891df6d..c463d5dfc1cdb 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -2433,9 +2433,7 @@ def pairwise_distances( dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else "infer_float" - if dtype is bool and ( - X.dtype is not bool or (Y is not None and Y.dtype is not bool) - ): + if dtype is bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)): msg = "Data was converted to boolean for metric %s" % metric warnings.warn(msg, DataConversionWarning) From 13ce565c2824eb3b9e22a98f379323d0415fbdad Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 6 Aug 2024 12:45:41 +0200 Subject: [PATCH 4/4] fix another is vs == issue --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 0747db1919438..8a8c12506216e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -926,7 +926,7 @@ def is_sparse(dtype): ) if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig): dtype_orig = np.result_type(*dtypes_orig) - elif pandas_requires_conversion and any(d is object for d in dtypes_orig): + elif pandas_requires_conversion and any(d == object for d in dtypes_orig): # Force object if any of the dtypes is an object dtype_orig = object