diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 6b6116ecce040..5711cd410a638 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -14,6 +14,7 @@ import numpy as np from scipy import linalg from joblib import Parallel +from ..utils._param_validation import Interval, HasMethods, StrOptions, validate_params from . import empirical_covariance, EmpiricalCovariance, log_likelihood @@ -78,6 +79,21 @@ def alpha_max(emp_cov): return np.max(np.abs(A)) +@validate_params( + { + "emp_cov": [np.ndarray], + "alpha": [Interval(Real, 0, None, closed="left"), None], + "cov_init": ["array-like"], + "mode": [StrOptions( {"lars", "cd"} )], + "tol": [Interval(Real, 0, None, closed="left"), None], + "enet_tol": [Interval(Real, 0, None, closed="left"), None], + "max_iter": [Interval(Integral, 0, None, closed="left"), None], + "verbose": ["boolean"], + "return_costs": ["boolean"], + "eps": [Interval(Real, 0, None, closed="left"), None], + "return_n_iter": ["boolean"] + } +) # The g-lasso algorithm def graphical_lasso( emp_cov, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 0630decfd233e..3e7213b95c8fc 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -122,6 +122,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "covariance.graphical_lasso", ]