8000 Merge pull request #4775 from arthurmensch/cd_fast_readonly_array_bra… · scikit-learn/scikit-learn@77ecf16 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77ecf16

Browse files
committed
Merge pull request #4775 from arthurmensch/cd_fast_readonly_array_brainfix
[MRG+1] Read-only data compatibility for Lasso
2 parents 3a6c4dc + fead69a commit 77ecf16

File tree

8 files changed

+2731
-2557
lines changed

8 files changed

+2731
-2557
lines changed

examples/decomposition/plot_image_denoising.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747
###############################################################################
4848
# Load Lena image and extract patches
49-
5049
lena = lena() / 256.0
5150

5251
# downsample for higher speed

sklearn/decomposition/tests/test_dict_learning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import numpy as np
22

3+
34
from sklearn.utils.testing import assert_array_almost_equal
45
from sklearn.utils.testing import assert_array_equal
56
from sklearn.utils.testing import assert_equal
67
from sklearn.utils.testing import assert_true
78
from sklearn.utils.testing import assert_less
89
from sklearn.utils.testing import assert_raises
910
from sklearn.utils.testing import ignore_warnings
11+
from sklearn.utils.testing import TempMemmap
1012

1113
from sklearn.decomposition import DictionaryLearning
1214
from sklearn.decomposition import MiniBatchDictionaryLearning
@@ -60,6 +62,15 @@ def test_dict_learning_reconstruction_parallel():
6062
assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2)
6163

6264

65+
def test_dict_learning_lassocd_readonly_data():
66+
n_components = 12
67+
with TempMemmap(X) as X_read_only:
68+
dico = DictionaryLearning(n_components, transform_algorithm='lasso_cd',
69+
transform_alpha=0.001, random_state=0, n_jobs=-1)
70+
code = dico.fit(X_read_only).transform(X_read_only)
71+
assert_array_almost_equal(np.dot(code, dico.components_), X_read_only, decimal=2)
72+
73+
6374
def test_dict_learning_nonzero_coefs():
6475
n_components = 4
6576
dico = DictionaryLearning(n_components, transform_algorithm='lars',
@@ -214,4 +225,4 @@ def test_sparse_coder_estimator():
214225
code = SparseCoder(dictionary=V, transform_algorithm='lasso_lars',
215226
transform_alpha=0.001).transform(X)
216227
assert_true(not np.all(code == 0))
217-
assert_less(np.sqrt(np.sum((np.dot(code, V) - X) ** 2)), 0.1)
228+
assert_less(np.sqrt(np.sum((np.dot(code, V) - X) ** 2)), 0.1)

sklearn/linear_model/cd_fast.c

Lines changed: 2640 additions & 2531 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/linear_model/cd_fast.pyx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,9 @@ def enet_coordinate_descent(np.ndarray[DOUBLE, ndim=1] w,
289289
@cython.cdivision(True)
290290
def sparse_enet_coordinate_descent(double[:] w,
291291
double alpha, double beta,
292-
double[:] X_data, int[:] X_indices,
293-
int[:] X_indptr, double[:] y,
292+
np.ndarray[double, ndim=1] X_data,
293+
np.ndarray[int, ndim=1] X_indices,
294+
np.ndarray[int, ndim=1] X_indptr, np.ndarray[double, ndim=1] y,
294295
double[:] X_mean, int max_iter,
295296
double tol, object rng, bint random=0,
296297
bint positive=0):
@@ -487,7 +488,9 @@ def sparse_enet_coordinate_descent(double[:] w,
487488
@cython.wraparound(False)
488489
@cython.cdivision(True)
489490
def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
490-
double[:, :] Q, double[:] q, double[:] y,
491+
np.ndarray[double, ndim=2] Q,
492+
np.ndarray[double, ndim=1] q,
493+
np.ndarray[double, ndim=1] y,
491494
int max_iter, double tol, object rng,
492495
bint random=0, bint positive=0):
493496
"""Cython version of the coordinate descent algorithm
@@ -628,8 +631,8 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
628631
@cython.wraparound(False)
629632
@cython.cdivision(True)
630633
def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
631-
double l2_reg, double[::1, :] X,
632-
double[:, :] Y, int max_iter,
634+
double l2_reg, np.ndarray[double, ndim=2] X,
635+
np.ndarray[double, ndim=2] Y, int max_iter,
633636
double tol, object rng,
634637
bint random=0):
635638
"""Cython version of the coordinate descent algorithm

sklearn/linear_model/coordinate_descent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
359359
ElasticNetCV
360360
"""
361361
X = check_array(X, 'csc', dtype=np.float64, order='F', copy=copy_X)
362+
y = check_array(y, 'csc', dtype=np.float64, order='F', copy=False, ensure_2d=False)
362363
if Xy is not None:
363364
Xy = check_array(Xy, 'csc', dtype=np.float64, order='F', copy=False,
364365
ensure_2d=False)

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.utils.testing import assert_warns
2020
from sklearn.utils.testing import ignore_warnings
2121
from sklearn.utils.testing import assert_array_equal
22+
from sklearn.utils.testing import TempMemmap
2223

2324
from sklearn.linear_model.coordinate_descent import Lasso, \
2425
LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet, \
@@ -388,6 +389,29 @@ def test_multi_task_lasso_and_enet():
388389
assert_array_almost_equal(clf.coef_[0], clf.coef_[1])
389390

390391

392+
def test_lasso_readonly_data():
393+
X = np.array([[-1], [0], [1]])
394+
Y = np.array([-1, 0, 1]) # just a straight line
395+
T = np.array([[2], [3], [4]]) # test sample
396+
with TempMemmap((X, Y)) as (X, Y):
397+
clf = Lasso(alpha=0.5)
398+
clf.fit(X, Y)
399+
pred = clf.predict(T)
400+
assert_array_almost_equal(clf.coef_, [.25])
401+
assert_array_almost_equal(pred, [0.5, 0.75, 1.])
402+
assert_almost_equal(clf.dual_gap_, 0)
403+
404+
405+
def test_multi_task_lasso_readonly_data():
406+
X, y, X_test, y_test = build_dataset()
407+
Y = np.c_[y, y]
408+
with TempMemmap((X, Y)) as (X, Y):
409+
Y = np.c_[y, y]
410+
clf = MultiTaskLasso(alpha=1, tol=1e-8).fit(X, Y)
411+
assert_true(0 < clf.dual_gap_ < 1e-5)
412+
assert_array_almost_equal(clf.coef_[0], clf.coef_[1])
413+
414+
391415
def test_enet_multitarget():
392416
n_targets = 3
393417
X, y, _, _ = build_dataset(n_samples=10, n_features=8,

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import tempfile
2-
import shutil
3-
import os.path as op
4-
import warnings
51
from nose.tools import assert_equal
62

73
import numpy as np
@@ -16,6 +12,7 @@
1612
from sklearn.utils.testing import assert_raises
1713
from sklearn.utils.testing import ignore_warnings
1814
from sklearn.utils.testing import assert_no_warnings, assert_warns
15+
from sklearn.utils.testing import TempMemmap
1916
from sklearn.utils import ConvergenceWarning
2017
from sklearn import linear_model, datasets
2118
from sklearn.linear_model.least_angle import _lars_path_residues
@@ -440,19 +437,6 @@ def test_lars_path_readonly_data():
440437
# This is a non-regression test for:
441438
# https://github.com/scikit-learn/scikit-learn/issues/4597
442439
splitted_data = train_test_split(X, y, random_state=42)
443-
temp_folder = tempfile.mkdtemp()
444-
try:
445-
fpath = op.join(temp_folder, 'data.pkl')
446-
joblib.dump(splitted_data, fpath)
447-
X_train, X_test, y_train, y_test = joblib.load(fpath, mmap_mode='r')
448-
440+
with TempMemmap(splitted_data) as (X_train, X_test, y_train, y_test):
449441
# The following should not fail despite copy=False
450-
_lars_path_residues(X_train, y_train, X_test, y_test, copy=False)
451-
finally:
452-
# try to release the mmap file handle in time to be able to delete
453-
# the temporary folder under windows
454-
del X_train, X_test, y_train, y_test
455-
try:
456-
shutil.rmtree(temp_folder)
457-
except shutil.WindowsError:
458-
warnings.warn("Could not delete temporary folder %s" % temp_folder)
442+
_lars_path_residues(X_train, y_train, X_test, y_test, copy=False)

sklearn/utils/testing.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,20 @@
2828
from urllib.request import urlopen
2929
from urllib.error import HTTPError
3030

31+
import tempfile
32+
import shutil
33+
import os.path as op
34+
import atexit
35+
36+
# WindowsError only exist on Windows
37+
try:
38+
WindowsError
39+
except NameError:
40+
WindowsError = None
41+
3142
import sklearn
3243
from sklearn.base import BaseEstimator
44+
from sklearn.externals import joblib
3345

3446
# Conveniently import all assertions in one place.
3547
from nose.tools import assert_equal
@@ -697,5 +709,36 @@ def check_skip_travis():
697709
if os.environ.get('TRAVIS') == "true":
698710
raise SkipTest("This test needs to be skipped on Travis")
699711

712+
713+
def _delete_folder(folder_path, warn=False):
714+
"""Utility function to cleanup a temporary folder if still existing.
715+
Copy from joblib.pool (for independance)"""
716+
try:
717+
if os.path.exists(folder_path):
718+
# This can fail under windows,
719+
# but will succeed when called by atexit
720+
shutil.rmtree(folder_path)
721+
except WindowsError:
722+
if warn:
723+
warnings.warn("Could not delete temporary folder %s" % folder_path)
724+
725+
726+
class TempMemmap(object):
727+
def __init__(self, data, mmap_mode='r'):
728+
self.temp_folder = tempfile.mkdtemp(prefix='sklearn_testing_')
729+
self.mmap_mode = mmap_mode
730+
self.data = data
731+
732+
def __enter__(self):
733+
fpath = op.join(self.temp_folder, 'data.pkl')
734+
joblib.dump(self.data, fpath)
735+
data_read_only = joblib.load(fpath, mmap_mode=self.mmap_mode)
736+
atexit.register(lambda: _delete_folder(self.temp_folder, warn=True))
737+
return data_read_only
738+
739+
def __exit__(self, exc_type, exc_val, exc_tb):
740+
_delete_folder(self.temp_folder)
741+
742+
700743
with_network = with_setup(check_skip_network)
701744
with_travis = with_setup(check_skip_travis)

0 commit comments

Comments
 (0)
0