Skip to content

Commit 6e8ebda

Browse files
aivision2020jnothman
authored andcommitted
[MRG+1] FIX bug where ransac is running too many iterations (scikit-learn#8271)
1 parent 00da9cc commit 6e8ebda

File tree

3 files changed

+28
-18
lines changed

3 files changed

+28
-18
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ Bug fixes
195195
sparse input.
196196
:issue:`8259` by :user:`Aman Dalmia <dalmia>`.
197197

198+
- Fixed a bug where :func:`sklearn.linear_model.RANSACRegressor.fit` may run until
199+
``max_iter`` if finds a large inlier group early. :issue:`8251` by :user:`aivision2020`.
200+
198201
- Fixed a bug where :func:`sklearn.datasets.make_moons` gives an
199202
incorrect result when ``n_samples`` is odd.
200203
:issue:`8198` by :user:`Josh Levy <levy5674>`.

sklearn/linear_model/ransac.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,10 @@ def fit(self, X, y, sample_weight=None):
342342

343343
n_samples, _ = X.shape
344344

345-
for self.n_trials_ in range(1, self.max_trials + 1):
345+
self.n_trials_ = 0
346+
max_trials = self.max_trials
347+
while self.n_trials_ < max_trials:
348+
self.n_trials_ += 1
346349

347350
if (self.n_skips_no_inliers_ + self.n_skips_invalid_data_ +
348351
self.n_skips_invalid_model_) > self.max_skips:
@@ -416,13 +419,14 @@ def fit(self, X, y, sample_weight=None):
416419
X_inlier_best = X_inlier_subset
417420
y_inlier_best = y_inlier_subset
418421

422+
max_trials = min(
423+
max_trials,
424+
_dynamic_max_trials(n_inliers_best, n_samples,
425+
min_samples, self.stop_probability))
426+
419427
# break if sufficient number of inliers or score is reached
420-
if (n_inliers_best >= self.stop_n_inliers
421-
or score_best >= self.stop_score
422-
or self.n_trials_
423-
>= _dynamic_max_trials(n_inliers_best, n_samples,
424-
min_samples,
425-
self.stop_probability)):
428+
if n_inliers_best >= self.stop_n_inliers or \
429+
score_best >= self.stop_score:
426430
break
427431

428432
# if none of the iterations met the required criteria

sklearn/linear_model/tests/test_ransac.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
data = np.column_stack([X, y])
2323

2424
# Add some faulty data
25-
outliers = np.array((10, 30, 200))
26-
data[outliers[0], :] = (1000, 1000)
27-
data[outliers[1], :] = (-1000, -1000)
28-
data[outliers[2], :] = (-100, -50)
25+
rng = np.random.RandomState(1000)
26+
outliers = np.unique(rng.randint(len(X), size=200))
27+
data[outliers, :] += 50 + rng.rand(len(outliers), 2) * 10
2928

3029
X = data[:, 0][:, np.newaxis]
3130
y = data[:, 1]
@@ -90,13 +89,16 @@ def test_ransac_max_trials():
9089
random_state=0)
9190
assert_raises(ValueError, ransac_estimator.fit, X, y)
9291

93-
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
94-
residual_threshold=5, max_trials=11,
95-
random_state=0)
96-
assert getattr(ransac_estimator, 'n_trials_', None) is None
97-
ransac_estimator.fit(X, y)
98-
assert_equal(ransac_estimator.n_trials_, 2)
99-
92+
# there is a 1e-9 chance it will take these many trials. No good reason
93+
# 1e-2 isn't enough, can still happen
94+
# 2 is the what ransac defines as min_samples = X.shape[1] + 1
95+
max_trials = _dynamic_max_trials(
96+
len(X) - len(outliers), X.shape[0], 2, 1 - 1e-9)
97+
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2)
98+
for i in range(50):
99+
ransac_estimator.set_params(min_samples=2, random_state=i)
100+
ransac_estimator.fit(X, y)
101+
assert_less(ransac_estimator.n_trials_, max_trials + 1)
100102

101103
def test_ransac_stop_n_inliers():
102104
base_estimator = LinearRegression()
@@ -383,6 +385,7 @@ def test_ransac_residual_metric():
383385
assert_array_almost_equal(ransac_estimator0.predict(X),
384386
ransac_estimator2.predict(X))
385387

388+
386389
def test_ransac_residual_loss():
387390
loss_multi1 = lambda y_true, y_pred: np.sum(np.abs(y_true - y_pred), axis=1)
388391
loss_multi2 = lambda y_true, y_pred: np.sum((y_true - y_pred) ** 2, axis=1)

0 commit comments

Comments
 (0)