8000 ENH: make LinearSVC copyiable · seckcoder/scikit-learn@5bf5266 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5bf5266

Browse files
committed
ENH: make LinearSVC copyiable
Fixes scikit-learn#820
1 parent fead9ac commit 5bf5266

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sklearn/svm/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,11 @@ def _validate_for_predict(self, X):
755755
raise ValueError(
756756
"cannot use sparse input in %r trained on dense data"
757757
% type(self).__name__)
758+
if not self.raw_coef_.flags['F_CONTIGUOUS']:
759+
warnings.warn('Coefficients are the fortran-contiguous. '
760+
'Copying them.', RuntimeWarning,
761+
stacklevel=3)
762+
self.raw_coef_ = np.asfortranarray(self.raw_coef_)
758763
return X
759764

760765
def _get_intercept_(self):

sklearn/svm/tests/test_svm.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
44
TODO: remove hard coded numerical results when possible
55
"""
6+
import copy
7+
import warnings
68

79
import numpy as np
810
from numpy.testing import assert_array_equal, assert_array_almost_equal, \
911
assert_almost_equal
10-
from nose.tools import assert_raises, assert_true
12+
from nose.tools import assert_raises, assert_true, assert_equal
1113

1214
from sklearn import svm, linear_model, datasets, metrics, base
1315
from sklearn.datasets.samples_generator import make_classification
@@ -151,7 +153,7 @@ def test_precomputed():
151153
assert_almost_equal(np.mean(pred == iris.target), .99, decimal=2)
152154

153155

154-
def test_SVR():
156+
def test_svr():
155157
"""
156158
Test Support Vector Regression
157159
"""
@@ -598,6 +600,15 @@ def test_linearsvc_verbose():
598600
os.dup2(stdout, 1) # restore original stdout
599601

600602

603+
def test_linearsvc_deepcopy():
604+
rng = check_random_state(0)
605+
clf = svm.LinearSVC()
606+
clf.fit(rng.rand(10, 2), rng.randint(0, 2, size=10))
607+
with warnings.catch_warnings(record=True) as warn_queue:
608+
copy.deepcopy(clf).predict(rng.rand(2))
609+
assert_equal(len(warn_queue), 1)
610+
611+
601612
def test_svc_clone_with_callable_kernel():
602613
a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True)
603614
b = base.clone(a)

0 commit comments

Comments
 (0)
0