8000 MNT Fix E721 linting issues to do type comparisons with is (#29501) · scikit-learn/scikit-learn@68d8c2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 68d8c2c

Browse files
authored
MNT Fix E721 linting issues to do type comparisons with is (#29501)
1 parent b5b24f8 commit 68d8c2c

File tree

9 files changed

+13
-13
lines changed

9 files changed

+13
-13
lines changed

examples/linear_model/plot_tweedie_regression_insurance_claims.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def load_mtpl2(n_samples=None):
8080
df["ClaimAmount"] = df["ClaimAmount"].fillna(0)
8181

8282
# unquote string fields
83-
for column_name in df.columns[df.dtypes.values == object]:
83+
for column_name in df.columns[[t is object for t in df.dtypes.values]]:
8484
df[column_name] = df[column_name].str.strip("'")
8585
return df.iloc[:n_samples]
8686

sklearn/cluster/_optics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def fit(self, X, y=None):
327327
Returns a fitted instance of self.
328328
"""
329329
dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float
330-
if dtype == bool and X.dtype != bool:
330+
if dtype is bool and X.dtype != bool:
331331
msg = (
332332
"Data will be converted to boolean for"
333333
f" metric {self.metric}, to avoid this warning,"

sklearn/cluster/tests/test_dbscan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def test_input_validation():
291291
def test_pickle():
292292
obj = DBSCAN()
293293
s = pickle.dumps(obj)
294-
assert type(pickle.loads(s)) == obj.__class__
294+
assert type(pickle.loads(s)) is obj.__class__
295295

296296

297297
def test_boundaries():

sklearn/linear_model/tests/test_ridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,15 +1020,15 @@ def _test_ridge_cv(sparse_container):
10201020
ridge_cv.predict(X)
10211021

10221022
assert len(ridge_cv.coef_.shape) == 1
1023-
assert type(ridge_cv.intercept_) == np.float64
1023+
assert type(ridge_cv.intercept_) is np.float64
10241024

10251025
cv = KFold(5)
10261026
ridge_cv.set_params(cv=cv)
10271027
ridge_cv.fit(X, y_diabetes)
10281028
ridge_cv.predict(X)
10291029

10301030
assert len(ridge_cv.coef_.shape) == 1
1031-
assert type(ridge_cv.intercept_) == np.float64
1031+
assert type(ridge_cv.intercept_) is np.float64
10321032

10331033

10341034
@pytest.mark.parametrize(

sklearn/metrics/pairwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2438,7 +2438,7 @@ def pairwise_distances(
24382438

24392439
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else "infer_float"
24402440

2441-
if dtype == bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)):
2441+
if dtype is bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)):
24422442
msg = "Data was converted to boolean for metric %s" % metric
24432443
warnings.warn(msg, DataConversionWarning)
24442444

sklearn/model_selection/_split.py 6D40

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2940,7 +2940,7 @@ def _build_repr(self):
29402940
value = getattr(self, key, None)
29412941
if value is None and hasattr(self, "cvargs"):
29422942
value = self.cvargs.get(key, None)
2943-
if len(w) and w[0].category == FutureWarning:
2943+
if len(w) and w[0].category is FutureWarning:
29442944
# if the parameter is deprecated, don't show it
29452945
continue
29462946
finally:

sklearn/model_selection/tests/test_validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,10 @@ def custom_scorer(clf, X, y):
586586
)
587587

588588
# Make sure all the arrays are of np.ndarray type
589-
assert type(cv_results["test_r2"]) == np.ndarray
590-
assert type(cv_results["test_neg_mean_squared_error"]) == np.ndarray
591-
assert type(cv_results["fit_time"]) == np.ndarray
592-
assert type(cv_results["score_time"]) == np.ndarray
589+
assert isinstance(cv_results["test_r2"], np.ndarray)
590+
assert isinstance(cv_results["test_neg_mean_squared_error"], np.ndarray)
591+
assert isinstance(cv_results["fit_time"], np.ndarray)
592+
assert isinstance(cv_results["score_time"], np.ndarray)
593593

594594
# Ensure all the times are within sane limits
595595
assert np.all(cv_results["fit_time"] >= 0)

sklearn/utils/estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,7 +1502,7 @@ def _apply_on_subsets(func, X):
15021502
result_by_batch = [func(batch.reshape(1, n_features)) for batch in X]
15031503

15041504
# func can output tuple (e.g. score_samples)
1505-
if type(result_full) == tuple:
1505+
if isinstance(result_full, tuple):
15061506
result_full = result_full[0]
15071507
result_by_batch = list(map(lambda x: x[0], result_by_batch))
15081508

sklearn/utils/tests/test_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ def test_check_scalar_invalid(
13501350
include_boundaries=include_boundaries,
13511351
)
13521352
assert str(raised_error.value) == str(err_msg)
1353-
assert type(raised_error.value) == type(err_msg)
1353+
assert isinstance(raised_error.value, type(err_msg))
13541354

13551355

13561356
_psd_cases_valid = {

0 commit comments

Comments
 (0)
0