8000 FIX Fixes KNeighborsRegressor.predict with array-likes (#22687) · scikit-learn/scikit-learn@742d39c · GitHub
[go: up one dir, main page]

Skip to content

Commit 742d39c

Browse files
FIX Fixes KNeighborsRegressor.predict with array-likes (#22687)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 723b707 commit 742d39c

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,11 @@ Changelog
664664
instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva <DessyVV>` and
665665
:user:`Lucy Jimenez <LucyJimenez>`.
666666

667+
- |Fix| :func:`neighbors.KNeighborsRegressor.predict` now works properly when
668+
given an array-like input if `KNeighborsRegressor` is first constructed with a
669+
callable passed to the `weights` parameter. :pr:`22687` by
670+
:user:`Meekail Zain <micky774>`
671+
667672
:mod:`sklearn.neural_network`
668673
.............................
669674

sklearn/neighbors/_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def predict(self, X):
233233
if weights is None:
234234
y_pred = np.mean(_y[neigh_ind], axis=1)
235235
else:
236-
y_pred = np.empty((X.shape[0], _y.shape[1]), dtype=np.float64)
236+
y_pred = np.empty((neigh_dist.shape[0], _y.shape[1]), dtype=np.float64)
237237
denom = np.sum(weights, axis=1)
238238

239239
for j in range(_y.shape[1]):

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
)
3131
from sklearn.model_selection import cross_val_score
3232
from sklearn.model_selection import train_test_split
33-
from sklearn.neighbors import VALID_METRICS_SPARSE
33+
from sklearn.neighbors import (
34+
VALID_METRICS_SPARSE,
35+
KNeighborsRegressor,
36+
)
3437
from sklearn.neighbors._base import (
3538
_is_sorted_by_data,
3639
_check_precomputed,
@@ -2096,3 +2099,19 @@ def test_radius_neighbors_brute_backend(
20962099
def test_valid_metrics_has_no_duplicate():
20972100
for val in neighbors.VALID_METRICS.values():
20982101
assert len(val) == len(set(val))
2102+
2103+
2104+
def test_regressor_predict_on_arraylikes():
2105+
"""Ensures that `predict` works for array-likes when `weights` is a callable.
2106+
2107+
Non-regression test for #22687.
2108+
"""
2109+
X = [[5, 1], [3, 1], [4, 3], [0, 3]]
2110+
y = [2, 3, 5, 6]
2111+
2112+
def _weights(dist):
2113+
return np.ones_like(dist)
2114+
2115+
est = KNeighborsRegressor(n_neighbors=1, algorithm="brute", weights=_weights)
2116+
est.fit(X, y)
2117+
assert_allclose(est.predict([[0, 2.5]]), [6])

0 commit comments

Comments
 (0)
0