8000 FIX sort indices in CSR matrix for SVM · scikit-learn/scikit-learn@c75dd39 · GitHub
[go: up one dir, main page]

Skip to content

Commit c75dd39

Browse files
amuellerlarsmans
authored andcommitted
FIX sort indices in CSR matrix for SVM
1 parent 0cb8963 commit c75dd39

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ Changelog
131131
:mod:`sklearn.metrics` for regression and classification metrics
132132
by `Arnaud Joly`_.
133133

134+
- Fixed a bug in :class:`sklearn.svm.SVC` when using csr-matrices with
135+
unsorted indices by Xinfan Meng and `Andreas Müller`_.
136+
134137
API changes summary
135138
-------------------
136139
- Renamed all occurences of ``n_atoms`` to ``n_components`` for consistency.

sklearn/svm/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel):
238238

239239
def _sparse_fit(self, X, y, sample_weight, solver_type, kernel):
240240
X.data = np.asarray(X.data, dtype=np.float64, order='C')
241+
X.sort_indices()
241242

242243
kernel_type = self._sparse_kernels.index(kernel)
243244

@@ -398,6 +399,9 @@ def _validate_for_predict(self, X):
398399
X = atleast2d_or_csr(X, dtype=np.float64, order="C")
399400
if self._sparse and not sp.isspmatrix(X):
400401
X = sp.csr_matrix(X)
402+
if self._sparse:
403+
X.sort_indices()
404+
401405
if (sp.issparse(X) and not self._sparse and
402406
not hasattr(self.kernel, '__call__')):
403407
raise ValueError(

sklearn/svm/tests/test_sparse.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
66
assert_equal)
77

8-
from nose.tools import assert_raises, assert_true
8+
from nose.tools import assert_raises, assert_true, assert_false
99
from nose.tools import assert_equal as nose_assert_equal
10-
from sklearn.datasets.samples_generator import make_classification
10+
from sklearn.datasets import make_classification, load_digits
1111
from sklearn.svm.tests import test_svm
1212
from sklearn.utils import ConvergenceWarning
1313
from sklearn.utils.extmath import safe_sparse_dot
@@ -69,6 +69,37 @@ def test_svc():
6969
sp_clf.predict_proba(T2), 4)
7070

7171

72+
def test_unsorted_indices():
73+
# test that the result with sorted and unsorted indices in csr is the same
74+
# we use a subset of digits as iris, blobs or make_classification didn't
75+
# show the problem
76+
digits = load_digits()
77+
X, y = digits.data[:50], digits.target[:50]
78+
X_test = sparse.csr_matrix(digits.data[50:100])
79+
80+
X_sparse = sparse.csr_matrix(X)
81+
coef_dense = svm.SVC(kernel='linear', probability=True).fit(X, y).coef_
82+
sparse_svc = svm.SVC(kernel='linear', probability=True).fit(X_sparse, y)
83+
coef_sorted = sparse_svc.coef_
84+
# make sure dense and sparse SVM give the same result
85+
assert_array_almost_equal(coef_dense, coef_sorted.toarray())
86+
87+
X_sparse_unsorted = X_sparse[np.arange(X.shape[0])]
88+
X_test_unsorted = X_test[np.arange(X_test.shape[0])]
89+
90+
# make sure we scramble the indices
91+
assert_false(X_sparse_unsorted.has_sorted_indices)
92+
assert_false(X_test_unsorted.has_sorted_indices)
93+
94+
unsorted_svc = svm.SVC(kernel='linear',
95+
probability=True).fit(X_sparse_unsorted, y)
96+
coef_unsorted = unsorted_svc.coef_
97+
# make sure unsorted indices give same result
98+
assert_array_almost_equal(coef_unsorted.toarray(), coef_sorted.toarray())
99+
assert_array_almost_equal(sparse_svc.predict_proba(X_test_unsorted),
100+
sparse_svc.predict_proba(X_test))
101+
102+
72103
def test_svc_with_custom_kernel():
73104
kfunc = lambda x, y: safe_sparse_dot(x, y.T)
74105
clf_lin = svm.SVC(kernel='linear').fit(X_sp, Y)

0 commit comments

Comments
 (0)
0