8000 MNT n_features_in through the multiclass module (#20193) · scikit-learn/scikit-learn@6bfaced · GitHub
[go: up one dir, main page]

Skip to content

Commit 6bfaced

Browse files
authored
MNT n_features_in through the multiclass module (#20193)
1 parent 3a23e26 commit 6bfaced

File tree

4 files changed

+56
-30
lines changed

4 files changed

+56
-30
lines changed

sklearn/multiclass.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@
5050
from .utils._tags import _safe_tags
5151
from .utils.validation import _num_samples
5252
from .utils.validation import check_is_fitted
53-
from .utils.validation import check_X_y, check_array
53+
from .utils.validation import column_or_1d
54+
from .utils.validation import _assert_all_finite
5455
from .utils.multiclass import (_check_partial_fit_first_call,
5556
check_classification_targets,
5657
_ovr_decision_function)
5758
from .utils.metaestimators import _safe_split, if_delegate_has_method
5859
from .utils.fixes import delayed
59-
from .exceptions import NotFittedError
6060

6161
from joblib import Parallel
6262

@@ -114,24 +114,28 @@ def _check_estimator(estimator):
114114
class _ConstantPredictor(BaseEstimator):
115115

116116
def fit(self, X, y):
117+
self._check_n_features(X, reset=True)
117118
self.y_ = y
118119
return self
119120

120121
def predict(self, X):
121122
check_is_fitted(self)
123+
self._check_n_features(X, reset=True)
122124

123-
return np.repeat(self.y_, X.shape[0])
125+
return np.repeat(self.y_, _num_samples(X))
124126

125127
def decision_function(self, X):
126128
check_is_fitted(self)
129+
self._check_n_features(X, reset=True)
127130

128-
return np.repeat(self.y_, X.shape[0])
131+
return np.repeat(self.y_, _num_samples(X))
129132

130133
def predict_proba(self, X):
131134
check_is_fitted(self)
135+
self._check_n_features(X, reset=True)
132136

133137
return np.repeat([np.hstack([1 - self.y_, self.y_])],
134-
X.shape[0], axis=0)
138+
_num_samples(X), axis=0)
135139

136140

137141
class OneVsRestClassifier(MultiOutputMixin, ClassifierMixin,
@@ -219,6 +223,12 @@ class OneVsRestClassifier(MultiOutputMixin, ClassifierMixin,
219223
multilabel_ : boolean
220224
Whether a OneVsRestClassifier is a multilabel classifier.
221225
226+
n_features_in_ : int
227+
Number of features seen during :term:`fit`. Only defined if the
228+
underlying estimator exposes such an attribute when fit.
229+
230+
.. versionadded:: 0.24
231+
222232
Examples
223233
--------
224234
>>> import numpy as np
@@ -282,6 +292,9 @@ def fit(self, X, y):
282292
self.label_binarizer_.classes_[i]])
283293
for i, column in enumerate(columns))
284294

295+
if hasattr(self.estimators_[0], "n_features_in_"):
296+
self.n_features_in_ = self.estimators_[0].n_features_in_
297+
285298
return self
286299

287300
@if_delegate_has_method('estimator')
@@ -338,6 +351,9 @@ def partial_fit(self, X, y, classes=None):
338351
delayed(_partial_fit_binary)(estimator, X, column)
339352
for estimator, column in zip(self.estimators_, columns))
340353

354+
if hasattr(self.estimators_[0], "n_features_in_"):
355+
self.n_features_in_ = self.estimators_[0].n_features_in_
356+
341357
return self
342358

343359
def predict(self, X):
@@ -504,19 +520,6 @@ def _more_tags(self):
504520
def _first_estimator(self):
505521
return self.estimators_[0]
506522

507-
@property
508-
def n_features_in_(self):
509-
# For consistency with other estimators we raise a AttributeError so
510-
# that hasattr() fails if the OVR estimator isn't fitted.
511-
try:
512-
check_is_fitted(self)
513-
except NotFittedError as nfe:
514-
raise AttributeError(
515-
"{} object has no n_features_in_ attribute."
516-
.format(self.__class__.__name__)
517-
) from nfe
518-
return self.estimators_[0].n_features_in_
519-
520523

521524
def _fit_ovo_binary(estimator, X, y, i, j):
522525
"""Fit a single binary estimator (one-vs-one)."""
@@ -525,7 +528,7 @@ def _fit_ovo_binary(estimator, X, y, i, j):
525528
y_binary = np.empty(y.shape, int)
526529
y_binary[y == i] = 0
527530
y_binary[y == j] = 1
528-
indcond = np.arange(X.shape[0])[cond]
531+
indcond = np.arange(_num_samples(X))[cond]
529532
return _fit_binary(estimator,
530533
_safe_split(estimator, X, None, indices=indcond)[0],
531534
y_binary, classes=[i, j]), indcond
@@ -593,6 +596,12 @@ class OneVsOneClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
593596
(renaming of 0.25) and onward, `pairwise_indices_` will use the
594597
pairwise estimator tag instead.
595598
599+
n_features_in_ : int
600+
Number of features seen during :term:`fit`. Only defined if the
601+
underlying estimator exposes such an attribute when fit.
602+
603+
.. versionadded:: 0.24
604+
596605
Examples
597606
--------
598607
>>> from sklearn.datasets import load_iris
@@ -626,6 +635,7 @@ def fit(self, X, y):
626635
-------
627636
self
628637
"""
638+
# We need to validate the data because we do a safe_indexing later.
629639
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
630640
force_all_finite=False)
631641
check_classification_targets(y)
@@ -642,6 +652,9 @@ def fit(self, X, y):
642652

643653
self.estimators_ = estimators_indices[0]
644654

655+
if hasattr(self.estimators_[0], "n_features_in_"):
656+
self.n_features_in_ = self.estimators_[0].n_features_in_
657+
645658
pairwise = _is_pairwise(self)
646659
self.pairwise_indices_ = (
647660
estimators_indices[1] if pairwise else None)
@@ -686,8 +699,9 @@ def partial_fit(self, X, y, classes=None):
686699
"must be subset of {1}".format(np.unique(y),
687700
self.classes_))
688701

689-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'],
690-
force_all_finite=False)
702+
X, y = self._validate_data(
703+
X, y, accept_sparse=['csr', 'csc'], force_all_finite=False,
704+
reset=_check_partial_fit_first_call(self, classes))
691705
check_classification_targets(y)
692706
combinations = itertools.combinations(range(self.n_classes_), 2)
693707
self.estimators_ = Parallel(
@@ -699,6 +713,9 @@ def partial_fit(self, X, y, classes=None):
699713

700714
self.pairwise_indices_ = None
701715

716+
if hasattr(self.estimators_[0], "n_features_in_"):
717+
self.n_features_in_ = self.estimators_[0].n_features_in_
718+
702719
return self
703720

704721
def predict(self, X):
@@ -832,6 +849,12 @@ class OutputCodeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
832849
code_book_ : numpy array of shape [n_classes, code_size]
833850
Binary array containing the code of each class.
834851
852+
n_features_in_ : int
853+
Number of features seen during :term:`fit`. Only defined if the
854+
underlying estimator exposes such an attribute when fit.
855+
856+
.. versionadded:: 0.24
857+
835858
Examples
836859
--------
837860
>>> from sklearn.multiclass import OutputCodeClassifier
@@ -886,7 +909,9 @@ def fit(self, X, y):
886909
-------
887910
self
888911
"""
889-
X, y = self._validate_data(X, y, accept_sparse=True)
912+
y = column_or_1d(y, warn=True)
913+
_assert_all_finite(y)
914+
890915
if self.code_size <= 0:
891916
raise ValueError("code_size should be greater than 0, got {0}"
892917
"".format(self.code_size))
@@ -897,6 +922,9 @@ def fit(self, X, y):
897922

898923
self.classes_ = np.unique(y)
899924
n_classes = self.classes_.shape[0]
925+
if n_classes == 0:
926+
raise ValueError("OutputCodeClassifier can not be fit when no "
927+
"class is present.")
900928
code_size_ = int(n_classes * self.code_size)
901929

902930
# FIXME: there are more elaborate methods than generating the codebook
@@ -912,12 +940,15 @@ def fit(self, X, y):
912940
classes_index = {c: i for i, c in enumerate(self.classes_)}
913941

914942
Y = np.array([self.code_book_[classes_index[y[i]]]
915-
for i in range(X.shape[0])], dtype=int)
943+
for i in range(_num_samples(y))], dtype=int)
916944

917945
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
918946
delayed(_fit_binary)(self.estimator, X, Y[:, i])
919947
for i in range(Y.shape[1]))
920948

949+
if hasattr(self.estimators_[0], "n_features_in_"):
950+
self.n_features_in_ = self.estimators_[0].n_features_in_
951+
921952
return self
922953

923954
def predict(self, X):
@@ -934,7 +965,6 @@ def predict(self, X):
934965
Predicted multi-class targets.
935966
"""
936967
check_is_fitted(self)
937-
X = check_array(X, accept_sparse=True)
938968
Y = np.array([_predict_binary(e, X) for e in self.estimators_]).T
939969
pred = euclidean_distances(Y, self.code_book_).argmin(axis=1)
940970
return self.classes_[pred]

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def test_search_cv(estimator, check, request):
262262
# check_classifiers_train would need to be updated with the error message
263263
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = {
264264
'model_selection',
265-
'multiclass',
266265
'multioutput',
267266
'pipeline',
268267
}

sklearn/tests/test_docstring_parameters.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def _construct_searchcv_instance(SearchCV):
192192
'linear_model',
193193
'manifold',
194194
'model_selection',
195-
'multiclass',
196195
'multioutput',
197196
'naive_bayes',
198197
'neighbors',
@@ -219,8 +218,7 @@ def test_fit_docstring_attributes(name, Estimator):
219218
'CountVectorizer', 'DictVectorizer', 'FeatureUnion',
220219
'GaussianRandomProjection',
221220
'MultiOutputClassifier', 'MultiOutputRegressor',
222-
'NoSampleWeightWrapper', 'OneVsOneClassifier',
223-
'OutputCodeClassifier', 'Pipeline', 'RFE', 'RFECV',
221+
'NoSampleWeightWrapper', 'Pipeline', 'RFE', 'RFECV',
224222
'RegressorChain', 'SelectFromModel',
225223
'SparseCoder', 'SparseRandomProjection',
226224
'SpectralBiclustering', 'StackingClassifier',

sklearn/tests/test_metaestimators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ def _generate_meta_estimator_instances_with_pipeline():
219219
"IterativeImputer",
220220
"MultiOutputClassifier",
221221
"MultiOutputRegressor",
222-
"OneVsOneClassifier",
223-
"OutputCodeClassifier",
222+
"OneVsOneClassifier", # input validation can't be avoided
224223
"RANSACRegressor",
225224
"RFE",
226225
"RFECV",

0 commit comments

Comments
 (0)
0