8000 FEA Add metadata routing to GraphicalLassoCV (#27566) · ssec-jhu/scikit-learn@8ad102b · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ad102b

Browse files
FEA Add metadata routing to GraphicalLassoCV (scikit-learn#27566)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 77f8731 commit 8ad102b

File tree

5 files changed

+123
-21
lines changed

5 files changed

+123
-21
lines changed

doc/metadata_routing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ Meta-estimators and functions supporting metadata routing:
276276

277277
- :class:`sklearn.calibration.CalibratedClassifierCV`
278278
- :class:`sklearn.compose.ColumnTransformer`
279+
- :class:`sklearn.covariance.GraphicalLassoCV`
279280
- :class:`sklearn.ensemble.VotingClassifier`
280281
- :class:`sklearn.ensemble.VotingRegressor`
281282
- :class:`sklearn.ensemble.BaggingClassifier`
@@ -313,7 +314,6 @@ Meta-estimators and functions supporting metadata routing:
313314
Meta-estimators and tools not supporting metadata routing yet:
314315

315316
- :class:`sklearn.compose.TransformedTargetRegressor`
316-
- :class:`sklearn.covariance.GraphicalLassoCV`
317317
- :class:`sklearn.ensemble.AdaBoostClassifier`
318318
- :class:`sklearn.ensemble.AdaBoostRegressor`
319319
- :class:`sklearn.ensemble.StackingClassifier`

doc/whats_new/v1.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ more details.
7979
:class:`model_selection.GridSearchCV` object or the underlying scorer.
8080
:pr:`27560` by :user:`Omar Salman <OmarManzoor>`.
8181

82+
- |Feature| :class:`GraphicalLassoCV` now supports metadata routing in it's
83+
`fit` method and routes metadata to the CV splitter.
84+
:pr:`27566` by :user:`Omar Salman <OmarManzoor>`.
85+
8286
- |Feature| :class:`linear_model.RANSACRegressor` now supports metadata routing
8387
in its ``fit``, ``score`` and ``predict`` methods and route metadata to its
8488
underlying estimator's' ``fit``, ``score`` and ``predict`` methods.

sklearn/covariance/_graph_lasso.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@
2121
from ..linear_model import _cd_fast as cd_fast # type: ignore
2222
from ..linear_model import lars_path_gram
2323
from ..model_selection import check_cv, cross_val_score
24+
from ..utils import Bunch
2425
from ..utils._param_validation import Interval, StrOptions, validate_params
25-
from ..utils.metadata_routing import _RoutingNotSupportedMixin
26+
from ..utils.metadata_routing import (
27+
MetadataRouter,
28+
MethodMapping,
29+
_raise_for_params,
30+
_routing_enabled,
31+
process_routing,
32+
)
2633
from ..utils.parallel import Parallel, delayed
2734
from ..utils.validation import (
2835
_is_arraylike_not_scalar,
@@ -721,7 +728,7 @@ def graphical_lasso_path(
721728
return covariances_, precisions_
722729

723730

724-
class GraphicalLassoCV(_RoutingNotSupportedMixin, BaseGraphicalLasso):
731+
class GraphicalLassoCV(BaseGraphicalLasso):
725732
"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty.
726733
727734
See glossary entry for :term:`cross-validation est 6D40 imator`.
@@ -942,7 +949,7 @@ def __init__(
942949
self.n_jobs = n_jobs
943950

944951
@_fit_context(prefer_skip_nested_validation=True)
945-
def fit(self, X, y=None):
952+
def fit(self, X, y=None, **params):
946953
"""Fit the GraphicalLasso covariance model to X.
947954
948955
Parameters
@@ -953,12 +960,25 @@ def fit(self, X, y=None):
953960
y : Ignored
954961
Not used, present for API consistency by convention.
955962
963+
**params : dict, default=None
964+
Parameters to be passed to the CV splitter and the
965+
cross_val_score function.
966+
967+
.. versionadded:: 1.5
968+
Only available if `enable_metadata_routing=True`,
969+
which can be set by using
970+
``sklearn.set_config(enable_metadata_routing=True)``.
971+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
972+
more details.
973+
956974
Returns
957975
-------
958976
self : object
959977
Returns the instance itself.
960978
"""
961979
# Covariance does not make sense for a single feature
980+
_raise_for_params(params, self, "fit")
981+
962982
X = self._validate_data(X, ensure_min_features=2)
963983
if self.assume_centered:
964984
self.location_ = np.zeros(X.shape[1])
@@ -991,6 +1011,11 @@ def fit(self, X, y=None):
9911011
alpha_0 = 1e-2 * alpha_1
9921012
alphas = np.logspace(np.log10(alpha_0), np.log10(alpha_1), n_alphas)[::-1]
9931013

1014+
if _routing_enabled():
1015+
routed_params = process_routing(self, "fit", **params)
1016+
else:
1017+
routed_params = Bunch(splitter=Bunch(split={}))
1018+
9941019
t0 = time.time()
9951020
for i in range(n_refinements):
9961021
with warnings.catch_warnings():
@@ -1015,7 +1040,7 @@ def fit(self, X, y=None):
10151040
verbose=inner_verbose,
10161041
eps=self.eps,
10171042
)
1018-
for train, test in cv.split(X, y)
1043+
for train, test in cv.split(X, y, **routed_params.splitter.split)
10191044
)
10201045

10211046
# Little danse to transform the list in what we need
@@ -1081,6 +1106,7 @@ def fit(self, X, y=None):
10811106
cv=cv,
10821107
n_jobs=self.n_jobs,
10831108
verbose=inner_verbose,
1109+
params=params,
10841110
)
10851111
)
10861112
grid_scores = np.array(grid_scores)
@@ -1108,3 +1134,23 @@ def fit(self, X, y=None):
11081134
eps=self.eps,
11091135
)
11101136
return self
1137+
1138+
def get_metadata_routing(self):
1139+
"""Get metadata routing of this object.
1140+
1141+
Please check :ref:`User Guide <metadata_routing>` on how the routing
1142+
mechanism works.
1143+
1144+
.. versionadded:: 1.5
1145+
1146+
Returns
1147+
-------
1148+
routing : MetadataRouter
1149+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1150+
routing information.
1151+
"""
1152+
router = MetadataRouter(owner=self.__class__.__name__).add(
1153+
splitter=check_cv(self.cv),
1154+
method_mapping=MethodMapping().add(callee="split", caller="fit"),
1155+
)
1156+
return router

sklearn/covariance/tests/test_graphical_lasso.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
graphical_lasso,
1717
)
1818
from sklearn.datasets import make_sparse_spd_matrix
19+
from sklearn.model_selection import GroupKFold
1920
from sklearn.utils import check_random_state
2021
from sklearn.utils._testing import (
2122
_convert_container,
@@ -254,12 +255,71 @@ def test_graphical_lasso_cv_scores():
254255
X
255256
)
256257

258+
_assert_graphical_lasso_cv_scores(
259+
cov=cov,
260+
n_splits=splits,
261+
n_refinements=n_refinements,
262+
n_alphas=n_alphas,
263+
)
264+
265+
266+
# TODO(1.5): remove in 1.5
267+
def test_graphical_lasso_cov_init_deprecation():
268+
"""Check that we raise a deprecation warning if providing `cov_init` in
269+
`graphical_lasso`."""
270+
rng, dim, n_samples = np.random.RandomState(0), 20, 100
271+
prec = make_sparse_spd_matrix(dim, alpha=0.95, random_state=0)
272+
cov = linalg.inv(prec)
273+
X = rng.multivariate_normal(np.zeros(dim), cov, size=n_samples)
274+
275+
emp_cov = empirical_covariance(X)
276+
with pytest.warns(FutureWarning, match="cov_init parameter is deprecated"):
277+
graphical_lasso(emp_cov, alpha=0.1, cov_init=emp_cov)
278+
279+
280+
@pytest.mark.usefixtures("enable_slep006")
281+
def test_graphical_lasso_cv_scores_with_routing(global_random_seed):
282+
"""Check that `GraphicalLassoCV` internally dispatches metadata to
283+
the splitter.
284+
"""
285+
splits = 5
286+
n_alphas = 5
287+
n_refinements = 3
288+
true_cov = np.array(
289+
[
290+
[0.8, 0.0, 0.2, 0.0],
291+
[0.0, 0.4, 0.0, 0.0],
292+
[0.2, 0.0, 0.3, 0.1],
293+
[0.0, 0.0, 0.1, 0.7],
294+
]
295+
)
296+
rng = np.random.RandomState(global_random_seed)
297+
X = rng.multivariate_normal(mean=[0, 0, 0, 0], cov=true_cov, size=300)
298+
n_samples = X.shape[0]
299+
groups = rng.randint(0, 5, n_samples)
300+
params = {"groups": groups}
301+
cv = GroupKFold(n_splits=splits)
302+
cv.set_split_request(groups=True)
303+
304+
cov = GraphicalLassoCV(cv=cv, alphas=n_alphas, n_refinements=n_refinements).fit(
305+
X, **params
306+
)
307+
308+
_assert_graphical_lasso_cv_scores(
309+
cov=cov,
310+
n_splits=splits,
311+
n_refinements=n_refinements,
312+
n_alphas=n_alphas,
313+
)
314+
315+
316+
def _assert_graphical_lasso_cv_scores(cov, n_splits, n_refinements, n_alphas):
257317
cv_results = cov.cv_results_
258318
# alpha and one for each split
259319

260320
total_alphas = n_refinements * n_alphas + 1
261321
keys = ["alphas"]
262-
split_keys = [f"split{i}_test_score" for i in range(splits)]
322+
split_keys = [f"split{i}_test_score" for i in range(n_splits)]
263323
for key in keys + split_keys:
264324
assert key in cv_results
265325
assert len(cv_results[key]) == total_alphas
@@ -270,17 +330,3 @@ def test_graphical_lasso_cv_scores():
270330

271331
assert_allclose(cov.cv_results_["mean_test_score"], expected_mean)
272332
assert_allclose(cov.cv_results_["std_test_score"], expected_std)
273-
274-
275-
# TODO(1.5): remove in 1.5
276-
def test_graphical_lasso_cov_init_deprecation():
277-
"""Check that we raise a deprecation warning if providing `cov_init` in
278-
`graphical_lasso`."""
279-
rng, dim, n_samples = np.random.RandomState(0), 20, 100
280-
prec = make_sparse_spd_matrix(dim, alpha=0.95, random_state=0)
281-
cov = linalg.inv(prec)
282-
X = rng.multivariate_normal(np.zeros(dim), cov, size=n_samples)
283-
284-
emp_cov = empirical_covariance(X)
285-
with pytest.warns(FutureWarning, match="cov_init parameter is deprecated"):
286-
graphical_lasso(emp_cov, alpha=0.1, cov_init=emp_cov)

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ def enable_slep006():
356356
"cv_name": "cv",
357357
"cv_routing_methods": ["fit"],
358358
},
359+
{
360+
"metaestimator": GraphicalLassoCV,
361+
"X": X,
362+
"y": y,
363+
"cv_name": "cv",
364+
"cv_routing_methods": ["fit"],
365+
},
359366
]
360367
"""List containing all metaestimators to be tested and their settings
361368
@@ -397,7 +404,6 @@ def enable_slep006():
397404
UNSUPPORTED_ESTIMATORS = [
398405
AdaBoostClassifier(),
399406
AdaBoostRegressor(),
400-
GraphicalLassoCV(),
401407
RFE(ConsumingClassifier()),
402408
RFECV(ConsumingClassifier()),
403409
SelfTrainingClassifier(ConsumingClassifier()),

0 commit comments

Comments
 (0)

Footer

© 2025 GitHub, Inc.
0