8000 DOC Add link to plot_regression.py (#29232) · scikit-learn/scikit-learn@1e3c7be · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e3c7be

Browse files
DOC Add link to plot_regression.py (#29232)
Co-authored-by: adrinjalali <adrin.jalali@gmail.com>
1 parent 56dbfd0 commit 1e3c7be

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

examples/neighbors/plot_regression.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,44 @@
66
Demonstrate the resolution of a regression problem
77
using a k-Nearest Neighbor and the interpolation of the
88
target using both barycenter and constant weights.
9-
109
"""
1110

1211
# Authors: The scikit-learn developers
1312
# SPDX-License-Identifier: BSD-3-Clause
1413

15-
1614
# %%
1715
# Generate sample data
1816
# --------------------
17+
# Here we generate a few data points to use to train the model. We also generate
18+
# data in the whole range of the training data to visualize how the model would
19+
# react in that whole region.
1920
import matplotlib.pyplot as plt
2021
import numpy as np
2122

2223
from sklearn import neighbors
2324

24-
np.random.seed(0)
25-
X = np.sort(5 * np.random.rand(40, 1), axis=0)
26-
T = np.linspace(0, 5, 500)[:, np.newaxis]
27-
y = np.sin(X).ravel()
25+
rng = np.random.RandomState(0)
26+
X_train = np.sort(5 * rng.rand(40, 1), axis=0)
27+
X_test = np.linspace(0, 5, 500)[:, np.newaxis]
28+
y = np.sin(X_train).ravel()
2829

2930
# Add noise to targets
3031
y[::5] += 1 * (0.5 - np.random.rand(8))
3132

3233
# %%
3334
# Fit regression model
3435
# --------------------
36+
# Here we train a model and visualize how `uniform` and `distance`
37+
# weights in prediction effect predicted values.
3538
n_neighbors = 5
3639

3740
for i, weights in enumerate(["uniform", "distance"]):
3841
knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
39-
y_ = knn.fit(X, y).predict(T)
42+
y_ = knn.fit(X_train, y).predict(X_test)
4043

4144
plt.subplot(2, 1, i + 1)
42-
plt.scatter(X, y, color="darkorange", label="data")
43-
plt.plot(T, y_, color="navy", label="prediction")
45+
plt.scatter(X_train, y, color="darkorange", label="data")
46+
plt.plot(X_test, y_, color="navy", label="prediction")
4447
plt.axis("tight")
4548
plt.legend()
4649
plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))

sklearn/neighbors/_regression.py

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase):
4343
4444
Uniform weights are used by default.
4545
46+
See the following example for a demonstration of the impact of
47+
different weighting schemes on predictions:
48+
:ref:`sphx_glr_auto_examples_neighbors_plot_regression.py`.
49+
4650
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
4751
Algorithm used to compute the nearest neighbors:
4852

0 commit comments

Comments
 (0)
0