From 2f8773131ebd15ddbf3524462c02d88834f4cb8e Mon Sep 17 00:00:00 2001 From: rprkh Date: Sun, 27 Nov 2022 20:56:22 +0530 Subject: [PATCH 1/3] include parameter validation for graphical_lasso --- sklearn/covariance/_graph_lasso.py | 16 ++++++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 17 insertions(+) diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 564d3d21dc681..7f9ea26860cdd 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -25,6 +25,7 @@ ) from ..utils.fixes import delayed from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import validate_params # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' from ..linear_model import _cd_fast as cd_fast # type: ignore @@ -78,6 +79,21 @@ def alpha_max(emp_cov): return np.max(np.abs(A)) +@validate_params( + { + "emp_cov": ["array-like"], + "alpha": [Interval(Real, 0, None, closed="right")], + "cov_init": ["array-like", None], + "mode": [StrOptions({"cd", "lars"})], + "tol": [Interval(Real, 0, None, closed="right")], + "enet_tol": [Interval(Real, 0, None, closed="right")], + "max_iter": [Interval(Integral, 0, None, closed="left")], + "verbose": ["verbose"], + "return_costs": ["boolean"], + "eps": [Interval(Real, 0, None, closed="left")], + "return_n_iter": ["boolean"] + } +) # The g-lasso algorithm def graphical_lasso( emp_cov, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 85cd0638a5ef3..a0e285ca4b58f 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.covariance.graphical_lasso", "sklearn.svm.l1_min_c", "sklearn.metrics.accuracy_score", ] From 7a9af8582c94ddacbaf62c199a634192da312197 Mon Sep 17 00:00:00 2001 From: rprkh Date: Sun, 27 Nov 2022 21:38:12 +0530 Subject: [PATCH 2/3] fix lint and alpha error --- sklearn/covariance/_graph_lasso.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 7f9ea26860cdd..1862c70ff8553 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -82,7 +82,7 @@ def alpha_max(emp_cov): @validate_params( { "emp_cov": ["array-like"], - "alpha": [Interval(Real, 0, None, closed="right")], + "alpha": [Interval(Real, 0, None, closed="both")], "cov_init": ["array-like", None], "mode": [StrOptions({"cd", "lars"})], "tol": [Interval(Real, 0, None, closed="right")], @@ -91,7 +91,7 @@ def alpha_max(emp_cov): "verbose": ["verbose"], "return_costs": ["boolean"], "eps": [Interval(Real, 0, None, closed="left")], - "return_n_iter": ["boolean"] + "return_n_iter": ["boolean"], } ) # The g-lasso algorithm From ee9885fbe73d6c3fefb832431c2d4f2ca3a5690e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Wed, 19 Apr 2023 16:13:36 +0200 Subject: [PATCH 3/3] fix docstring param description --- sklearn/covariance/_graph_lasso.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index ad3cd1f033311..afe21fa3a02f1 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -243,7 +243,7 @@ def graphical_lasso( Parameters ---------- - emp_cov : ndarray of shape (n_features, n_features) + emp_cov : array-like of shape (n_features, n_features) Empirical covariance from which to compute the covariance estimate. alpha : float