8000 FIX add support for multilabel classification in RidgeClassifier* (#1… · scikit-learn/scikit-learn@90e04a5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 90e04a5

Browse files
glemaitreagramfortjeremiedbb
authored
FIX add support for multilabel classification in RidgeClassifier* (#19869)
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 90bef46 commit 90e04a5

File tree

5 files changed

+145
-76
lines changed

5 files changed

+145
-76
lines changed

doc/modules/multiclass.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ can provide additional strategies beyond what is built-in:
102102
- :class:`neural_network.MLPClassifier`
103103
- :class:`neighbors.RadiusNeighborsClassifier`
104104
- :class:`ensemble.RandomForestClassifier`
105+
- :class:`linear_model.RidgeClassifier`
105106
- :class:`linear_model.RidgeClassifierCV`
106107

107108

doc/whats_new/v1.1.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ Changelog
8585
:class:`impute.KNNImputer`, :class:`impute.IterativeImputer`, and
8686
:class:`impute.MissingIndicator`. :pr:`21078` by `Thomas Fan`_.
8787

88+
- |Fix| Fix a bug in :class:`linear_model.RidgeClassifierCV` where the method
89+
`predict` was performing an `argmax` on the scores obtained from
90+
`decision_function` instead of returning the multilabel indicator matrix.
91+
:pr:`19869` by :user:`Guillaume Lemaitre <glemaitre>`.
92+
93+
- |Enhancement| :class:`linear_model.RidgeClassifier` is now supporting
94+
multilabel classification.
95+
:pr:`19689` by :user:`Guillaume Lemaitre <glemaitre>`.
96+
8897
:mod:`sklearn.metrics`
8998
......................
9099

sklearn/linear_model/_base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,15 +392,15 @@ def decision_function(self, X):
392392
393393
Parameters
394394
----------
395-
X : array-like or sparse matrix, shape (n_samples, n_features)
396-
Samples.
395+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
396+
The data matrix for which we want to get the confidence scores.
397397
398398
Returns
399399
-------
400-
array, shape=(n_samples,) if n_classes == 2 else (n_samples, n_classes)
401-
Confidence scores per (sample, class) combination. In the binary
402-
case, confidence score for self.classes_[1] where >0 means this
403-
class would be predicted.
400+
scores : ndarray of shape (n_samples,) or (n_samples, n_classes)
401+
Confidence scores per `(n_samples, n_classes)` combination. In the
402+
binary case, confidence score for `self.classes_[1]` where >0 means
403+
this class would be predicted.
404404
"""
405405
check_is_fitted(self)
406406

@@ -414,13 +414,13 @@ def predict(self, X):
414414
415415
Parameters
416416
----------
417-
X : array-like or sparse matrix, shape (n_samples, n_features)
418-
Samples.
417+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
418+
The data matrix for which we want to get the predictions.
419419
420420
Returns
421421
-------
422-
C : array, shape [n_samples]
423-
Predicted class label per sample.
422+
y_pred : ndarray of shape (n_samples,)
423+
Vector containing the class labels for each sample.
424424
"""
425425
scores = self.decision_function(X)
426426
if len(scores.shape) == 1:

sklearn/linear_model/_ridge.py

Lines changed: 103 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..utils import check_consistent_length
2929
from ..utils import compute_sample_weight
3030
from ..utils import column_or_1d
31+
from ..utils.validation import check_is_fitted
3132
from ..utils.validation import _check_sample_weight
3233
from ..preprocessing import LabelBinarizer
3334
from ..model_selection import GridSearchCV
@@ -1010,7 +1011,93 @@ def fit(self, X, y, sample_weight=None):
10101011
return super().fit(X, y, sample_weight=sample_weight)
10111012

10121013

1013-
class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
1014+
class _RidgeClassifierMixin(LinearClassifierMixin):
1015+
def _prepare_data(self, X, y, sample_weight, solver):
1016+
"""Validate `X` and `y` and binarize `y`.
1017+
1018+
Parameters
1019+
----------
1020+
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
1021+
Training data.
1022+
1023+
y : ndarray of shape (n_samples,)
1024+
Target values.
1025+
1026+
sample_weight : float or ndarray of shape (n_samples,), default=None
1027+
Individual weights for each sample. If given a float, every sample
1028+
will have the same weight.
1029+
1030+
solver : str
1031+
The solver used in `Ridge` to know which sparse format to support.
1032+
1033+
Returns
1034+
-------
1035+
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
1036+
Validated training data.
1037+
1038+
y : ndarray of shape (n_samples,)
1039+
Validated target values.
1040+
1041+
sample_weight : ndarray of shape (n_samples,)
1042+
Validated sample weights.
1043+
1044+
Y : ndarray of shape (n_samples, n_classes)
1045+
The binarized version of `y`.
1046+
"""
1047+
accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), solver)
1048+
X, y = self._validate_data(
1049+
X,
1050+
y,
1051+
accept_sparse=accept_sparse,
1052+
multi_output=True,
1053+
y_numeric=False,
1054+
)
1055+
1056+
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
1057+
Y = self._label_binarizer.fit_transform(y)
1058+
if not self._label_binarizer.y_type_.startswith("multilabel"):
1059+
y = column_or_1d(y, warn=True)
1060+
1061+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
1062+
if self.class_weight:
1063+
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)
1064+
return X, y, sample_weight, Y
1065+
1066+
def predict(self, X):
1067+
"""Predict class labels for samples in `X`.
1068+
1069+
Parameters
1070+
----------
1071+
X : {array-like, spare matrix} of shape (n_samples, n_features)
1072+
The data matrix for which we want to predict the targets.
1073+
1074+
Returns
1075+
-------
1076+
y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs)
1077+
Vector or matrix containing the predictions. In binary and
1078+
multiclass problems, this is a vector containing `n_samples`. In
1079+
a multilabel problem, it returns a matrix of shape
1080+
`(n_samples, n_outputs)`.
1081+
"""
1082+
check_is_fitted(self, attributes=["_label_binarizer"])
1083+
if self._label_binarizer.y_type_.startswith("multilabel"):
1084+
# Threshold such that the negative label is -1 and positive label
1085+
# is 1 to use the inverse transform of the label binarizer fitted
1086+
# during fit.
1087+
scores = 2 * (self.decision_function(X) > 0) - 1
1088+
return self._label_binarizer.inverse_transform(scores)
1089+
return super().predict(X)
1090+
1091+
@property
1092+
def classes_(self):
1093+
"""Classes labels."""
1094+
return self._label_binarizer.classes_
1095+
1096+
def _more_tags(self):
1097+
return {"multilabel": True}
1098+
1099+
1100+
class RidgeClassifier(_RidgeClassifierMixin, _BaseRidge):
10141101
"""Classifier using Ridge regression.
10151102
10161103
This classifier first converts the target values into ``{-1, 1}`` and
@@ -1096,7 +1183,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
10961183
.. versionadded:: 0.17
10971184
Stochastic Average Gradient descent solver.
10981185
.. versionadded:: 0.19
1099-
SAGA solver.
1186+
SAGA solver.
11001187
11011188
- 'lbfgs' uses L-BFGS-B algorithm implemented in
11021189
`scipy.optimize.minimize`. It can be used only when `positive`
@@ -1203,42 +1290,18 @@ def fit(self, X, y, sample_weight=None):
12031290
will have the same weight.
12041291
12051292
.. versionadded:: 0.17
1206-
*sample_weight* support to Classifier.
1293+
*sample_weight* support to RidgeClassifier.
12071294
12081295
Returns
12091296
-------
12101297
self : object
12111298
Instance of the estimator.
12121299
"""
1213-
_accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver)
1214-
X, y = self._validate_data(
1215-
X, y, accept_sparse=_accept_sparse, multi_output=True, y_numeric=False
1216-
)
1217-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
1218-
1219-
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
1220-
Y = self._label_binarizer.fit_transform(y)
1221-
if not self._label_binarizer.y_type_.startswith("multilabel"):
1222-
y = column_or_1d(y, warn=True)
1223-
else:
1224-
# we don't (yet) support multi-label classification in Ridge
1225-
raise ValueError(
1226-
"%s doesn't support multi-label classification"
1227-
% (self.__class__.__name__)
1228-
)
1229-
1230-
if self.class_weight:
1231-
# modify the sample weights with the corresponding class weight
1232-
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)
1300+
X, y, sample_weight, Y = self._prepare_data(X, y, sample_weight, self.solver)
12331301

12341302
super().fit(X, Y, sample_weight=sample_weight)
12351303
return self
12361304

1237-
@property
1238-
def classes_(self):
1239-
"""Classes labels."""
1240-
return self._label_binarizer.classes_
1241-
12421305

12431306
def _check_gcv_mode(X, gcv_mode):
12441307
possible_gcv_modes = [None, "auto", "svd", "eigen"]
@@ -2145,7 +2208,7 @@ class RidgeCV(MultiOutputMixin, RegressorMixin, _BaseRidgeCV):
21452208
"""
21462209

21472210

2148-
class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
2211+
class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
21492212
"""Ridge classifier with built-in cross-validation.
21502213
21512214
See glossary entry for :term:`cross-validation estimator`.
@@ -2318,46 +2381,26 @@ def fit(self, X, y, sample_weight=None):
23182381
self : object
23192382
Fitted estimator.
23202383
"""
2321-
X, y = self._validate_data(
2322-
X,
2323-
y,
2324-
accept_sparse=["csr", "csc", "coo"],
2325-
multi_output=True,
2326-
y_numeric=False,
2327-
)
2328-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
2329-
2330-
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
2331-
Y = self._label_binarizer.fit_transform(y)
2332-
if not self._label_binarizer.y_type_.startswith("multilabel"):
2333-
y = column_or_1d(y, warn=True)
2334-
2335-
if self.class_weight:
2336-
# modify the sample weights with the corresponding class weight
2337-
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)
2338-
2384+
# `RidgeClassifier` does not accept "sag" or "saga" solver and thus support
2385+
# csr, csc, and coo sparse matrices. By using solver="eigen" we force to accept
2386+
# all sparse format.
2387+
X, y, sample_weight, Y = self._prepare_data(X, y, sample_weight, solver="eigen")
2388+
2389+
# If cv is None, gcv mode will be used and we used the binarized Y
2390+
# since y will not be binarized in _RidgeGCV estimator.
2391+
# If cv is not None, a GridSearchCV with some RidgeClassifier
2392+
# estimators are used where y will be binarized. Thus, we pass y
2393+
# instead of the binarized Y.
23392394
target = Y if self.cv is None else y
2340-
_BaseRidgeCV.fit(self, X, target, sample_weight=sample_weight)
2395+
super().fit(X, target, sample_weight=sample_weight)
23412396
return self
23422397

2343-
@property
2344-
def classes_(self):
2345-
"""Classes labels."""
2346-
return self._label_binarizer.classes_
2347-
23482398
def _more_tags(self):
23492399
return {
23502400
"multilabel": True,
23512401
"_xfail_checks": {
23522402
"check_sample_weights_invariance": (
23532403
"zero sample_weight is not equivalent to removing samples"
23542404
),
2355-
# FIXME: see
2356-
# https://github.com/scikit-learn/scikit-learn/issues/19858
2357-
# to track progress to resolve this issue
2358-
"check_classifiers_multilabel_output_format_predict": (
2359-
"RidgeClassifierCV.predict outputs an array of shape (25,) "
2360-
"instead of (25, 5)"
2361-
),
23622405
},
23632406
}

sklearn/linear_model/tests/test_ridge.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,12 +1396,6 @@ def test_ridge_regression_check_arguments_validity(
13961396
assert_allclose(out, true_coefs, rtol=0, atol=atol)
13971397

13981398

1399-
def test_ridge_classifier_no_support_multilabel():
1400-
X, y = make_multilabel_classification(n_samples=10, random_state=0)
1401-
with pytest.raises(ValueError):
1402-
RidgeClassifier().fit(X, y)
1403-
1404-
14051399
@pytest.mark.parametrize(
14061400
"solver", ["svd", "sparse_cg", "cholesky", "lsqr", "sag", "saga", "lbfgs"]
14071401
)
@@ -1515,6 +1509,28 @@ def test_ridge_sag_with_X_fortran():
15151509
Ridge(solver="sag").fit(X, y)
15161510

15171511

1512+
@pytest.mark.parametrize(
1513+
"Classifier, params",
1514+
[
1515+
(RidgeClassifier, {}),
1516+
(RidgeClassifierCV, {"cv": None}),
1517+
(RidgeClassifierCV, {"cv": 3}),
1518+
],
1519+
)
1520+
def test_ridgeclassifier_multilabel(Classifier, params):
1521+
"""Check that multilabel classification is supported and give meaningful
1522+
results."""
1523+
X, y = make_multilabel_classification(n_classes=1, random_state=0)
1524+
y = y.reshape(-1, 1)
1525+
Y = np.concatenate([y, y], axis=1)
1526+
clf = Classifier(**params).fit(X, Y)
1527+
Y_pred = clf.predict(X)
1528+
1529+
assert Y_pred.shape == Y.shape
1530+
assert_array_equal(Y_pred[:, 0], Y_pred[:, 1])
1531+
Ridge(solver="sag").fit(X, y)
1532+
1533+
15181534
@pytest.mark.parametrize("solver", ["auto", "lbfgs"])
15191535
@pytest.mark.parametrize("fit_intercept", [True, False])
15201536
@pytest.mark.parametrize("alpha", [1e-3, 1e-2, 0.1, 1.0])

0 commit comments

Comments
 (0)
0