From 0d58adf07414dfc92fb6d4bbf0a2ab7c07b64605 Mon Sep 17 00:00:00 2001 From: adossantosalfam Date: Thu, 12 Jan 2023 23:03:52 +0100 Subject: [PATCH] this is my work on grphical_lasso --- sklearn/covariance/_graph_lasso.py | 16 ++++++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 17 insertions(+) diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 6b6116ecce040..5711cd410a638 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -14,6 +14,7 @@ import numpy as np from scipy import linalg from joblib import Parallel +from ..utils._param_validation import Interval, HasMethods, StrOptions, validate_params from . import empirical_covariance, EmpiricalCovariance, log_likelihood @@ -78,6 +79,21 @@ def alpha_max(emp_cov): return np.max(np.abs(A)) +@validate_params( + { + "emp_cov": [np.ndarray], + "alpha": [Interval(Real, 0, None, closed="left"), None], + "cov_init": ["array-like"], + "mode": [StrOptions( {"lars", "cd"} )], + "tol": [Interval(Real, 0, None, closed="left"), None], + "enet_tol": [Interval(Real, 0, None, closed="left"), None], + "max_iter": [Interval(Integral, 0, None, closed="left"), None], + "verbose": ["boolean"], + "return_costs": ["boolean"], + "eps": [Interval(Real, 0, None, closed="left"), None], + "return_n_iter": ["boolean"] + } +) # The g-lasso algorithm def graphical_lasso( emp_cov, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 0630decfd233e..3e7213b95c8fc 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -122,6 +122,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "covariance.graphical_lasso", ]