8000 FIX Fix handling of binary_only tag in check_estimator (#17812) · scikit-learn/scikit-learn@9b42b0c · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b42b0c

Browse files
authored
FIX Fix handling of binary_only tag in check_estimator (#17812)
Co-authored-by: Bruno Charron <bruno@charron.email>
1 parent b78d62f commit 9b42b0c

File tree

3 files changed

+34
-46
lines changed

3 files changed

+34
-46
lines changed

doc/whats_new/v0.24.rst

Lines changed: 7 additions & 0 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ Changelog
294294
:meth:`tree.DecisionTreeRegressor.fit`, and has not effect.
295295
:pr:`17614` by :user:`Juan Carlos Alfaro Jiménez <alfaro96>`.
296296

297+
:mod:`sklearn.utils`
298+
.........................
299+
300+
- |Fix| Fix :func:`utils.estimator_checks.check_estimator` so that all test
301+
cases support the `binary_only` estimator tag.
302+
:pr:`17812` by :user:`Bruno Charron <brcharron>`.
303+
297304
Code and Documentation Contributors
298305
-----------------------------------
299306

sklearn/utils/estimator_checks.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -701,15 +701,12 @@ def check_estimator_sparse_data(name, estimator_orig):
701701
X[X < .8] = 0
702702
X = _pairwise_estimator_convert_X(X, estimator_orig)
703703
X_csr = sparse.csr_matrix(X)
704-
tags = estimator_orig._get_tags()
705-
if tags['binary_only']:
706-
y = (2 * rng.rand(40)).astype(int)
707-
else:
708-
y = (4 * rng.rand(40)).astype(int)
704+
y = (4 * rng.rand(40)).astype(int)
709705
# catch deprecation warnings
710706
with ignore_warnings(category=FutureWarning):
711707
estimator = clone(estimator_orig)
712708
y = _enforce_estimator_tags_y(estimator, y)
709+
tags = estimator_orig._get_tags()
713710
for matrix_format, X in _generate_sparse_matrix(X_csr):
714711
# catch deprecation warnings
715712
with ignore_warnings(category=FutureWarning):
@@ -807,10 +804,7 @@ def check_sample_weights_list(name, estimator_orig):
807804
n_samples = 30
808805
X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
809806
estimator_orig)
810-
if estimator._get_tags()['binary_only']:
811-
y = np.arange(n_samples) % 2
812-
else:
813-
y = np.arange(n_samples) % 3
807+
y = np.arange(n_samples) % 3
814808
y = _enforce_estimator_tags_y(estimator, y)
815809
sample_weight = [3] * n_samples
816810
# Test that estimators don't raise any exception
@@ -901,10 +895,7 @@ def check_dtype_object(name, estimator_orig):
901895
X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
902896
X = X.astype(object)
903897
tags = estimator_orig._get_tags()
904-
if tags['binary_only']:
905-
y = (X[:, 0] * 2).astype(int)
906-
else:
907-
y = (X[:, 0] * 4).astype(int)
898+
y = (X[:, 0] * 4).astype(int)
908899
estimator = clone(estimator_orig)
909900
y = _enforce_estimator_tags_y(estimator, y)
910901

@@ -998,8 +989,6 @@ def check_dont_overwrite_parameters(name, estimator_orig):
998989
X = 3 * rnd.uniform(size=(20, 3))
999990
X = _pairwise_estimator_convert_X(X, estimator_orig)
1000991
y = X[:, 0].astype(int)
1001-
if estimator._get_tags()['binary_only']:
1002-
y[y == 2] = 1
1003992
y = _enforce_estimator_tags_y(estimator, y)
1004993

1005994
if hasattr(estimator, "n_components"):
@@ -1050,8 +1039,6 @@ def check_fit2d_predict1d(name, estimator_orig):
10501039
X = _pairwise_estimator_convert_X(X, estimator_orig)
10511040
y = X[:, 0].astype(int)
10521041
tags = estimator_orig._get_tags()
1053-
if tags['binary_only']:
1054-
y[y == 2] = 1
10551042
estimator = clone(estimator_orig)
10561043
y = _enforce_estimator_tags_y(estimator, y)
10571044

@@ -1100,8 +1087,6 @@ def check_methods_subset_invariance(name, estimator_orig):
11001087
X = 3 * rnd.uniform(size=(20, 3))
11011088
X = _pairwise_estimator_convert_X(X, estimator_orig)
11021089
y = X[:, 0].astype(int)
1103-
if estimator_orig._get_tags()['binary_only']:
1104-
y[y == 2] = 1
11051090
estimator = clone(estimator_orig)
11061091
y = _enforce_estimator_tags_y(estimator, y)
11071092

@@ -1373,10 +1358,7 @@ def check_fit_score_takes_y(name, estimator_orig):
13731358
n_samples = 30
13741359
X = rnd.uniform(size=(n_samples, 3))
13751360
X = _pairwise_estimator_convert_X(X, estimator_orig)
1376-
if estimator_orig._get_tags()['binary_only']:
1377-
y = np.arange(n_samples) % 2
1378-
else:
1379-
y = np.arange(n_samples) % 3
1361+
y = np.arange(n_samples) % 3
13801362
estimator = clone(estimator_orig)
13811363
y = _enforce_estimator_tags_y(estimator, y)
13821364
set_random_state(estimator)
@@ -1406,8 +1388,6 @@ def check_estimators_dtypes(name, estimator_orig):
14061388
X_train_int_64 = X_train_32.astype(np.int64)
14071389
X_train_int_32 = X_train_32.astype(np.int32)
14081390
y = X_train_int_64[:, 0]
1409-
if estimator_orig._get_tags()['binary_only']:
1410-
y[y == 2] = 1
14111391
y = _enforce_estimator_tags_y(estimator_orig, y)
14121392

14131393
methods = ["predict", "transform", "decision_function", "predict_proba"]
@@ -1581,6 +1561,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
15811561
estimator = clone(estimator_orig)
15821562
X, y = make_blobs(n_samples=50, random_state=1)
15831563
X -= X.min()
1564+
y = _enforce_estimator_tags_y(estimator_orig, y)
15841565

15851566
try:
15861567
if is_classifier(estimator):
@@ -2047,11 +2028,7 @@ def check_classifiers_multilabel_representation_invariance(name,
20472028
def check_estimators_fit_returns_self(name, estimator_orig,
20482029
readonly_memmap=False):
20492030
"""Check if self is returned when calling fit"""
2050-
if estimator_orig._get_tags()['binary_only']:
2051-
n_centers = 2
2052-
else:
2053-
n_centers = 3
2054-
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
2031+
X, y = make_blobs(random_state=0, n_samples=21)
20552032
# some want non-negative input
20562033
X -= X.min()
20572034
X = _pairwise_estimator_convert_X(X, estimator_orig)
@@ -2093,10 +2070,7 @@ def check_supervised_y_2d(name, estimator_orig):
20932070
X = _pairwise_estimator_convert_X(
20942071
rnd.uniform(size=(n_samples, 3)), estimator_orig
20952072
)
2096-
if tags['binary_only']:
2097-
y = np.arange(n_samples) % 2
2098-
else:
2099-
y = np.arange(n_samples) % 3
2073+
y = np.arange(n_samples) % 3
21002074
y = _enforce_estimator_tags_y(estimator_orig, y)
21012075
estimator = clone(estimator_orig)
21022076
set_random_state(estimator)
@@ -2414,11 +2388,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
24142388

24152389
@ignore_warnings(category=FutureWarning)
24162390
def check_estimators_overwrite_params(name, estimator_orig):
2417-
if estimator_orig._get_tags()['binary_only']:
2418-
n_centers = 2
2419-
else:
2420-
n_centers = 3
2421-
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
2391+
X, y = make_blobs(random_state=0, n_samples=21)
24222392
# some want non-negative input
24232393
X -= X.min()
24242394
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
@@ -2489,7 +2459,8 @@ def check_no_attributes_set_in_init(name, estimator_orig):
24892459
def check_sparsify_coefficients(name, estimator_orig):
24902460
X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1],
24912461
[-1, -2], [2, 2], [-2, -2]])
2492-
y = [1, 1, 1, 2, 2, 2, 3, 3, 3]
2462+
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3])
2463+
y = _enforce_estimator_tags_y(estimator_orig, y)
24932464
est = clone(estimator_orig)
24942465

24952466
est.fit(X, y)
@@ -2513,7 +2484,7 @@ def check_classifier_data_not_an_array(name, estimator_orig):
25132484
X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1],
25142485
[0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]])
25152486
X = _pairwise_estimator_convert_X(X, estimator_orig)
2516-
y = [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2]
2487+
y = np.array([1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2])
25172488
y = _enforce_estimator_tags_y(estimator_orig, y)
25182489
for obj_type in ["NotAnArray", "PandasDataframe"]:
25192490
check_estimators_data_not_an_array(name, estimator_orig, X, y,
@@ -2649,6 +2620,9 @@ def _enforce_estimator_tags_y(estimator, y):
26492620
# Create strictly positive y. The minimal increment above 0 is 1, as
26502621
# y could be of integer dtype.
26512622
y += 1 + abs(y.min())
2623+
# Estimators with a `binary_only` tag only accept up to two unique y values
2624+
if estimator._get_tags()["binary_only"] and y.size > 0:
2625+
y = np.where(y == y.flat[0], y, y.flat[0] + 1)
26522626
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
26532627
# Convert into a 2-D y for those estimators.
26542628
if estimator._get_tags()["multioutput_only"]:

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
3535
from sklearn.svm import SVC, NuSVC
3636
from sklearn.neighbors import KNeighborsRegressor
37-
from sklearn.tree import DecisionTreeClassifier
3837
from sklearn.utils.validation import check_array
3938
from sklearn.utils import all_estimators
4039
from sklearn.exceptions import SkipTestWarning
@@ -307,11 +306,19 @@ def predict(self, X):
307306
return np.array([self.value_] * X.shape[0])
308307

309308

310-
class UntaggedBinaryClassifier(DecisionTreeClassifier):
309+
class UntaggedBinaryClassifier(SGDClassifier):
311310
# Toy classifier that only supports binary classification, will fail tests.
312-
def fit(self, X, y, sample_weight=None):
313-
super().fit(X, y, sample_weight)
314-
if np.all(self.n_classes_ > 2):
311+
def fit(self, X, y, coef_init=None, intercept_init=None,
312+
sample_weight=None):
313+
super().fit(X, y, coef_init, intercept_init, sample_weight)
314+
if len(self.classes_) > 2:
315+
raise ValueError('Only 2 classes are supported')
316+
return self
317+
318+
def partial_fit(self, X, y, classes=None, sample_weight=None):
319+
super().partial_fit(X=X, y=y, classes=classes,
320+
sample_weight=sample_weight)
321+
if len(self.classes_) > 2:
315322
raise ValueError('Only 2 classes are supported')
316323
return self
317324

0 commit comments

Comments
 (0)
0