8000 MNT: Make error message clearer for n_neighbors (#23317) · scikit-learn/scikit-learn@5c4288d · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c4288d

Browse files
bharatr21thomasjpfanbetatimTomDLTglemaitre
authored
MNT: Make error message clearer for n_neighbors (#23317)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Tom Dupré la Tour <tom.duprelatour.10@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 9e8f10e commit 5c4288d

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

doc/whats_new/v1.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,10 @@ Changelog
474474
when `radius` is large and `algorithm="brute"` with non-Euclidean metrics.
475475
:pr:`26828` by :user:`Omar Salman <OmarManzoor>`.
476476

477+
- |Fix| Improve error message for :class:`neighbors.LocalOutlierFactor`
478+
when it is invoked with `n_samples = n_neighbors`.
479+
:pr:`23317` by :user:`Bharat Raghunathan <Bharat123rox>`.
480+
477481
:mod:`sklearn.preprocessing`
478482
............................
479483

sklearn/neighbors/_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,15 @@ class from an array representing our data set and ask who's
813813

814814
n_samples_fit = self.n_samples_fit_
815815
if n_neighbors > n_samples_fit:
816+
if query_is_train:
817+
n_neighbors -= 1 # ok to modify inplace because an error is raised
818+
inequality_str = "n_neighbors < n_samples_fit"
819+
else:
820+
inequality_str = "n_neighbors <= n_samples_fit"
816821
raise ValueError(
817-
"Expected n_neighbors <= n_samples, "
818-
" but n_samples = %d, n_neighbors = %d" % (n_samples_fit, n_neighbors)
822+
f"Expected {inequality_str}, but "
823+
f"n_neighbors = {n_neighbors}, n_samples_fit = {n_samples_fit}, "
824+
f"n_samples = {X.shape[0]}" # include n_samples for common tests
819825
)
820826

821827
n_jobs = effective_n_jobs(self.n_jobs)

sklearn/neighbors/tests/test_lof.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,50 @@ def test_sparse 8000 (csr_container):
255255
lof.fit_predict(X)
256256

257257

258+
def test_lof_error_n_neighbors_too_large():
259+
"""Check that we raise a proper error message when n_neighbors == n_samples.
260+
261+
Non-regression test for:
262+
https://github.com/scikit-learn/scikit-learn/issues/17207
263+
"""
264+
X = np.ones((7, 7))
265+
266+
msg = (
267+
"Expected n_neighbors < n_samples_fit, but n_neighbors = 1, "
268+
"n_samples_fit = 1, n_samples = 1"
269+
)
270+
with pytest.raises(ValueError, match=msg):
271+
lof = neighbors.LocalOutlierFactor(n_neighbors=1).fit(X[:1])
272+
273+
lof = neighbors.LocalOutlierFactor(n_neighbors=2).fit(X[:2])
274+
assert lof.n_samples_fit_ == 2
275+
276+
msg = (
277+
"Expected n_neighbors < n_samples_fit, but n_neighbors = 2, "
278+
"n_samples_fit = 2, n_samples = 2"
279+
)
280+
with pytest.raises(ValueError, match=msg):
281+
lof.kneighbors(None, n_neighbors=2)
282+
283+
distances, indices = lof.kneighbors(None, n_neighbors=1)
284+
assert distances.shape == (2, 1)
285+
assert indices.shape == (2, 1)
286+
287+
msg = (
288+
"Expected n_neighbors <= n_samples_fit, but n_neighbors = 3, "
289+
"n_samples_fit = 2, n_samples = 7"
290+
)
291+
with pytest.raises(ValueError, match=msg):
292+
lof.kneighbors(X, n_neighbors=3)
293+
294+
(
295+
distances,
296+
indices,
297+
) = lof.kneighbors(X, n_neighbors=2)
298+
assert distances.shape == (7, 2)
299+
assert indices.shape == (7, 2)
300+
301+
258302
@pytest.mark.parametrize("algorithm", ["auto", "ball_tree", "kd_tree", "brute"])
259303
@pytest.mark.parametrize("novelty", [True, False])
260304
@pytest.mark.parametrize("contamination", [0.5, "auto"])

0 commit comments

Comments
 (0)
0