8000 review glemaitre · scikit-learn/scikit-learn@e649ab3 · GitHub
[go: up one dir, main page]

Skip to content

Commit e649ab3

Browse files
committed
review glemaitre
1 parent 4fde439 commit e649ab3

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

doc/whats_new/v0.23.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ parameters, may produce different models from the previous version. This often
2323
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
2424
random sampling procedures.
2525

26-
- models come here
26+
- :class:`svm.SVC` and :class:`svm.SVR` when `kernel` is callable and input is
27+
not castable as float array. |Fix|
2728

2829
Details are listed in the changelog below.
2930

sklearn/svm/_base.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ def fit(self, X, y, sample_weight=None):
177177
"boolean masks (use `indices=True` in CV)."
178178
% (sample_weight.shape, X.shape))
179179

180-
kernel = self.kernel
181-
if callable(kernel):
182-
kernel = 'precomputed'
180+
kernel = 'precomputed' if callable(self.kernel) else self.kernel
183181

184182
if kernel == 'precomputed':
185-
self._gamma = 0. # unused but needs to be a float
183+
# unused but needs to be a float for cython code that ignores
184+
# it anyway
185+
self._gamma = 0.
186186
elif isinstance(self.gamma, str):
187187
if self.gamma == 'scale':
188188
# var = E[X^2] - E[X]^2 if sparse
@@ -207,10 +207,7 @@ def fit(self, X, y, sample_weight=None):
207207
fit(X, y, sample_weight, solver_type, kernel, random_seed=seed)
208208
# see comment on the other call to np.iinfo in this file
209209

210-
if hasattr(X, 'shape'):
211-
self.shape_fit_ = X.shape
212-
else:
213-
self.shape_fit_ = (_num_samples(X),)
210+
self.shape_fit_ = X.shape if hasattr(X, "shape") else (n_samples, )
214211

215212
# In binary case, we need to flip the sign of coef, intercept and
216213
# decision function. Use self._intercept_ and self._dual_coef_
@@ -467,18 +464,16 @@ def _validate_for_predict(self, X):
467464
raise ValueError(
468465
"cannot use sparse input in %r trained on dense data"
469466
% type(self).__name__)
470-
if not callable(self.kernel):
471-
n_features = X.shape[1]
472467

473468
if self.kernel == "precomputed":
474469
if X.shape[1] != self.shape_fit_[0]:
475470
raise ValueError("X.shape[1] = %d should be equal to %d, "
476471
"the number of samples at training time" %
477472
(X.shape[1], self.shape_fit_[0]))
478-
elif not callable(self.kernel) and n_features != self.shape_fit_[1]:
473+
elif not callable(self.kernel) and X.shape[1] != self.shape_fit_[1]:
479474
raise ValueError("X.shape[1] = %d should be equal to %d, "
480475
"the number of features at training time" %
481-
(n_features, self.shape_fit_[1]))
476+
(X.shape[1], self.shape_fit_[1]))
482477
return X
483478

484479
@property

sklearn/svm/tests/test_svm.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,14 +1281,12 @@ def string_kernel(X1, X2):
12811281
assert svc1.score(data, y) == svc3.score(K, y)
12821282
assert svc1.score(data, y) == svc2.score(X, y)
12831283
if hasattr(svc1, 'decision_function'): # classifier
1284-
assert_array_almost_equal(svc1.decision_function(data),
1285-
svc2.decision_function(X))
1286-
assert_array_almost_equal(svc1.decision_function(data),
1287-
svc3.decision_function(K))
1284+
assert_allclose(svc1.decision_function(data),
1285+
svc2.decision_function(X))
1286+
assert_allclose(svc1.decision_function(data),
1287+
svc3.decision_function(K))
12881288
assert_array_equal(svc1.predict(data), svc2.predict(X))
12891289
assert_array_equal(svc1.predict(data), svc3.predict(K))
12901290
else: # regressor
1291-
assert_array_almost_equal(svc1.predict(data),
1292-
svc2.predict(X))
1293-
assert_array_almost_equal(svc1.predict(data),
1294-
svc3.predict(K))
1291+
assert_allclose(svc1.predict(data), svc2.predict(X))
1292+
assert_allclose(svc1.predict(data), svc3.predict(K))

0 commit comments

Comments
 (0)
0