8000 TST check multilabel common check for supported estimators (#19859) · scikit-learn/scikit-learn@a44e9a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit a44e9a8

Browse files
glemaitrejjerphan
andauthored
TST check multilabel common check for supported estimators (#19859)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent f812e2a commit a44e9a8

File tree

7 files changed

+464
-19
lines changed

7 files changed

+464
-19
lines changed

sklearn/ensemble/_forest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5151
from joblib import Parallel
5252

5353
from ..base import is_classifier
54-
from ..base import ClassifierMixin, RegressorMixin, MultiOutputMixin
54+
from ..base import ClassifierMixin, MultiOutputMixin, RegressorMixin
5555
from ..metrics import accuracy_score, r2_score
5656
from ..preprocessing import OneHotEncoder
5757
from ..tree import (
@@ -1052,6 +1052,9 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
10521052

10531053
return averaged_predictions
10541054

1055+
def _more_tags(self):
1056+
return {"multilabel": True}
1057+
10551058

10561059
class RandomForestClassifier(ForestClassifier):
10571060
"""

sklearn/linear_model/_ridge.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._base import LinearClassifierMixin, LinearModel
2222
from ._base import _deprecate_normalize, _rescale_data
2323
from ._sag import sag_solver
24-
from ..base import RegressorMixin, MultiOutputMixin, is_classifier
24+
from ..base import MultiOutputMixin, RegressorMixin, is_classifier
2525
from ..utils.extmath import safe_sparse_dot
2626
from ..utils.extmath import row_norms
2727
from ..utils import check_array
@@ -2319,9 +2319,17 @@ def classes_(self):
23192319

23202320
def _more_tags(self):
23212321
return {
2322+
"multilabel": True,
23222323
"_xfail_checks": {
23232324
"check_sample_weights_invariance": (
23242325
"zero sample_weight is not equivalent to removing samples"
23252326
),
2326-
}
2327+
# FIXME: see
2328+
# https://github.com/scikit-learn/scikit-learn/issues/19858
2329+
# to track progress to resolve this issue
2330+
"check_classifiers_multilabel_output_format_predict": (
2331+
"RidgeClassifierCV.predict outputs an array of shape (25,) "
2332+
"instead of (25, 5)"
2333+
),
2334+
},
23272335
}

sklearn/neighbors/_classification.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ def predict_proba(self, X):
287287

288288
return probabilities
289289

290+
def _more_tags(self):
291+
return {"multilabel": True}
292+
290293

291294
class RadiusNeighborsClassifier(RadiusNeighborsMixin, ClassifierMixin, NeighborsBase):
292295
"""Classifier implementing a vote among neighbors within a given radius
@@ -651,3 +654,6 @@ def predict_proba(self, X):
651654
probabilities = probabilities[0]
652655

653656
return probabilities
657+
658+
def _more_tags(self):
659+
return {"multilabel": True}

sklearn/neural_network/_multilayer_perceptron.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
# Jiyuan Qian
77
# License: BSD 3 clause
88

9+
from tkinter.tix import Tree
910
import numpy as np
1011

1112
from abc import ABCMeta, abstractmethod
1213
import warnings
1314

1415
import scipy.optimize
1516

16-
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
17+
from ..base import (
18+
BaseEstimator,
19+
ClassifierMixin,
20+
RegressorMixin,
21+
)
1722
from ..base import is_classifier
1823
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
1924
from ._stochastic_optimizers import SGDOptimizer, AdamOptimizer
@@ -1246,6 +1251,9 @@ def predict_proba(self, X):
12461251
else:
12471252
return y_pred
12481253

1254+
def _more_tags(self):
1255+
return {"multilabel": Tree}
1256+
12491257

12501258
class MLPRegressor(RegressorMixin, BaseMultilayerPerceptron):
12511259
"""Multi-layer Perceptron regressor.

sklearn/tree/_classes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,9 @@ def predict_log_proba(self, X):
10211021
def n_features_(self):
10221022
return self.n_features_in_
10231023

1024+
def _more_tags(self):
1025+
return {"multilabel": True}
1026+
10241027

10251028
class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
10261029
"""A decision tree regressor.

sklearn/utils/estimator_checks.py

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._testing import assert_array_almost_equal
2020
from ._testing import assert_allclose
2121
from ._testing import assert_allclose_dense_sparse
22+
from ._testing import assert_array_less
2223
from ._testing import set_random_state
2324
from ._testing import SkipTest
2425
from ._testing import ignore_warnings
@@ -141,6 +142,9 @@ def _yield_classifier_checks(classifier):
141142
yield check_classifiers_regression_target
142143
if tags["multilabel"]:
143144
yield check_classifiers_multilabel_representation_invariance
145+
yield check_classifiers_multilabel_output_format_predict
146+
yield check_classifiers_multilabel_output_format_predict_proba
147+
yield check_classifiers_multilabel_output_format_decision_function
144148
if not tags["no_validation"]:
145149
yield check_supervised_y_no_nan
146150
if not tags["multioutput_only"]:
@@ -651,7 +655,7 @@ def _set_checking_parameters(estimator):
651655
estimator.set_params(strategy="stratified")
652656

653657
# Speed-up by reducing the number of CV or splits for CV estimators
654-
loo_cv = ["RidgeCV"]
658+
loo_cv = ["RidgeCV", "RidgeClassifierCV"]
655659
if name not in loo_cv and hasattr(estimator, "cv"):
656660
estimator.set_params(cv=3)
657661
if hasattr(estimator, "n_splits"):
@@ -2258,18 +2262,18 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
22582262
estimator.fit(X)
22592263

22602264

2261-
@ignore_warnings(category=(FutureWarning))
2265+
@ignore_warnings(category=FutureWarning)
22622266
def check_classifiers_multilabel_representation_invariance(name, classifier_orig):
2263-
22642267
X, y = make_multilabel_classification(
22652268
n_samples=100,
2266-
n_features=20,
2269+
n_features=2,
22672270
n_classes=5,
22682271
n_labels=3,
22692272
length=50,
22702273
allow_unlabeled=True,
22712274
random_state=0,
22722275
)
2276+
X = scale(X)
22732277

22742278
X_train, y_train = X[:80], y[:80]
22752279
X_test = X[80:]
@@ -2299,6 +2303,181 @@ def check_classifiers_multilabel_representation_invariance(name, classifier_orig
22992303
assert type(y_pred) == type(y_pred_list_of_lists)
23002304

23012305

2306+
@ignore_warnings(category=FutureWarning)
2307+
def check_classifiers_multilabel_output_format_predict(name, classifier_orig):
2308+
"""Check the output of the `predict` method for classifiers supporting
2309+
multilabel-indicator targets."""
2310+
classifier = clone(classifier_orig)
2311+
set_random_state(classifier)
2312+
2313+
n_samples, test_size, n_outputs = 100, 25, 5
2314+
X, y = make_multilabel_classification(
2315+
n_samples=n_samples,
2316+
n_features=2,
2317+
n_classes=n_outputs,
2318+
n_labels=3,
2319+
length=50,
2320+
allow_unlabeled=True,
2321+
random_state=0,
2322+
)
2323+
X = scale(X)
2324+
2325+
X_train, X_test = X[:-test_size], X[-test_size:]
2326+
y_train, y_test = y[:-test_size], y[-test_size:]
2327+
classifier.fit(X_train, y_train)
2328+
2329+
response_method_name = "predict"
2330+
predict_method = getattr(classifier, response_method_name, None)
2331+
if predict_method is None:
2332+
raise SkipTest(f"{name} does not have a {response_method_name} method.")
2333+
2334+
y_pred = predict_method(X_test)
2335+
2336+
# y_pred.shape -> y_test.shape with the same dtype
2337+
assert isinstance(y_pred, np.ndarray), (
2338+
f"{name}.predict is expected to output a NumPy array. Got "
2339+
f"{type(y_pred)} instead."
2340+
)
2341+
assert y_pred.shape == y_test.shape, (
2342+
f"{name}.predict outputs a NumPy array of shape {y_pred.shape} "
2343+
f"instead of {y_test.shape}."
2344+
)
2345+
assert y_pred.dtype == y_test.dtype, (
2346+
f"{name}.predict does not output the same dtype than the targets. "
2347+
f"Got {y_pred.dtype} instead of {y_test.dtype}."
2348+
)
2349+
2350+
2351+
@ignore_warnings(category=FutureWarning)
2352+
def check_classifiers_multilabel_output_format_predict_proba(name, classifier_orig):
2353+
"""Check the output of the `predict_proba` method for classifiers supporting
2354+
multilabel-indicator targets."""
2355+
classifier = clone(classifier_orig)
2356+
set_random_state(classifier)
2357+
2358+
n_samples, test_size, n_outputs = 100, 25, 5
2359+
X, y = make_multilabel_classification(
2360+
n_samples=n_samples,
2361+
n_features=2,
2362+
n_classes=n_outputs,
2363+
n_labels=3,
2364+
length=50,
2365+
allow_unlabeled=True,
2366+
random_state=0,
2367+
)
2368+
X = scale(X)
2369+
2370+
X_train, X_test = X[:-test_size], X[-test_size:]
2371+
y_train = y[:-test_size]
2372+
classifier.fit(X_train, y_train)
2373+
2374+
response_method_name = "predict_proba"
2375+
predict_proba_method = getattr(classifier, response_method_name, None)
2376+
if predict_proba_method is None:
2377+
raise SkipTest(f"{name} does not have a {response_method_name} method.")
2378+
2379+
y_pred = predict_proba_method(X_test)
2380+
2381+
# y_pred.shape -> 2 possibilities:
2382+
# - list of length n_outputs of shape (n_samples, 2);
2383+
# - ndarray of shape (n_samples, n_outputs).
2384+
# dtype should be floating
2385+
if isinstance(y_pred, list):
2386+
assert len(y_pred) == n_outputs, (
2387+
f"When {name}.predict_proba returns a list, the list should "
2388+
"be of length n_outputs and contain NumPy arrays. Got length "
2389+
f"of {len(y_pred)} instead of {n_outputs}."
2390+
)
2391+
for pred in y_pred:
2392+
assert pred.shape == (test_size, 2), (
2393+
f"When {name}.predict_proba returns a list, this list "
2394+
"should contain NumPy arrays of shape (n_samples, 2). Got "
2395+
f"NumPy arrays of shape {pred.shape} instead of "
2396+
f"{(test_size, 2)}."
2397+
)
2398+
assert pred.dtype.kind == "f", (
2399+
f"When {name}.predict_proba returns a list, it should "
2400+
"contain NumPy arrays with floating dtype. Got "
2401+
f"{pred.dtype} instead."
2402+
)
2403+
# check that we have the correct probabilities
2404+
err_msg = (
2405+
f"When {name}.predict_proba returns a list, each NumPy "
2406+
"array should contain probabilities for each class and "
2407+
"thus each row should sum to 1 (or close to 1 due to "
2408+
"numerical errors)."
2409+
)
2410+
assert_allclose(pred.sum(axis=1), 1, err_msg=err_msg)
2411+
elif isinstance(y_pred, np.ndarray):
2412+
assert y_pred.shape == (test_size, n_outputs), (
2413+
f"When {name}.predict_proba returns a NumPy array, the "
2414+
f"expected shape is (n_samples, n_outputs). Got {y_pred.shape}"
2415+
f" instead of {(test_size, n_outputs)}."
2416+
)
2417+
assert y_pred.dtype.kind == "f", (
2418+
f"When {name}.predict_proba returns a NumPy array, the "
2419+
f"expected data type is floating. Got {y_pred.dtype} instead."
2420+
)
2421+
err_msg = (
2422+
f"When {name}.predict_proba returns a NumPy array, this array "
2423+
"is expected to provide probabilities of the positive class "
2424+
"and should therefore contain values between 0 and 1."
2425+
)
2426+
assert_array_less(0, y_pred, err_msg=err_msg)
2427+
assert_array_less(y_pred, 1, err_msg=err_msg)
2428+
else:
2429+
raise ValueError(
2430+
f"Unknown returned type {type(y_pred)} by {name}."
2431+
"predict_proba. A list or a Numpy array is expected."
2432+
)
2433+
2434+
2435+
@ignore_warnings(category=FutureWarning)
2436+
def check_classifiers_multilabel_output_format_decision_function(name, classifier_orig):
2437+
"""Check the output of the `decision_function` method for classifiers supporting
2438+
multilabel-indicator targets."""
2439+
classifier = clone(classifier_orig)
2440+
set_random_state(classifier)
2441+
2442+
n_samples, test_size, n_outputs = 100, 25, 5
2443+
X, y = make_multilabel_classification(
2444+
n_samples=n_samples,
2445+
n_features=2,
2446+
n_classes=n_outputs,
2447+
n_labels=3,
2448+
length=50,
2449+
allow_unlabeled=True,
2450+
random_state=0,
2451+
)
2452+
X = scale(X)
2453+
2454+
X_train, X_test = X[:-test_size], X[-test_size:]
2455+
y_train = y[:-test_size]
2456+
classifier.fit(X_train, y_train)
2457+
2458+
response_method_name = "decision_function"
2459+
decision_function_method = getattr(classifier, response_method_name, None)
2460+
if decision_function_method is None:
2461+
raise SkipTest(f"< BAC9 span class=pl-s1>{name} does not have a {response_method_name} method.")
2462+
2463+
y_pred = decision_function_method(X_test)
2464+
2465+
# y_pred.shape -> y_test.shape with floating dtype
2466+
assert isinstance(y_pred, np.ndarray), (
2467+
f"{name}.decision_function is expected to output a NumPy array."
2468+
f" Got {type(y_pred)} instead."
2469+
)
2470+
assert y_pred.shape == (test_size, n_outputs), (
2471+
f"{name}.decision_function is expected to provide a NumPy array "
2472+
f"of shape (n_samples, n_outputs). Got {y_pred.shape} instead of "
2473+
f"{(test_size, n_outputs)}."
2474+
)
2475+
assert y_pred.dtype.kind == "f", (
2476+
f"{name}.decision_function is expected to output a floating dtype."
2477+
f" Got {y_pred.dtype} instead."
2478+
)
2479+
2480+
23022481
@ignore_warnings(category=FutureWarning)
23032482
def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False):
23042483
"""Check if self is returned when calling fit."""

0 commit comments

Comments
 (0)
0