8000 FIX Fix handling of binary_only tag in check_estimator (#17812) · glemaitre/scikit-learn@f6c77e9 · GitHub
[go: up one dir, main page]

Skip to content

Commit f6c77e9

Browse files
brcharronglemaitre
authored andcommitted
FIX Fix handling of binary_only tag in check_estimator (scikit-learn#17812)
Co-authored-by: Bruno Charron <bruno@charron.email>
1 parent 3a0af83 commit f6c77e9

File tree

3 files changed

+37
-49
lines changed

3 files changed

+37
-49
lines changed

doc/whats_new/v0.23.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,14 @@ Changelog
9090
- |Fix| Fixed a bug in :func:`metrics.mean_squared_error` where the
9191
average of multiple RMSE values was incorrectly calculated as the root of the
9292
average of multiple MSE values.
93-
:pr:`17309` by :user:`Swier Heeres <swierh>`
93+
:pr:`17309` by :user:`Swier Heeres <swierh>`.
94+
95+
:mod:`sklearn.utils`
96+
....................
97+
98+
- |Fix| Fix :func:`utils.estimator_checks.check_estimator` so that all test
99+
cases support the `binary_only` estimator tag.
100+
:pr:`17812` by :user:`Bruno Charron <brcharron>`.
94101

95102
.. _changes_0_23_1:
96103

sklearn/utils/estimator_checks.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -719,15 +719,12 @@ def check_estimator_sparse_data(name, estimator_orig):
719719
X[X < .8] = 0
720720
X = _pairwise_estimator_convert_X(X, estimator_orig)
721721
X_csr = sparse.csr_matrix(X)
722-
tags = estimator_orig._get_tags()
723-
if tags['binary_only']:
724-
y = (2 * rng.rand(40)).astype(np.int)
725-
else:
726-
y = (4 * rng.rand(40)).astype(np.int)
722+
y = (4 * rng.rand(40)).astype(int)
727723
# catch deprecation warnings
728724
with ignore_warnings(category=FutureWarning):
729725
estimator = clone(estimator_orig)
730726
y = _enforce_estimator_tags_y(estimator, y)
727+
tags = estimator_orig._get_tags()
731728
for matrix_format, X in _generate_sparse_matrix(X_csr):
732729
# catch deprecation warnings
733730
with ignore_warnings(category=FutureWarning):
@@ -825,10 +822,7 @@ def check_sample_weights_list(name, estimator_orig):
825822
n_samples = 30
826823
X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
827824
estimator_orig)
828-
if estimator._get_tags()['binary_only']:
829-
y = np.arange(n_samples) % 2
830-
else:
831-
y = np.arange(n_samples) % 3
825+
y = np.arange(n_samples) % 3
832826
y = _enforce_estimator_tags_y(estimator, y)
833827
sample_weight = [3] * n_samples
834828
# Test that estimators don't raise any exception
@@ -905,10 +899,7 @@ def check_dtype_object(name, estimator_orig):
905899
X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
906900
X = X.astype(object)
907901
tags = estimator_orig._get_tags()
908-
if tags['binary_only']:
909-
y = (X[:, 0] * 2).astype(np.int)
910-
else:
911-
y = (X[:, 0] * 4).astype(np.int)
902+
y = (X[:, 0] * 4).astype(int)
912903
estimator = clone(estimator_orig)
913904
y = _enforce_estimator_tags_y(estimator, y)
914905

@@ -1007,9 +998,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
1007998
rnd = np.random.RandomState(0)
1008999< 57AE div class="diff-text-inner"> X = 3 * rnd.uniform(size=(20, 3))
10091000
X = _pairwise_estimator_convert_X(X, estimator_orig)
1010-
y = X[:, 0].astype(np.int)
1011-
if estimator._get_tags()['binary_only']:
1012-
y[y == 2] = 1
1001+
y = X[:, 0].astype(int)
10131002
y = _enforce_estimator_tags_y(estimator, y)
10141003

10151004
if hasattr(estimator, "n_components"):
@@ -1060,8 +1049,6 @@ def check_fit2d_predict1d(name, estimator_orig):
10601049
X = _pairwise_estimator_convert_X(X, estimator_orig)
10611050
y = X[:, 0].astype(np.int)
10621051
tags = estimator_orig._get_tags()
1063-
if tags['binary_only']:
1064-
y[y == 2] = 1
10651052
estimator = clone(estimator_orig)
10661053
y = _enforce_estimator_tags_y(estimator, y)
10671054

@@ -1109,9 +1096,7 @@ def check_methods_subset_invariance(name, estimator_orig):
11091096
rnd = np.random.RandomState(0)
11101097
X = 3 * rnd.uniform(size=(20, 3))
11111098
X = _pairwise_estimator_convert_X(X, estimator_orig)
1112-
y = X[:, 0].astype(np.int)
1113-
if estimator_orig._get_tags()['binary_only']:
1114-
y[y == 2] = 1
1099+
y = X[:, 0].astype(int)
11151100
estimator = clone(estimator_orig)
11161101
y = _enforce_estimator_tags_y(estimator, y)
11171102

@@ -1383,10 +1368,7 @@ def check_fit_score_takes_y(name, estimator_orig):
13831368
n_samples = 30
13841369
X = rnd.uniform(size=(n_samples, 3))
13851370
X = _pairwise_estimator_convert_X(X, estimator_orig)
1386-
if estimator_orig._get_tags()['binary_only']:
1387-
y = np.arange(n_samples) % 2
1388-
else:
1389-
y = np.arange(n_samples) % 3
1371+
y = np.arange(n_samples) % 3
13901372
estimator = clone(estimator_orig)
13911373
y = _enforce_estimator_tags_y(estimator, y)
13921374
set_random_state(estimator)
@@ -1416,8 +1398,6 @@ def check_estimators_dtypes(name, estimator_orig):
14161398
X_train_int_64 = X_train_32.astype(np.int64)
14171399
X_train_int_32 = X_train_32.astype(np.int32)
14181400
y = X_train_int_64[:, 0]
1419-
if estimator_orig._get_tags()['binary_only']:
1420-
y[y == 2] = 1
14211401
y = _enforce_estimator_tags_y(estimator_orig, y)
14221402

14231403
methods = ["predict", "transform", "decision_function", "predict_proba"]
@@ -1596,6 +1576,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
15961576
estimator = clone(estimator_orig)
15971577
X, y = make_blobs(n_samples=50, random_state=1)
15981578
X -= X.min()
1579+
y = _enforce_estimator_tags_y(estimator_orig, y)
15991580

16001581
try:
16011582
if is_classifier(estimator):
@@ -2062,11 +2043,7 @@ def check_classifiers_multilabel_representation_invariance(name,
20622043
def check_estimators_fit_returns_self(name, estimator_orig,
20632044
readonly_memmap=False):
20642045
"""Check if self is returned when calling fit"""
2065-
if estimator_orig._get_tags()['binary_only']:
2066-
n_centers = 2
2067-
else:
2068-
n_centers = 3
2069-
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
2046+
X, y = make_blobs(random_state=0, n_samples=21)
20702047
# some want non-negative input
20712048
X -= X.min()
20722049
X = _pairwise_estimator_convert_X(X, estimator_orig)
@@ -2108,10 +2085,7 @@ def check_supervised_y_2d(name, estimator_orig):
21082085
X = _pairwise_estimator_convert_X(
21092086
rnd.uniform(size=(n_samples, 3)), estimator_orig
21102087
)
2111-
if tags['binary_only']:
2112-
y = np.arange(n_samples) % 2
2113-
else:
2114-
y = np.arange(n_samples) % 3
2088+
y = np.arange(n_samples) % 3
21152089
y = _enforce_estimator_tags_y(estimator_orig, y)
21162090
estimator = clone(estimator_orig)
21172091
set_random_state(estimator)
@@ -2436,11 +2410,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
24362410

24372411
@ignore_warnings(category=FutureWarning)
24382412
def check_estimators_overwrite_params(name, estimator_orig):
2439-
if estimator_orig._get_tags()['binary_only']:
2440-
n_centers = 2
2441-
else:
2442-
n_centers = 3
2443-
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
2413+
X, y = make_blobs(random_state=0, n_samples=21)
24442414
# some want non-negative input
24452415
X -= X.min()
24462416
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
@@ -2511,7 +2481,8 @@ def check_no_attributes_set_in_init(name, estimator_orig):
25112481
def check_sparsify_coefficients(name, estimator_orig):
25122482
X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1],
25132483
[-1, -2], [2, 2], [-2, -2]])
2514-
y = [1, 1, 1, 2, 2, 2, 3, 3, 3]
2484+
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3])
2485+
y = _enforce_estimator_tags_y(estimator_orig, y)
25152486
est = clone(estimator_orig)
25162487

25172488
est.fit(X, y)
@@ -2535,7 +2506,7 @@ def check_classifier_data_not_an_array(name, estimator_orig):
25352506
X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1],
25362507
[0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]])
25372508
X = _pairwise_estimator_convert_X(X, estimator_orig)
2538-
y = [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2]
2509+
y = np.array([1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2])
25392510
y = _enforce_estimator_tags_y(estimator_orig, y)
25402511
for obj_type in ["NotAnArray", "PandasDataframe"]:
25412512
check_estimators_data_not_an_array(name, estimator_orig, X, y,
@@ -2682,6 +2653,9 @@ def _enforce_estimator_tags_y(estimator, y):
26822653
# Create strictly positive y. The minimal increment above 0 is 1, as
26832654
# y could be of integer dtype.
26842655
y += 1 + abs(y.min())
2656+
# Estimators with a `binary_only` tag only accept up to two unique y values
2657+
if estimator._get_tags()["binary_only"] and y.size > 0:
2658+
y = np.where(y == y.flat[0], y, y.flat[0] + 1)
26852659
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
26862660
# Convert into a 2-D y for those estimators.
26872661
if estimator._get_tags()["multioutput_only"]:

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
3535
from sklearn.svm import SVC
3636
from sklearn.neighbors import KNeighborsRegressor
37-
from sklearn.tree import DecisionTreeClassifier
3837
from sklearn.utils.validation import check_array
3938
from sklearn.utils import all_estimators
4039

@@ -306,11 +305,19 @@ def predict(self, X):
306305
return np.array([self.value_] * X.shape[0])
307306

308307

309-
class UntaggedBinaryClassifier(DecisionTreeClassifier):
308+
class UntaggedBinaryClassifier(SGDClassifier):
310309
# Toy classifier that only supports binary classification, will fail tests.
311-
def fit(self, X, y, sample_weight=None):
312-
super().fit(X, y, sample_weight)
313-
if np.all(self.n_classes_ > 2):
310+
def fit(self, X, y, coef_init=None, intercept_init=None,
311+
sample_weight=None):
312+
super().fit(X, y, coef_init, intercept_init, sample_weight)
313+
if len(self.classes_) > 2:
314+
raise ValueError('Only 2 classes are supported')
315+
return self
316+
317+
def partial_fit(self, X, y, classes=None, sample_weight=None):
318+
super().partial_fit(X=X, y=y, classes=classes,
319+
sample_weight=sample_weight)
320+
if len(self.classes_) > 2:
314321
raise ValueError('Only 2 classes are supported')
315322
return self
316323

0 commit comments

Comments
 (0)
0