8000 ENH better input validation for prediction in SVC, LinearSVC. · pfdevilliers/scikit-learn@163d777 · GitHub
[go: up one dir, main page]

Skip to content

Commit 163d777

Browse files
committed
ENH better input validation for prediction in SVC, LinearSVC.
1 parent c210c0b commit 163d777

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

sklearn/svm/base.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from . import libsvm, liblinear
77
from . import libsvm_sparse
88
from ..base import BaseEstimator
9-
from ..utils import array2d, atleast2d_or_csr
9+
from ..utils import atleast2d_or_csr
1010
from ..utils.extmath import safe_sparse_dot
1111

1212

@@ -259,17 +259,6 @@ def predict(self, X):
259259
C : array, shape = [n_samples]
260260
"""
261261
X = self._validate_for_predict(X)
262-
n_samples, n_features = X.shape
263-
264-
if self.kernel == "precomputed":
265-
if X.shape[1] != self.shape_fit_[0]:
266-
raise ValueError("X.shape[1] = %d should be equal to %d, "
267-
"the number of samples at training time" %
268-
(X.shape[1], self.shape_fit_[0]))
269-
elif n_features != self.shape_fit_[1]:
270-
raise ValueError("X.shape[1] = %d should be equal to %d, "
271-
"the number of features at training time" %
272-
(n_features, self.shape_fit_[1]))
273262
predict = self._sparse_predict if self._sparse else self._dense_predict
274263
return predict(X)
275264

@@ -354,7 +343,7 @@ def predict_proba(self, X):
354343
datasets.
355344
"""
356345
if not self.probability:
357-
raise ValueError(
346+
raise NotImplementedError(
358347
"probability estimates must be enabled to use this method")
359348

360349
if self.impl not in ('c_svc', 'nu_svc'):
@@ -461,7 +450,7 @@ def decision_function(self, X):
461450
raise NotImplementedError("decision_function not supported for"
462451
" sparse SVM")
463452

464-
X = array2d(X, dtype=np.float64, order="C")
453+
X = self._validate_for_predict(X)
465454

466455
C = 0.0 # C is not useful here
467456

@@ -494,6 +483,17 @@ def _validate_for_predict(self, X):
494483
raise ValueError(
495484
"cannot use sparse input in %r trained on dense data"
496485
% type(self).__name__)
486+
n_samples, n_features = X.shape
487+
488+
if self.kernel == "precomputed":
489+
if X.shape[1] != self.shape_fit_[0]:
490+
raise ValueError("X.shape[1] = %d should be equal to %d, "
491+
"the number of samples at training time" %
492+
(X.shape[1], self.shape_fit_[0]))
493+
elif n_features != self.shape_fit_[1]:
494+
raise ValueError("X.shape[1] = %d should be equal to %d, "
495+
"the number of features at training time" %
496+
(n_features, self.shape_fit_[1]))
497497
return X
498498

499499
@property
@@ -657,7 +657,6 @@ def predict(self, X):
657657
C : array, shape = [n_samples]
658658
"""
659659
X = self._validate_for_predict(X)
660-
self._check_n_features(X)
661660

662661
C = 0.0 # C is not useful here
663662

@@ -681,7 +680,6 @@ def decision_function(self, X):
681680
in the model.
682681
"""
683682
X = self._validate_for_predict(X)
684-
self._check_n_features(X)
685683

686684
C = 0.0 # C is not useful here
687685

@@ -716,6 +714,7 @@ def _validate_for_predict(self, X):
716714
'Copying them.', RuntimeWarning,
717715
stacklevel=3)
718716
self.raw_coef_ = np.asfortranarray(self.raw_coef_)
717+
self._check_n_features(X)
719718
return X
720719

721720
def _get_intercept_(self):

sklearn/tests/test_common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.datasets import load_iris, load_boston
1616
from sklearn.metrics import zero_one_score
1717
from sklearn.lda import LDA
18-
from sklearn.svm.base import BaseLibSVM, BaseLibLinear
18+
from sklearn.svm.base import BaseLibSVM
1919

2020
# import "special" estimators
2121
from sklearn.grid_search import GridSearchCV
@@ -84,28 +84,27 @@ def test_classifiers_train():
8484
assert_greater(zero_one_score(y, y_pred), 0.78)
8585

8686
# raises error on malformed input for predict
87-
if isinstance(clf, BaseLibSVM) or isinstance(clf, BaseLibLinear):
88-
# TODO: libsvm decision functions, input validation
89-
continue
9087
assert_raises(ValueError, clf.predict, X.T)
9188
if hasattr(clf, "decision_function"):
9289
try:
93-
# raises error on malformed input for decision_function
94-
assert_raises(ValueError, clf.decision_function, X.T)
9590
# decision_function agrees with predict:
9691
decision = clf.decision_function(X)
9792
assert_equal(decision.shape, (n_samples, n_labels))
98-
assert_array_equal(np.argmax(decision, axis=1), y_pred)
93+
if not isinstance(clf, BaseLibSVM):
94+
# 1on1 of LibSVM works differently
95+
assert_array_equal(np.argmax(decision, axis=1), y_pred)
96+
# raises error on malformed input for decision_function
97+
assert_raises(ValueError, clf.decision_function, X.T)
9998
except NotImplementedError:
10099
pass
101100
if hasattr(clf, "predict_proba"):
102101
try:
103-
# raises error on malformed input for predict_proba
104-
assert_raises(ValueError, clf.predict_proba, X.T)
105102
# predict_proba agrees with predict:
106103
y_prob = clf.predict_proba(X)
107104
assert_equal(y_prob.shape, (n_samples, n_labels))
108105
assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
106+
# raises error on malformed input for predict_proba
107+
assert_raises(ValueError, clf.predict_proba, X.T)
109108
except NotImplementedError:
110109
pass
111110

0 commit comments

Comments
 (0)
0