8000 ENH Add metadata routing for `RANSACRegressor` (#28261) · scikit-learn/scikit-learn@e2f9530 · GitHub
[go: up one dir, main page]

Skip to content

Commit e2f9530

Browse files
ENH Add metadata routing for RANSACRegressor (#28261)
1 parent bb53bb3 commit e2f9530

9 files changed

+292
-65
lines changed

doc/metadata_routing.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ Meta-estimators and functions supporting metadata routing:
287287
- :class:`sklearn.linear_model.LogisticRegressionCV`
288288
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
289289
- :class:`sklearn.linear_model.MultiTaskLassoCV`
290+
- :class:`sklearn.linear_model.RANSACRegressor`
290291
- :class:`sklearn.model_selection.GridSearchCV`
291292
- :class:`sklearn.model_selection.HalvingGridSearchCV`
292293
- :class:`sklearn.model_selection.HalvingRandomSearchCV`
@@ -315,6 +316,7 @@ Meta-estimators and tools not supporting metadata routing yet:
315316
- :class:`sklearn.feature_selection.RFE`
316317
- :class:`sklearn.feature_selection.RFECV`
317318
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
319+
- :class:`sklearn.impute.IterativeImputer`
318320
- :class:`sklearn.linear_model.RANSACRegressor`
319321
- :class:`sklearn.linear_model.RidgeClassifierCV`
320322
- :class:`sklearn.linear_model.RidgeCV`

doc/modules/linear_model.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,10 @@ Each iteration performs the following steps:
15361536

15371537
1. Select ``min_samples`` random samples from the original data and check
15381538
whether the set of data is valid (see ``is_data_valid``).
1539-
2. Fit a model to the random subset (``base_estimator.fit``) and check
1539+
2. Fit a model to the random subset (``estimator.fit``) and check
15401540
whether the estimated model is valid (see ``is_model_valid``).
15411541
3. Classify all data as inliers or outliers by calculating the residuals
1542-
to the estimated model (``base_estimator.predict(X) - y``) - all data
1542+
to the estimated model (``estimator.predict(X) - y``) - all data
15431543
samples with absolute residuals smaller than or equal to the
15441544
``residual_threshold`` are considered as inliers.
15451545
4. Save fitted model as best model if number of inlier samples is

doc/whats_new/v1.5.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ more details.
4848
via their `fit` methods.
4949
:pr:`28432` by :user:`Adam Li <adam2392>` and :user:`Benjamin Bossan <BenjaminBossan>`.
5050

51+
Metadata Routing
52+
----------------
53+
54+
The following models now support metadata routing in one or more or their
55+
methods. Refer to the :ref:`Metadata Routing User Guide <metadata_routing>` for
56+
more details.
57+
58+
- |Feature| :class:`linear_model.RANSACRegressor` now supports metadata routing
59+
in its ``fit``, ``score`` and ``predict`` methods and route metadata to its
60+
underlying estimator's' ``fit``, ``score`` and ``predict`` methods.
61+
:pr:`28261` by :user:`Stefanie Senger <StefanieSenger>`.
62+
5163
- |Feature| :class:`ensemble.VotingClassifier` and
5264
:class:`ensemble.VotingRegressor` now support metadata routing and pass
5365
``**fit_params`` to the underlying estimators via their `fit` methods.

sklearn/linear_model/_ransac.py

Lines changed: 130 additions & 27 deletions
< 10000 td data-grid-cell-id="diff-fa70c71c755e7824bda9d9fe5e1a8c2c48d80ff23033ce51196867f134c24c68-34-44-2" data-line-anchor="diff-fa70c71c755e7824bda9d9fe5e1a8c2c48d80ff23033ce51196867f134c24c68R44" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">

Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from ..exceptions import ConvergenceWarning
1919
from ..utils import check_consistent_length, check_random_state
20+
from ..utils._bunch import Bunch
2021
from ..utils._param_validation import (
2122
HasMethods,
2223
Interval,
@@ -25,11 +26,20 @@
2526
StrOptions,
2627
)
2728
from ..utils.metadata_routing import (
28-
_raise_for_unsupported_routing,
29-
_RoutingNotSupportedMixin,
29+
MetadataRouter,
30+
MethodMapping,
31+
_raise_for_params,
32+
_routing_enabled,
33+
process_routing,
3034
)
3135
from ..utils.random import sample_without_replacement
32-
from ..utils.validation import _check_sample_weight, check_is_fitted, has_fit_parameter
36+
from ..utils.validation import (
37+
_check_method_params,
38+
_check_sample_weight,
39+
_deprecate_positional_args,
40+
check_is_fitted,
41+
has_fit_parameter,
42+
)
3343
from ._base import LinearRegression
3444
3545
_EPSILON = np.spacing(1)
@@ -70,7 +80,6 @@ def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
7080

7181

7282
class RANSACRegressor(
73-
_RoutingNotSupportedMixin,
7483
MetaEstimatorMixin,
7584
RegressorMixin,
7685
MultiOutputMixin,
@@ -306,7 +315,11 @@ def __init__(
306315
# RansacRegressor.estimator is not validated yet
307316
prefer_skip_nested_validation=False
308317
)
309-
def fit(self, X, y, sample_weight=None):
318+
# TODO(1.7): remove `sample_weight` from the signature after deprecation
319+
# cycle; for backwards compatibility: pop it from `fit_params` before the
320+
# `_raise_for_params` check and reinsert it after the check
321+
@_deprecate_positional_args(version="1.7")
322+
def fit(self, X, y, *, sample_weight=None, **fit_params):
310323
"""Fit estimator using RANSAC algorithm.
311324
312325
Parameters
@@ -324,6 +337,17 @@ def fit(self, X, y, sample_weight=None):
324337
325338
.. versionadded:: 0.18
326339
340+
**fit_params : dict
341+
Parameters routed to the `fit` method of the sub-estimator via the
342+
metadata routing API.
343+
344+
.. versionadded:: 1.5
345+
346+
Only available if
347+
`sklearn.set_config(enable_metadata_routing=True)` is set. See
348+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
349+
details.
350+
327351
Returns
328352
-------
329353
self : object
@@ -336,10 +360,10 @@ def fit(self, X, y, sample_weight=None):
336360
`is_data_valid` and `is_model_valid` return False for all
337361
`max_trials` randomly chosen sub-samples.
338362
"""
339-
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
340363
# Need to validate separately here. We can't pass multi_output=True
341364
# because that would allow y to be csr. Delay expensive finiteness
342365
# check to the estimator's own input validation.
366+
_raise_for_params(fit_params, self, "fit")
343367
check_X_params = dict(accept_sparse="csr", force_all_finite=False)
344368
check_y_params = dict(ensure_2d=False)
345369
X, y = self._validate_data(
@@ -404,12 +428,22 @@ def fit(self, X, y, sample_weight=None):
404428
estimator_name = type(estimator).__name__
405429
if sample_weight is not None and not estimator_fit_has_sample_weight:
406430
raise ValueError(
407-
"%s does not support sample_weight. Samples"
431+
"%s does not support sample_weight. Sample"
408432
" weights are only used for the calibration"
409433
" itself." % estimator_name
410434
)
435+
411436
if sample_weight is not None:
412-
sample_weight = _check_sample_weight(sample_weight, X)
437+
fit_params["sample_weight"] = sample_weight
438+
439+
if _routing_enabled():
440+
routed_params = process_routing(self, "fit", **fit_params)
441+
else:
442+
routed_params = Bunch()
443+
routed_params.estimator = Bunch(fit={}, predict={}, score={})
444+
if sample_weight is not None:
445+
sample_weight = _check_sample_weight(sa 1241 mple_weight, X)
446+
routed_params.estimator.fit = {"sample_weight": sample_weight}
413447

414448
n_inliers_best = 1
415449
score_best = -np.inf
@@ -451,13 +485,13 @@ def fit(self, X, y, sample_weight=None):
451485
self.n_skips_invalid_data_ += 1
452486
continue
453487

488+
# cut `fit_params` down to `subset_idxs`
489+
fit_params_subset = _check_method_params(
490+
X, params=routed_params.estimator.fit, indices=subset_idxs
491+
)
492+
454493
# fit model for current random sample set
455-
if sample_weight is None:
456-
estimator.fit(X_subset, y_subset)
457-
else:
458-
estimator.fit(
459-
X_subset, y_subset, sample_weight=sample_weight[subset_idxs]
460-
)
494+
estimator.fit(X_subset, y_subset, **fit_params_subset)
461495

462496
# check if estimated model is valid
463497
if self.is_model_valid is not None and not self.is_model_valid(
@@ -484,8 +518,17 @@ def fit(self, X, y, sample_weight=None):
484518
X_inlier_subset = X[inlier_idxs_subset]
485519
y_inlier_subset = y[inlier_idxs_subset]
486520

521+
# cut `fit_params` down to `inlier_idxs_subset`
522+
score_params_inlier_subset = _check_method_params(
523+
X, params=routed_params.estimator.score, indices=inlier_idxs_subset
524+
)
525+
487526
# score of inlier data set
488-
score_subset = estimator.score(X_inlier_subset, y_inlier_subset)
527+
score_subset = estimator.score(
528+
X_inlier_subset,
529+
y_inlier_subset,
530+
**score_params_inlier_subset,
531+
)
489532

490533
# same number of inliers but worse score -> skip current random
491534
# sample
@@ -549,20 +592,17 @@ def fit(self, X, y, sample_weight=None):
549592
)
550593

551594
# estimate final model using all inliers
552-
if sample_weight is None:
553-
estimator.fit(X_inlier_best, y_inlier_best)
554-
else:
555-
estimator.fit(
556-
X_inlier_best,
557-
y_inlier_best,
558-
sample_weight=sample_weight[inlier_best_idxs_subset],
559-
)
595+
fit_params_best_idxs_subset = _check_method_params(
596+
X, params=routed_params.estimator.fit, indices=inlier_best_idxs_subset
597+
)
598+
599+
estimator.fit(X_inlier_best, y_inlier_best, **fit_params_best_idxs_subset)
560600

561601
self.estimator_ = estimator
562602
self.inlier_mask_ = inlier_mask_best
563603
return self
564604

565-
def predict(self, X):
605+
def predict(self, X, **params):
566606
"""Predict using the estimated model.
567607
568608
This is a wrapper for `estimator_.predict(X)`.
@@ -572,6 +612,17 @@ def predict(self, X):
572612
X : {array-like or sparse matrix} of shape (n_samples, n_features)
573613
Input data.
574614
615+
**params : dict
616+
Parameters routed to the `predict` method of the sub-estimator via
617+
the metadata routing API.
618+
619+
.. versionadded:: 1.5
620+
621+
Only available if
622+
`sklearn.set_config(enable_metadata_routing=True)` is set. See
623+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
624+
details.
625+
575626
Returns
576627
-------
577628
y : array, shape = [n_samples] or [n_samples, n_targets]
@@ -584,9 +635,19 @@ def predict(self, X):
584635
accept_sparse=True,
585636
reset=False,
586637
)
587-
return self.estimator_.predict(X)
588638

589-
def score(self, X, y):
639+
_raise_for_params(params, self, "predict")
640+
641+
if _routing_enabled():
642+
predict_params = process_routing(self, "predict", **params).estimator[
643+
"predict"
644+
]
645+
else:
646+
predict_params = {}
647+
648+
return self.estimator_.predict(X, **predict_params)
649+
650+
def score(self, X, y, **params):
590651
"""Return the score of the prediction.
591652
592653
This is a wrapper for `estimator_.score(X, y)`.
@@ -599,6 +660,17 @@ def score(self, X, y):
599660
y : array-like of shape (n_samples,) or (n_samples, n_targets)
600661
Target values.
601662
663+
**params : dict
664+
Parameters routed to the `score` method of the sub-estimator via
665+
the metadata routing API.
666+
667+
.. versionadded:: 1.5
668+
669+
Only available if
670+
`sklearn.set_config(enable_metadata_routing=True)` is set. See
671+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
672+
details.
673+
602674
Returns
603675
-------
604676
z : float
@@ -611,7 +683,38 @@ def score(self, X, y):
611683
accept_sparse=True,
612684
reset=False,
613685
)
614-
return self.estimator_.score(X, y)
686+
687+
_raise_for_params(params, self, "score")
688+
if _routing_enabled():
689+
score_params = process_routing(self, "score", **params).estimator["score"]
690+
else:
691+
score_params = {}
692+
693+
return self.estimator_.score(X, y, **score_params)
694+
695+
def get_metadata_routing(self):
696+
"""Get metadata routing of this object.
697+
698+
Please check :ref:`User Guide <metadata_routing>` on how the routing
699+
mechanism works.
700+
701+
.. versionadded:: 1.5
702+
703+
Returns
704+
-------
705+
routing : MetadataRouter
706+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
707+
routing information.
708+
"""
709+
router = MetadataRouter(owner=self.__class__.__name__).add(
710+
estimator=self.estimator,
711+
method_mapping=MethodMapping()
712+
.add(caller="fit", callee="fit")
713+
.add(caller="fit", callee="score")
714+
.add(caller="score", callee="score")
715+
.add(caller="predict", callee="predict"),
716+
)
717+
return router
615718

616719
def _more_tags(self):
617720
return {

sklearn/linear_model/tests/test_ransac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def test_ransac_fit_sample_weight():
461461
ransac_estimator = RANSACRegressor(random_state=0)
462462
n_samples = y.shape[0]
463463
weights = np.ones(n_samples)
464-
ransac_estimator.fit(X, y, weights)
464+
ransac_estimator.fit(X, y, sample_weight=weights)
465465
# sanity check
466466
assert ransac_estimator.inlier_mask_.shape[0] == n_samples
467467

@@ -498,7 +498,7 @@ def test_ransac_fit_sample_weight():
498498
sample_weight = np.append(sample_weight, outlier_weight)
499499
X_ = np.append(X_, outlier_X, axis=0)
500500
y_ = np.append(y_, outlier_y)
501-
ransac_estimator.fit(X_, y_, sample_weight)
501+
ransac_estimator.fit(X_, y_, sample_weight=sample_weight)
502502

503503
assert_allclose(ransac_estimator.estimator_.coef_, ref_coef_)
504504

@@ -509,15 +509,15 @@ def test_ransac_fit_sample_weight():
509509

510510
err_msg = f"{estimator.__class__.__name__} does not support sample_weight."
511511
with pytest.raises(ValueError, match=err_msg):
512-
ransac_estimator.fit(X, y, weights)
512+
ransac_estimator.fit(X, y, sample_weight=weights)
513513

514514

515515
def test_ransac_final_model_fit_sample_weight():
516516
X, y = make_regression(n_samples=1000, random_state=10)
517517
rng = check_random_state(42)
518518
sample_weight = rng.randint(1, 4, size=y.shape[0])
519519
sample_weight = sample_weight / sample_weight.sum()
520-
ransac = RANSACRegressor(estimator=LinearRegression(), random_state=0)
520+
ransac = RANSACRegressor(random_state=0)
521521
ransac.fit(X, y, sample_weight=sample_weight)
522522

523523
final_model = LinearRegression()

sklearn/tests/metadata_routing_common.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,17 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
162162
)
163163
return self
164164

165-
def predict(self, X, sample_weight="default", metadata="default"):
166-
pass # pragma: no cover
165+
def predict(self, X, y=None, sample_weight="default", metadata="default"):
166+
record_metadata_not_default(
167+
self, "predict", sample_weight=sample_weight, metadata=metadata
168+
)
169+
return np.zeros(shape=(len(X),))
167170

168-
# when needed, uncomment the implementation
169-
# record_metadata_not_default(
170-
# self, "predict", sample_weight=sample_weight, metadata=metadata
171-
# )
172-
# return np.zeros(shape=(len(X),))
171+
def score(self, X, y, sample_weight="default", metadata="default"):
172+
record_metadata_not_default(
173+
self, "score", sample_weight=sample_weight, metadata=metadata
174+
)
175+
return 1
173176

174177

175178
class NonConsumingClassifier(ClassifierMixin, BaseEstimator):
@@ -278,6 +281,13 @@ def decision_function(self, X, sample_weight="default", metadata="default"):
278281
)
279282
return np.zeros(shape=(len(X),))
280283

284+
# uncomment when needed
285+
# def score(self, X, y, sample_weight="default", metadata="default"):
286+
# record_metadata_not_default(
287+
# self, "score", sample_weight=sample_weight, metadata=metadata
288+
# )
289+
# return 1
290+
281291

282292
class ConsumingTransformer(TransformerMixin, BaseEstimator):
283293
"""A transformer which accepts metadata on fit and transform.

0 commit comments

Comments
 (0)
0