8000 [MRG+2] Avoid failure in first iteration of RANSAC regression (#7914) · Pthinker/scikit-learn@d0ce0d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit d0ce0d9

Browse files
mthorrelljnothman
authored andcommitted
[MRG+2] Avoid failure in first iteration of RANSAC regression (scikit-learn#7914)
Fixes scikit-learn#7908 Adds RANSACRegressor attributes n_skips_* for diagnostics
1 parent e874398 commit d0ce0d9

File tree

3 files changed

+152
-18
lines changed

3 files changed

+152
-18
lines changed

doc/whats_new.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ Enhancements
100100
norm 'max' the norms returned will be the same as for dense matrices.
101101
:issue:`7771` by `Ang Lu <https://github.com/luang008>`_.
102102

103+
- :class:`sklearn.linear_model.RANSACRegressor` no longer throws an error
104+
when calling ``fit`` if no inliers are found in its first iteration.
105+
Furthermore, causes of skipped iterations are tracked in newly added
106+
attributes, ``n_skips_*``.
107+
:issue:`7914` by :user:`Michael Horrell <mthorrell>`.
108+
109+
- Fix a bug where :class:`sklearn.feature_selection.SelectFdr` did not
110+
exactly implement Benjamini-Hochberg procedure. It formerly may have
111+
selected fewer features than it should.
112+
:issue:`7490` by :user:`Peng Meng <mpjlu>`.
113+
103114
- Added ability to set ``n_jobs`` parameter to :func:`pipeline.make_union`.
104115
A ``TypeError`` will be raised for any other kwargs. :issue:`8028`
105116
by :user:`Alexander Booth <alexandercbooth>`.

sklearn/linear_model/ransac.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
111111
max_trials : int, optional
112112
Maximum number of iterations for random sample selection.
113113
114+
max_skips : int, optional
115+
Maximum number of iterations that can be skipped due to finding zero
116+
inliers or invalid data defined by ``is_data_valid`` or invalid models
117+
defined by ``is_model_valid``.
118+
119+
.. versionadded:: 0.19
120+
114121
stop_n_inliers : int, optional
115122
Stop iteration if at least this number of inliers are found.
116123
@@ -168,6 +175,23 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
168175
inlier_mask_ : bool array of shape [n_samples]
169176
Boolean mask of inliers classified as ``True``.
170177
178+
n_skips_no_inliers_ : int
179+
Number of iterations skipped due to finding zero inliers.
180+
181+
.. versionadded:: 0.19
182+
183+
n_skips_invalid_data_ : int
184+
Number of iterations skipped due to invalid data defined by
185+
``is_data_valid``.
186+
187+
.. versionadded:: 0.19
188+
189+
n_skips_invalid_model_ : int
190+
Number of iterations skipped due to an invalid model defined by
191+
``is_model_valid``.
192+
193+
.. versionadded:: 0.19
194+
171195
References
172196
----------
173197
.. [1] https://en.wikipedia.org/wiki/RANSAC
@@ -177,7 +201,7 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
177201

178202
def __init__(self, base_estimator=None, min_samples=None,
179203
residual_threshold=None, is_data_valid=None,
180-
is_model_valid=None, max_trials=100,
204+
is_model_valid=None, max_trials=100, max_skips=np.inf,
181205
stop_n_inliers=np.inf, stop_score=np.inf,
182206
stop_probability=0.99, residual_metric=None,
183207
loss='absolute_loss', random_state=None):
@@ -188,6 +212,7 @@ def __init__(self, base_estimator=None, min_samples=None,
188212
self.is_data_valid = is_data_valid
189213
self.is_model_valid = is_model_valid
190214
self.max_trials = max_trials
215+
self.max_skips = max_skips
191216
self.stop_n_inliers = stop_n_inliers
192217
self.stop_score = stop_score
193218
self.stop_probability = stop_probability
@@ -301,11 +326,14 @@ def fit(self, X, y, sample_weight=None):
301326
if sample_weight is not None:
302327
sample_weight = np.asarray(sample_weight)
303328

304-
n_inliers_best = 0
305-
score_best = np.inf
329+
n_inliers_best = 1
330+
score_best = -np.inf
306331
inlier_mask_best = None
307332
X_inlier_best = None
308333
y_inlier_best = None
334+
self.n_skips_no_inliers_ = 0
335+
self.n_skips_invalid_data_ = 0
336+
self.n_skips_invalid_model_ = 0
309337

310338
# number of data samples
311339
n_samples = X.shape[0]
@@ -315,6 +343,10 @@ def fit(self, X, y, sample_weight=None):
315343

316344
for self.n_trials_ in range(1, self.max_trials + 1):
317345

346+
if (self.n_skips_no_inliers_ + self.n_skips_invalid_data_ +
347+
self.n_skips_invalid_model_) > self.max_skips:
348+
break
349+
318350
# choose random sample set
319351
subset_idxs = sample_without_replacement(n_samples, min_samples,
320352
random_state=random_state)
@@ -324,6 +356,7 @@ def fit(self, X, y, sample_weight=None):
324356
# check if random sample set is valid
325357
if (self.is_data_valid is not None
326358
and not self.is_data_valid(X_subset, y_subset)):
359+
self.n_skips_invalid_data_ += 1
327360
continue
328361

329362
# fit model for current random sample set
@@ -336,6 +369,7 @@ def fit(self, X, y, sample_weight=None):
336369
# check if estimated model is valid
337370
if (self.is_model_valid is not None and not
338371
self.is_model_valid(base_estimator, X_subset, y_subset)):
372+
self.n_skips_invalid_model_ += 1
339373
continue
340374

341375
# residuals of all data for current random sample model
@@ -356,11 +390,8 @@ def fit(self, X, y, sample_weight=None):
356390

357391
# less inliers -> skip current random sample
358392
if n_inliers_subset < n_inliers_best:
393+
self.n_skips_no_inliers_ += 1
359394
continue
360-
if n_inliers_subset == 0:
361-
raise ValueError("No inliers found, possible cause is "
362-
"setting residual_threshold ({0}) too low.".format(
363-
self.residual_threshold))
364395

365396
# extract inlier data set
366397
inlier_idxs_subset = sample_idxs[inlier_mask_subset]
@@ -395,12 +426,28 @@ def fit(self, X, y, sample_weight=None):
395426

396427
# if none of the iterations met the required criteria
397428
if inlier_mask_best is None:
398-
raise ValueError(
399-
"RANSAC could not find valid consensus set, because"
400-
" either the `residual_threshold` rejected all the samples or"
401-
" `is_data_valid` and `is_model_valid` returned False for all"
402-
" `max_trials` randomly ""chosen sub-samples. Consider "
403-
"relaxing the ""constraints.")
429+
if ((self.n_skips_no_inliers_ + self.n_skips_invalid_data_ +
430+
self.n_skips_invalid_model_) > self.max_skips):
431+
raise ValueError(
432+
"RANSAC skipped more iterations than `max_skips` without"
433+
" finding a valid consensus set. Iterations were skipped"
434+
" because each randomly chosen sub-sample failed the"
435+
" passing criteria. See estimator attributes for"
436+
" diagnostics (n_skips*).")
437+
else:
438+
raise ValueError(
439+
"RANSAC could not find a valid consensus set. All"
440+
" `max_trials` iterations were skipped because each"
441+
" randomly chosen sub-sample failed the passing criteria."
442+
" See estimator attributes for diagnostics (n_skips*).")
443+
else:
444+
if (self.n_skips_no_inliers_ + self.n_skips_invalid_data_ +
445+
self.n_skips_invalid_model_) > self.max_skips:
446+
warnings.warn("RANSAC found a valid consensus set but exited"
447+
" early due to skipping more iterations than"
448+
" `max_skips`. See estimator attributes for"
449+
" diagnostics (n_skips*).",
450+
UserWarning)
404451

405452
# estimate final model using all inliers
406453
base_estimator.fit(X_inlier_best, y_inlier_best)

sklearn/linear_model/tests/test_ransac.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from numpy.testing import assert_array_equal
99

1010
from sklearn.utils import check_random_state
11-
from sklearn.utils.testing import assert_raises_regexp
1211
from sklearn.utils.testing import assert_less
1312
from sklearn.utils.testing import assert_warns
1413
from sklearn.utils.testing import assert_almost_equal
14+
from sklearn.utils.testing import assert_raises_regexp
1515
from sklearn.linear_model import LinearRegression, RANSACRegressor, Lasso
1616
from sklearn.linear_model.ransac import _dynamic_max_trials
1717

@@ -152,11 +152,87 @@ def test_ransac_resid_thresh_no_inliers():
152152
# ValueError with a message should be raised
153153
base_estimator = LinearRegression()
154154
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
155-
residual_threshold=0.0, random_state=0)
155+
residual_threshold=0.0, random_state=0,
156+
max_trials=5)
157+
158+
msg = ("RANSAC could not find a valid consensus set")
159+
assert_raises_regexp(ValueError, msg, ransac_estimator.fit, X, y)
160+
assert_equal(ransac_estimator.n_skips_no_inliers_, 5)
161+
assert_equal(ransac_estimator.n_skips_invalid_data_, 0)
162+
assert_equal(ransac_estimator.n_skips_invalid_model_, 0)
163+
164+
165+
def test_ransac_no_valid_data():
166+
def is_data_valid(X, y):
167+
return False
168+
169+
base_estimator = LinearRegression()
170+
ransac_estimator = RANSACRegressor(base_estimator,
171+
is_data_valid=is_data_valid,
172+
max_trials=5)
173+
174+
msg = ("RANSAC could not find a valid consensus set")
175+
assert_raises_regexp(ValueError, msg, ransac_estimator.fit, X, y)
176+
assert_equal(ransac_estimator.n_skips_no_inliers_, 0)
177+
assert_equal(ransac_estimator.n_skips_invalid_data_, 5)
178+
assert_equal(ransac_estimator.n_skips_invalid_model_, 0)
179+
180+
181+
def test_ransac_no_valid_model():
182+
def is_model_valid(estimator, X, y):
183+
return False
184+
185+
base_estimator = LinearRegression()
186+
ransac_estimator = RANSACRegressor(base_estimator,
187+
is_model_valid=is_model_valid,
188+
max_trials=5)
189+
190+
msg = ("RANSAC could not find a valid consensus set")
191+
assert_raises_regexp(ValueError, msg, ransac_estimator.fit, X, y)
192+
assert_equal(ransac_estimator.n_skips_no_inliers_, 0)
193+
assert_equal(ransac_estimator.n_skips_invalid_data_, 0)
194+
assert_equal(ransac_estimator.n_skips_invalid_model_, 5)
195+
196+
197+
def test_ransac_exceed_max_skips():
198+
def is_data_valid(X, y):
199+
return False
200+
201+
base_estimator = LinearRegression()
202+
ransac_estimator = RANSACRegressor(base_estimator,
203+
is_data_valid=is_data_valid,
204+
max_trials=5,
205+
max_skips=3)
206+
207+
msg = ("RANSAC skipped more iterations than `max_skips`")
208+
assert_raises_regexp(ValueError, msg, ransac_estimator.fit, X, y)
209+
assert_equal(ransac_estimator.n_skips_no_inliers_, 0)
210+
assert_equal(ransac_estimator.n_skips_invalid_data_, 4)
211+
assert_equal(ransac_estimator.n_skips_invalid_model_, 0)
212+
213+
214+
def test_ransac_warn_exceed_max_skips():
215+
global cause_skip
216+
cause_skip = False
217+
218+
def is_data_valid(X, y):
219+
global cause_skip
220+
if not cause_skip:
221+
cause_skip = True
222+
return True
223+
else:
224+
return False
225+
226+
base_estimator = LinearRegression()
227+
ransac_estimator = RANSACRegressor(base_estimator,
228+
is_data_valid=is_data_valid,
229+
max_skips=3,
230+
max_trials=5)
156231

157-
assert_raises_regexp(ValueError,
158-
"No inliers.*residual_threshold.*0\.0",
159-
ransac_estimator.fit, X, y)
232+
assert_warns(UserWarning, ransac_estimator.fit, X, y)
233+
assert_equal(ransac_estimator.n_skips_no_inliers_, 0)
234+
assert_equal(ransac_estimator.n_skips_invalid_data_, 4)
235+
assert_equal(ransac_estimator.n_skips_invalid_model_, 0)
160236

161237

162238
def test_ransac_sparse_coo():

0 commit comments

Comments
 (0)
0