14
14
from ..utils import column_or_1d , check_X_y
15
15
from ..utils import compute_class_weight
16
16
from ..utils .extmath import safe_sparse_dot
17
- from ..utils .validation import check_is_fitted
17
+ from ..utils .validation import check_is_fitted , _num_samples
18
18
from ..utils .multiclass import check_classification_targets
19
19
from ..externals import six
20
20
from ..exceptions import ConvergenceWarning
@@ -144,7 +144,10 @@ def fit(self, X, y, sample_weight=None):
144
144
raise TypeError ("Sparse precomputed kernels are not supported." )
145
145
self ._sparse = sparse and not callable (self .kernel )
146
146
147
- X , y = check_X_y (X , y , dtype = np .float64 , order = 'C' , accept_sparse = 'csr' )
147
+ if callable (self .kernel ):
148
+ check_consistent_length (X , y )
149
+ else :
150
+ X , y = check_X_y (X , y , dtype = np .float64 , order = 'C' , accept_sparse = 'csr' )
148
151
y = self ._validate_targets (y )
149
152
150
153
sample_weight = np .asarray ([]
@@ -153,15 +156,16 @@ def fit(self, X, y, sample_weight=None):
153
156
solver_type = LIBSVM_IMPL .index (self ._impl )
154
157
155
158
# input validation
156
- if solver_type != 2 and X .shape [0 ] != y .shape [0 ]:
159
+ n_samples = _num_samples (X )
160
+ if solver_type != 2 and n_samples != y .shape [0 ]:
157
161
raise ValueError ("X and y have incompatible shapes.\n " +
158
162
"X has %s samples, but y has %s." %
159
- (X . shape [ 0 ] , y .shape [0 ]))
163
+ (n_samples , y .shape [0 ]))
160
164
161
- if self .kernel == "precomputed" and X . shape [ 0 ] != X .shape [1 ]:
165
+ if self .kernel == "precomputed" and n_samples != X .shape [1 ]:
162
166
raise ValueError ("X.shape[0] should be equal to X.shape[1]" )
163
167
164
- if sample_weight .shape [0 ] > 0 and sample_weight .shape [0 ] != X . shape [ 0 ] :
168
+ if sample_weight .shape [0 ] > 0 and sample_weight .shape [0 ] != n_samples :
165
169
raise ValueError ("sample_weight and X have incompatible shapes: "
166
170
"%r vs %r\n "
167
171
"Note: Sparse matrices cannot be indexed w/"
@@ -210,7 +214,10 @@ def fit(self, X, y, sample_weight=None):
210
214
fit (X , y , sample_weight , solver_type , kernel , random_seed = seed )
211
215
# see comment on the other call to np.iinfo in this file
212
216
213
- self .shape_fit_ = X .shape
217
+ if hasattr (X , 'shape' ):
218
+ self .shape_fit_ = X .shape
219
+ else :
220
+ self .shape_fit_ = (_num_samples (X ), )
214
221
215
222
# In binary case, we need to flip the sign of coef, intercept and
216
223
# decision function. Use self._intercept_ and self._dual_coef_ internally.
@@ -324,7 +331,6 @@ def predict(self, X):
324
331
return predict (X )
325
332
326
333
def _dense_predict (self , X ):
327
- n_samples , n_features = X .shape
328
334
X = self ._compute_kernel (X )
329
335
if X .ndim == 1 :
330
336
X = check_array (X , order = 'C' )
@@ -450,7 +456,8 @@ def _sparse_decision_function(self, X):
450
456
def _validate_for_predict (self , X ):
451
457
check_is_fitted (self , 'support_' )
452
458
453
- X = check_array (X , accept_sparse = 'csr' , dtype = np .float64 , order = "C" )
459
+ if not callable (self .kernel ):
460
+ X = check_array (X , accept_sparse = 'csr' , dtype = np .float64 , order = "C" )
454
461
if self ._sparse and not sp .isspmatrix (X ):
455
462
X = sp .csr_matrix (X )
456
463
if self ._sparse :
@@ -460,14 +467,15 @@ def _validate_for_predict(self, X):
460
467
raise ValueError (
461
468
"cannot use sparse input in %r trained on dense data"
462
469
% type (self ).__name__ )
463
- n_samples , n_features = X .shape
470
+ if not callable (self .kernel ):
471
+ n_features = X .shape [1 ]
464
472
465
473
if self .kernel == "precomputed" :
466
474
if X .shape [1 ] != self .shape_fit_ [0 ]:
467
475
raise ValueError ("X.shape[1] = %d should be equal to %d, "
468
476
"the number of samples at training time" %
469
477
(X .shape [1 ], self .shape_fit_ [0 ]))
470
- elif n_features != self .shape_fit_ [1 ]:
478
+ elif not callable ( self . kernel ) and n_features != self .shape_fit_ [1 ]:
471
479
raise ValueError ("X.shape[1] = %d should be equal to %d, "
472
480
"the number of features at training time" %
473
481
(n_features , self .shape_fit_ [1 ]))
0 commit comments