8000 FIX Do not reset for non-fit in multiclass (#20205) · scikit-learn/scikit-learn@007da8d · GitHub
[go: up one dir, main page]

Skip to content

Commit 007da8d

Browse files
authored
FIX Do not reset for non-fit in multiclass (#20205)
1 parent a253826 commit 007da8d

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

sklearn/multiclass.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,34 @@ def _check_estimator(estimator):
114114
class _ConstantPredictor(BaseEstimator):
115115

116116
def fit(self, X, y):
117-
self._check_n_features(X, reset=True)
117+
check_params = dict(force_all_finite=False, dtype=None,
118+
ensure_2d=False, accept_sparse=True)
119+
self._validate_data(X, y, reset=True,
120+
validate_separately=(check_params, check_params))
118121
self.y_ = y
119122
return self
120123

121124
def predict(self, X):
122125
check_is_fitted(self)
123-
self._check_n_features(X, reset=True)
126+
self._validate_data(X, force_all_finite=False, dtype=None,
127+
accept_sparse=True,
128+
ensure_2d=False, reset=False)
124129

125130
return np.repeat(self.y_, _num_samples(X))
126131

127132
def decision_function(self, X):
128133
check_is_fitted(self)
129-
self._check_n_features(X, reset=True)
134+
self._validate_data(X, force_all_finite=False, dtype=None,
135+
accept_sparse=True,
136+
ensure_2d=False, reset=False)
130137

131138
return np.repeat(self.y_, _num_samples(X))
132139

133140
def predict_proba(self, X):
134141
check_is_fitted(self)
135-
self._check_n_features(X, reset=True)
142+
self._validate_data(X, force_all_finite=False, dtype=None,
143+
accept_sparse=True,
144+
ensure_2d=False, reset=False)
136145

137146
return np.repeat([np.hstack([1 - self.y_, self.y_])],
138147
_num_samples(X), axis=0)

0 commit comments

Comments
 (0)
0