8000 fix ompcv on old scipy versions · raghavrv/scikit-learn@211f204 · GitHub
[go: up one dir, main page]

Skip to content

Commit 211f204

Browse files
committed
fix ompcv on old scipy versions
1 parent 5610dba commit 211f204

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

sklearn/linear_model/omp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import scipy
2222
solve_triangular_args = {}
2323
if LooseVersion(scipy.__version__) >= LooseVersion('0.12'):
24+
# check_finite=False is an optimization available only in scipy >=0.12
2425
solve_triangular_args = {'check_finite': False}
2526

2627

@@ -89,7 +90,13 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True,
8990
indices = np.arange(X.shape[1]) # keeping track of swapping
9091

9192
max_features = X.shape[1] if tol is not None else n_nonzero_coefs
92-
L = np.empty((max_features, max_features), dtype=X.dtype)
93+
if solve_triangular_args:
94+
# new scipy, don't need to initialize because check_finite=False
95+
L = np.empty((max_features, max_features), dtype=X.dtype)
96+
else:
97+
# old scipy, we need the garbage upper triangle to be non-Inf
98+
L = np.zeros((max_features, max_features), dtype=X.dtype)
99+
93100
L[0, 0] = 1.
94101
if return_path:
95102
coefs = np.empty_like(L)
@@ -373,8 +380,7 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute=False,
373380
for k in range(y.shape[1]):
374381
out = _cholesky_omp(
375382
X, y[:, k], n_nonzero_coefs, tol,
376-
copy_X=copy_X, return_path=return_path
377-
)
383+
copy_X=copy_X, return_path=return_path)
378384
if return_path:
379385
_, idx, coefs, n_iter = out
380386
coef = coef[:, :, :len(idx)]
@@ -504,8 +510,7 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None,
504510
Gram, Xy[:, k], n_nonzero_coefs,
505511
norms_squared[k] if tol is not None else None, tol,
506512
copy_Gram=copy_Gram, copy_Xy=copy_Xy,
507-
return_path=return_path
508-
)
513+
return_path=return_path)
509514
if return_path:
510515
_, idx, coefs, n_iter = out
511516
coef = coef[:, :, :len(idx)]

sklearn/linear_model/tests/test_omp.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from sklearn.utils.testing import assert_array_almost_equal
1111
from sklearn.utils.testing import assert_warns
1212
from sklearn.utils.testing import ignore_warnings
13-
from sklearn.utils.testing import check_skip_travis
1413

1514

1615
from sklearn.linear_model import (orthogonal_mp, orthogonal_mp_gram,
@@ -172,17 +171,15 @@ def test_omp_path():
172171

173172

174173
def test_omp_return_path_prop_with_gram():
175-
path = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=True,
176-
precompute=True)
174+
path = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=True,
175+
precompute=True)
177176
last = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=False,
178-
precompute=True)
177+
precompute=True)
179178
assert_equal(path.shape, (n_features, n_targets, 5))
180179
assert_array_almost_equal(path[:, :, -1], last)
181180

182181

183182
def test_omp_cv():
184-
# FIXME: This test is unstable on Travis, see issue #3190 for more detail.
185-
check_skip_travis()
186183
y_ = y[:, 0]
187184
gamma_ = gamma[:, 0]
188185
ompcv = OrthogonalMatchingPursuitCV(normalize=True, fit_intercept=False,

sklearn/utils/estimator_checks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sklearn.utils.testing import set_random_state
2424
from sklearn.utils.testing import assert_greater
2525
from sklearn.utils.testing import SkipTest
26-
from sklearn.utils.testing import check_skip_travis
2726
from sklearn.utils.testing import ignore_warnings
2827

2928
from sklearn.base import clone, ClassifierMixin
@@ -62,9 +61,6 @@ def _boston_subset(n_samples=200):
6261
def set_fast_parameters(estimator):
6362
# speed up some estimators
6463
params = estimator.get_params()
65-
if estimator.__class__.__name__ == 'OrthogonalMatchingPursuitCV':
66-
# FIXME: This test is unstable on Travis, see issue #3190.
67-
check_skip_travis()
6864
if ("n_iter" in params
6965
and estimator.__class__.__name__ != "TSNE"):
7066
estimator.set_params(n_iter=5)

0 commit comments

Comments
 (0)
0