8000 FIX adjust inliner criteria in RANSACRegressor (#19499) · rth/scikit-learn@0b45ac5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b45ac5

Browse files
gregorystrubelGregory Strubeljjerphanglemaitre
authored
FIX adjust inliner criteria in RANSACRegressor (scikit-learn#19499)
Co-authored-by: Gregory Strubel <greg@Air-de-Ali.local> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 6f7ae91 commit 0b45ac5

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

doc/modules/linear_model.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,8 +1265,8 @@ Each iteration performs the following steps:
12651265
whether the estimated model is valid (see ``is_model_valid``).
12661266
3. Classify all data as inliers or outliers by calculating the residuals
12671267
to the estimated model (``base_estimator.predict(X) - y``) - all data
1268-
samples with absolute residuals smaller than the ``residual_threshold``
1269-
are considered as inliers.
1268+
samples with absolute residuals smaller than or equal to the
1269+
``residual_threshold`` are considered as inliers.
12701270
4. Save fitted model as best model if number of inlier samples is
12711271
maximal. In case the current estimated model has the same number of
12721272
inliers, it is only considered as the best model if it has better score.

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,11 @@ Changelog
456456
:pr:`19426` by :user:`Alexandre Gramfort <agramfort>` and
457457
:user:`Maria Telenczuk <maikia>`.
458458

459+
- |Fix| Points with residuals equal to ``residual_threshold`` are now considered
460+
as inliers for :class:`linear_model.RANSACRegressor`. This allows fitting
461+
a model perfectly on some datasets when `residual_threshold=0`.
462+
:pr:`19499` by :user:`Gregory Strubel <gregorystrubel>`.
463+
459464
- |Efficiency| The implementation of `fit` for `PolynomialFeatures` transformer
460465
is now faster. This is especially noticeable on large sparse input.
461466
:pr:`19734` by :user:`Fred Robinson <frrad>`.

sklearn/linear_model/_ransac.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class RANSACRegressor(
9595
residual_threshold : float, default=None
9696
Maximum residual for a data sample to be classified as an inlier.
9797
By default the threshold is chosen as the MAD (median absolute
98-
deviation) of the target values `y`.
98+
deviation) of the target values `y`. Points whose residuals are
99+
strictly equal to the threshold are considered as inliers.
99100
100101
is_data_valid : callable, default=None
101102
This function is called with the randomly selected data before the
@@ -434,7 +435,7 @@ def fit(self, X, y, sample_weight=None):
434435
residuals_subset = loss_function(y, y_pred)
435436

436437
# classify data into inliers and outliers
437-
inlier_mask_subset = residuals_subset < residual_threshold
438+
inlier_mask_subset = residuals_subset <= residual_threshold
438439
n_inliers_subset = np.sum(inlier_mask_subset)
439440

440441
# less inliers -> skip current random sample

sklearn/linear_model/tests/test_ransac.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,14 @@ def test_ransac_predict():
168168
assert_array_equal(ransac_estimator.predict(X), np.zeros(100))
169169

170170

171-
def test_ransac_resid_thresh_no_inliers():
172-
# When residual_threshold=0.0 there are no inliers and a
171+
def test_ransac_residuals_threshold_no_inliers():
172+
# When residual_threshold=nan there are no inliers and a
173173
# ValueError with a message should be raised
174174
base_estimator = LinearRegression()
175175
ransac_estimator = RANSACRegressor(
176176
base_estimator,
177177
min_samples=2,
178-
residual_threshold=0.0,
178+
residual_threshold=float("nan"),
179179
random_state=0,
180180
max_trials=5,
181181
)
@@ -597,6 +597,22 @@ def test_ransac_final_model_fit_sample_weight():
597597
assert_allclose(ransac.estimator_.coef_, final_model.coef_, atol=1e-12)
598598

599599

600+
def test_perfect_horizontal_line():
601+
"""Check that we can fit a line where all samples are inliers.
602+
Non-regression test for:
603+
https://github.com/scikit-learn/scikit-learn/issues/19497
604+
"""
605+
X = np.arange(100)[:, None]
606+
y = np.zeros((100,))
607+
608+
base_estimator = LinearRegression()
609+
ransac_estimator = RANSACRegressor(base_estimator, random_state=0)
610+
ransac_estimator.fit(X, y)
611+
612+
assert_allclose(ransac_estimator.estimator_.coef_, 0.0)
613+
assert_allclose(ransac_estimator.estimator_.intercept_, 0.0)
614+
615+
600616
# TODO: Remove in v1.2
601617
@pytest.mark.parametrize(
602618
"old_loss, new_loss",

0 commit comments

Comments
 (0)
0