8000 Removed non-applicable checks on X when using user-defined kernels · scikit-learn/scikit-learn@4b831be · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b831be

Browse files
Georgi PeevGeorgi Peev
Georgi Peev
authored and
Georgi Peev
committed
Removed non-applicable checks on X when using user-defined kernels
1 parent 4143356 commit 4b831be

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

sklearn/svm/base.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..utils import column_or_1d, check_X_y
1515
from ..utils import compute_class_weight
1616
from ..utils.extmath import safe_sparse_dot
17-
from ..utils.validation import check_is_fitted
17+
from ..utils.validation import check_is_fitted, _num_samples
1818
from ..utils.multiclass import check_classification_targets
1919
from ..externals import six
2020
from ..exceptions import ConvergenceWarning
@@ -144,7 +144,10 @@ def fit(self, X, y, sample_weight=None):
144144
raise TypeError("Sparse precomputed kernels are not supported.")
145145
self._sparse = sparse and not callable(self.kernel)
146146

147-
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
147+
if callable(self.kernel):
148+
check_consistent_length(X, y)
149+
else:
150+
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
148151
y = self._validate_targets(y)
149152

150153
sample_weight = np.asarray([]
@@ -153,15 +156,16 @@ def fit(self, X, y, sample_weight=None):
153156
solver_type = LIBSVM_IMPL.index(self._impl)
154157

155158
# input validation
156-
if solver_type != 2 and X.shape[0] != y.shape[0]:
159+
n_samples = _num_samples(X)
160+
if solver_type != 2 and n_samples != y.shape[0]:
157161
raise ValueError("X and y have incompatible shapes.\n" +
158162
"X has %s samples, but y has %s." %
159-
(X.shape[0], y.shape[0]))
163+
(n_samples, y.shape[0]))
160164

161-
if self.kernel == "precomputed" and X.shape[0] != X.shape[1]:
165+
if self.kernel == "precomputed" and n_samples != X.shape[1]:
162166
raise ValueError("X.shape[0] should be equal to X.shape[1]")
163167

164-
if sample_weight.shape[0] > 0 and sample_weight.shape[0] != X.shape[0]:
168+
if sample_weight.shape[0] > 0 and sample_weight.shape[0] != n_samples:
165169
raise ValueError("sample_weight and X have incompatible shapes: "
166170
"%r vs %r\n"
167171
"Note: Sparse matrices cannot be indexed w/"
@@ -210,7 +214,10 @@ def fit(self, X, y, sample_weight=None):
210214
fit(X, y, sample_weight, solver_type, kernel, random_seed=seed)
211215
# see comment on the other call to np.iinfo in this file
212216

213-
self.shape_fit_ = X.shape
217+
if hasattr(X, 'shape'):
218+
self.shape_fit_ = X.shape
219+
else:
220+
self.shape_fit_ = (_num_samples(X), )
214221

215222
# In binary case, we need to flip the sign of coef, intercept and
216223
# decision function. Use self._intercept_ and self._dual_coef_ internally.
@@ -324,7 +331,6 @@ def predict(self, X):
324331
return predict(X)
325332

326333
def _dense_predict(self, X):
327-
n_samples, n_features = X.shape
328334
X = self._compute_kernel(X)
329335
if X.ndim == 1:
330336
X = check_array(X, order='C')
@@ -450,7 +456,8 @@ def _sparse_decision_function(self, X):
450456
def _validate_for_predict(self, X):
451457
check_is_fitted(self, 'support_')
452458

453-
X = check_array(X, accept_sparse='csr', dtype=np.float64, order="C")
459+
if not callable(self.kernel):
460+
X = check_array(X, accept_sparse='csr', dtype=np.float64, order="C")
454461
if self._sparse and not sp.isspmatrix(X):
455462
X = sp.csr_matrix(X)
456463
if self._sparse:
@@ -460,14 +467,15 @@ def _validate_for_predict(self, X):
460467
raise ValueError(
461468
"cannot use sparse input in %r trained on dense data"
462469
% type(self).__name__)
463-
n_samples, n_features = X.shape
470+
if not callable(self.kernel):
471+
n_features = X.shape[1]
464472

465473
if self.kernel == "precomputed":
466474
if X.shape[1] != self.shape_fit_[0]:
467475
raise ValueError("X.shape[1] = %d should be equal to %d, "
468476
"the number of samples at training time" %
469477
(X.shape[1], self.shape_fit_[0]))
470-
elif n_features != self.shape_fit_[1]:
478+
elif not callable(self.kernel) and n_features != self.shape_fit_[1]:
471479
raise ValueError("X.shape[1] = %d should be equal to %d, "
472480
"the number of features at training time" %
473481
(n_features, self.shape_fit_[1]))

0 commit comments

Comments
 (0)
0