diff --git a/sklearn/linear_model/cd_fast.pyx b/sklearn/linear_model/cd_fast.pyx index a51d1bdbdbc96..f3731e3e7b894 100644 --- a/sklearn/linear_model/cd_fast.pyx +++ b/sklearn/linear_model/cd_fast.pyx @@ -15,6 +15,7 @@ cimport cython from cpython cimport bool from cython cimport floating import warnings +from ..exceptions import ConvergenceWarning ctypedef np.float64_t DOUBLE ctypedef np.uint32_t UINT32_t @@ -302,6 +303,12 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, if gap < tol: # return if we reached desired tolerance break + else: + with gil: + warnings.warn("Objective did not converge." + " You might want to increase the number of iterations.", + ConvergenceWarning) + return w, gap, tol, n_iter + 1 @@ -521,6 +528,11 @@ def sparse_enet_coordinate_descent(floating [:] w, if gap < tol: # return if we reached desired tolerance break + else: + with gil: + warnings.warn("Objective did not converge." + " You might want to increase the number of iterations.", + ConvergenceWarning) return w, gap, tol, n_iter + 1 @@ -675,6 +687,11 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, if gap < tol: # return if we reached desired tolerance break + else: + with gil: + warnings.warn("Objective did not converge." + " You might want to increase the number of iterations.", + ConvergenceWarning) return np.asarray(w), gap, tol, n_iter + 1 @@ -880,5 +897,10 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, if gap < tol: # return if we reached desired tolerance break + else: + with gil: + warnings.warn("Objective did not converge." + " You might want to increase the number of iterations.", + ConvergenceWarning) return np.asarray(W), gap, tol, n_iter + 1 diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 834d685f5b23d..90422ccf731a0 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -818,3 +818,33 @@ def test_coef_shape_not_zero(): est_no_intercept = Lasso(fit_intercept=False) est_no_intercept.fit(np.c_[np.ones(3)], np.ones(3)) assert est_no_intercept.coef_.shape == (1,) + + +def test_enet_coordinate_descent(): + """Test that a warning is issued if model does not converge""" + clf = Lasso() + n_samples = 15500 + n_features = 500 + X = np.ones([n_samples, n_features]) * 1e50 + y = np.ones([n_samples]) + assert_warns(ConvergenceWarning, clf.fit, X, y) + + +def test_enet_coordinate_descent_gram(): + """Test that a warning is issued if model does not converge""" + clf = Lasso(precompute=True) + n_samples = 15500 + n_features = 500 + X = np.ones([n_samples, n_features]) * 1e50 + y = np.ones([n_samples]) + assert_warns(ConvergenceWarning, clf.fit, X, y) + +def test_enet_coordinate_descent_multi_task(): + """Test that a warning is issued if model does not converge""" + clf = MultiTaskLasso() + n_samples = 15500 + n_features = 500 + n_classes = 2 + X = np.ones([n_samples, n_features]) * 1e50 + y = np.ones([n_samples, n_classes]) + assert_warns(ConvergenceWarning, clf.fit, X, y) diff --git a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py index 6b4c09d9742e0..a3526e8d468c1 100644 --- a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py @@ -9,6 +9,8 @@ from sklearn.utils.testing import assert_greater from sklearn.utils.testing import ignore_warnings +from sklearn.utils.testing import assert_warns +from sklearn.exceptions import ConvergenceWarning from sklearn.linear_model.coordinate_descent import (Lasso, ElasticNet, LassoCV, ElasticNetCV) @@ -291,3 +293,13 @@ def test_same_multiple_output_sparse_dense(): predict_sparse = l_sp.predict(sample_sparse) assert_array_almost_equal(predict_sparse, predict_dense) + + +def test_sparse_enet_coordinate_descent(): + """Test that a warning is issued if model does not converge""" + clf = Lasso() + n_samples = 15500 + n_features = 500 + X = sp.csc_matrix((n_samples, n_features)) * 1e50 + y = np.ones([n_samples]) + assert_warns(ConvergenceWarning, clf.fit, X, y)