8000 Merge pull request #4352 from amueller/issue-4297-infinite-isotonic_bak · pletelli/scikit-learn@4cc0235 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4cc0235

Browse files
committed
Merge pull request scikit-learn#4352 from amueller/issue-4297-infinite-isotonic_bak
[MRG + 2] Adding fix for issue scikit-learn#4297, isotonic infinite loop
2 parents 555b859 + 16075fb commit 4cc0235

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

sklearn/isotonic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ def _build_y(self, X, y, sample_weight):
252252
"""Build the y_ IsotonicRegression."""
253253
check_consistent_length(X, y, sample_weight)
254254
X, y = [check_array(x, ensure_2d=False) for x in [X, y]]
255-
if sample_weight is not None:
256-
sample_weight = check_array(sample_weight, ensure_2d=False)
257255

258256
y = as_float_array(y)
259257
self._check_fit_data(X, y, sample_weight)
@@ -264,10 +262,16 @@ def _build_y(self, X, y, sample_weight):
264262
else:
265263
self.increasing_ = self.increasing
266264

265+
# If sample_weights is passed, removed zero-weight values and clean order
266+
if sample_weight is not None:
267+
sample_weight = check_array(sample_weight, ensure_2d=False)
268+
mask = sample_weight > 0
269+
X, y, sample_weight = X[mask], y[mask], sample_weight[mask]
270+
else:
271+
sample_weight = np.ones(len(y))
272+
267273
order = np.lexsort((y, X))
268274
order_inv = np.argsort(order)
269-
if sample_weight is None:
270-
sample_weight = np.ones(len(y))
271275
X, y, sample_weight = [astype(array[order], np.float64, copy=False)
272276
for array in [X, y, sample_weight]]
273277
unique_X, unique_y, unique_sample_weight = _make_unique(X, y, sample_weight)

sklearn/tests/test_isotonic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,29 @@ def test_isotonic_duplicate_min_entry():
325325
all_predictions_finite = np.all(np.isfinite(ir.predict(x)))
326326
assert_true(all_predictions_finite)
327327

328+
329+
def test_isotonic_zero_weight_loop():
330+
# Test from @ogrisel's issue:
331+
# https://github.com/scikit-learn/scikit-learn/issues/4297
332+
333+
# Get deterministic RNG with seed
334+
rng = np.random.RandomState(42)
335+
336+
# Create regression and samples
337+
regression = IsotonicRegression()
338+
n_samples = 50
339+
x = np.linspace(-3, 3, n_samples)
340+
y = x + rng.uniform(size=n_samples)
341+
342+
# Get some random weights and zero out
343+
w = rng.uniform(size=n_samples)
344+
w[5:8] = 0
345+
regression.fit(x, y, sample_weight=w)
346+
347+
# This will hang in failure case.
348+
regression.fit(x, y, sample_weight=w)
349+
350+
328351
if __name__ == "__main__":
329352
import nose
330353
nose.run(argv=['', __file__])

0 commit comments

Comments
 (0)
0