@@ -111,6 +111,13 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
111
111
max_trials : int, optional
112
112
Maximum number of iterations for random sample selection.
113
113
114
+ max_skips : int, optional
115
+ Maximum number of iterations that can be skipped due to finding zero
116
+ inliers or invalid data defined by ``is_data_valid`` or invalid models
117
+ defined by ``is_model_valid``.
118
+
119
+ .. versionadded:: 0.19
120
+
114
121
stop_n_inliers : int, optional
115
122
Stop iteration if at least this number of inliers are found.
116
123
@@ -168,6 +175,23 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
168
175
inlier_mask_ : bool array of shape [n_samples]
169
176
Boolean mask of inliers classified as ``True``.
170
177
178
+ n_skips_no_inliers_ : int
179
+ Number of iterations skipped due to finding zero inliers.
180
+
181
+ .. versionadded:: 0.19
182
+
183
+ n_skips_invalid_data_ : int
184
+ Number of iterations skipped due to invalid data defined by
185
+ ``is_data_valid``.
186
+
187
+ .. versionadded:: 0.19
188
+
189
+ n_skips_invalid_model_ : int
190
+ Number of iterations skipped due to an invalid model defined by
191
+ ``is_model_valid``.
192
+
193
+ .. versionadded:: 0.19
194
+
171
195
References
172
196
----------
173
197
.. [1] https://en.wikipedia.org/wiki/RANSAC
@@ -177,7 +201,7 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
177
201
178
202
def __init__ (self , base_estimator = None , min_samples = None ,
179
203
residual_threshold = None , is_data_valid = None ,
180
- is_model_valid = None , max_trials = 100 ,
204
+ is_model_valid = None , max_trials = 100 , max_skips = np . inf ,
181
205
stop_n_inliers = np .inf , stop_score = np .inf ,
182
206
stop_probability = 0.99 , residual_metric = None ,
183
207
loss = 'absolute_loss' , random_state = None ):
@@ -188,6 +212,7 @@ def __init__(self, base_estimator=None, min_samples=None,
188
212
self .is_data_valid = is_data_valid
189
213
self .is_model_valid = is_model_valid
190
214
self .max_trials = max_trials
215
+ self .max_skips = max_skips
191
216
self .stop_n_inliers = stop_n_inliers
192
217
self .stop_score = stop_score
193
218
self .stop_probability = stop_probability
@@ -301,11 +326,14 @@ def fit(self, X, y, sample_weight=None):
301
326
if sample_weight is not None :
302
327
sample_weight = np .asarray (sample_weight )
303
328
304
- n_inliers_best = 0
305
- score_best = np .inf
329
+ n_inliers_best = 1
330
+ score_best = - np .inf
306
331
inlier_mask_best = None
307
332
X_inlier_best = None
308
333
y_inlier_best = None
334
+ self .n_skips_no_inliers_ = 0
335
+ self .n_skips_invalid_data_ = 0
336
+ self .n_skips_invalid_model_ = 0
309
337
310
338
# number of data samples
311
339
n_samples = X .shape [0 ]
@@ -315,6 +343,10 @@ def fit(self, X, y, sample_weight=None):
315
343
316
344
for self .n_trials_ in range (1 , self .max_trials + 1 ):
317
345
346
+ if (self .n_skips_no_inliers_ + self .n_skips_invalid_data_ +
347
+ self .n_skips_invalid_model_ ) > self .max_skips :
348
+ break
349
+
318
350
# choose random sample set
319
351
subset_idxs = sample_without_replacement (n_samples , min_samples ,
320
352
random_state = random_state )
@@ -324,6 +356,7 @@ def fit(self, X, y, sample_weight=None):
324
356
# check if random sample set is valid
325
357
if (self .is_data_valid is not None
326
358
and not self .is_data_valid (X_subset , y_subset )):
359
+ self .n_skips_invalid_data_ += 1
327
360
continue
328
361
329
362
# fit model for current random sample set
@@ -336,6 +369,7 @@ def fit(self, X, y, sample_weight=None):
336
369
# check if estimated model is valid
337
370
if (self .is_model_valid is not None and not
338
371
self .is_model_valid (base_estimator , X_subset , y_subset )):
372
+ self .n_skips_invalid_model_ += 1
339
373
continue
340
374
341
375
# residuals of all data for current random sample model
@@ -356,11 +390,8 @@ def fit(self, X, y, sample_weight=None):
356
390
357
391
# less inliers -> skip current random sample
358
392
if n_inliers_subset < n_inliers_best :
393
+ self .n_skips_no_inliers_ += 1
359
394
continue
360
- if n_inliers_subset == 0 :
361
- raise ValueError ("No inliers found, possible cause is "
362
- "setting residual_threshold ({0}) too low." .format (
363
- self .residual_threshold ))
364
395
365
396
# extract inlier data set
366
397
inlier_idxs_subset = sample_idxs [inlier_mask_subset ]
@@ -395,12 +426,28 @@ def fit(self, X, y, sample_weight=None):
395
426
396
427
# if none of the iterations met the required criteria
397
428
if inlier_mask_best is None :
398
- raise ValueError (
399
- "RANSAC could not find valid consensus set, because"
400
- " either the `residual_threshold` rejected all the samples or"
401
- " `is_data_valid` and `is_model_valid` returned False for all"
402
- " `max_trials` randomly " "chosen sub-samples. Consider "
403
- "relaxing the " "constraints." )
429
+ if ((self .n_skips_no_inliers_ + self .n_skips_invalid_data_ +
430
+ self .n_skips_invalid_model_ ) > self .max_skips ):
431
+ raise ValueError (
432
+ "RANSAC skipped more iterations than `max_skips` without"
433
+ " finding a valid consensus set. Iterations were skipped"
434
+ " because each randomly chosen sub-sample failed the"
435
+ " passing criteria. See estimator attributes for"
436
+ " diagnostics (n_skips*)." )
437
+ else :
438
+ raise ValueError (
439
+ "RANSAC could not find a valid consensus set. All"
440
+ " `max_trials` iterations were skipped because each"
441
+ " randomly chosen sub-sample failed the passing criteria."
442
+ " See estimator attributes for diagnostics (n_skips*)." )
443
+ else :
444
+ if (self .n_skips_no_inliers_ + self .n_skips_invalid_data_ +
445
+ self .n_skips_invalid_model_ ) > self .max_skips :
446
+ warnings .warn ("RANSAC found a valid consensus set but exited"
447
+ " early due to skipping more iterations than"
448
+ " `max_skips`. See estimator attributes for"
449
+ " diagnostics (n_skips*)." ,
450
+ UserWarning )
404
451
405
452
# estimate final model using all inliers
406
453
base_estimator .fit (X_inlier_best , y_inlier_best )
0 commit comments