8000 DOC Add link to plot_regression.py (#29232) · glemaitre/scikit-learn@628a9e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 628a9e3

Browse files
craetonaadrinjalali
authored andcommitted
DOC Add link to plot_regression.py (scikit-learn#29232)
Co-authored-by: adrinjalali <adrin.jalali@gmail.com>
1 parent 1f633c9 commit 628a9e3

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

examples/neighbors/plot_regression.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,46 @@
66
Demonstrate the resolution of a regression problem
77
using a k-Nearest Neighbor and the interpolation of the
88
target using both barycenter and constant weights.
9-
109
"""
1110

1211
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
1312
# Fabian Pedregosa <fabian.pedregosa@inria.fr>
1413
#
1514
# License: BSD 3 clause (C) INRIA
1615

17-
1816
# %%
1917
# Generate sample data
2018
# --------------------
19+
# Here we generate a few data points to use to train the model. We also generate
20+
# data in the whole range of the training data to visualize how the model would
21+
# react in that whole region.
2122
import matplotlib.pyplot as plt
2223
import numpy as np
2324

2425
from sklearn import neighbors
2526

26-
np.random.seed(0)
27-
X = np.sort(5 * np.random.rand(40, 1), axis=0)
28-
T = np.linspace(0, 5, 500)[:, np.newaxis]
29-
y = np.sin(X).ravel()
27+
rng = np.random.RandomState(0)
28+
X_train = np.sort(5 * rng.rand(40, 1), axis=0)
29+
X_test = np.linspace(0, 5, 500)[:, np.newaxis]
30+
y = np.sin(X_train).ravel()
3031

3132
# Add noise to targets
3233
y[::5] += 1 * (0.5 - np.random.rand(8))
3334

3435
# %%
3536
# Fit regression model
3637
# --------------------
38+
# Here we train a model and visualize how `uniform` and `distance`
39+
# weights in prediction effect predicted values.
3740
n_neighbors = 5
3841

3942
for i, weights in enumerate(["uniform", "distance"]):
4043
knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
41-
y_ = knn.fit(X, y).predict(T)
44+
y_ = knn.fit(X_train, y).predict(X_test)
4245

4346
plt.subplot(2, 1, i + 1)
44-
plt.scatter(X, y, color="darkorange", label="data")
45-
plt.plot(T, y_, color="navy", label="prediction")
47+
plt.scatter(X_train, y, color="darkorange", label="data")
48+
plt.plot(X_test, y_, color="navy", label="prediction")
4649
plt.axis("tight")
4750
plt.legend()
4851
plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))

sklearn/neighbors/_regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class KNeighborsRegressor(KNeighborsMixin, RegressorMixin, NeighborsBase):
4949
5050
Uniform weights are used by default.
5151
52+
See the following example for a demonstration of the impact of
53+
different weighting schemes on predictions:
54+
:ref:`sphx_glr_auto_examples_neighbors_plot_regression.py`.
55+
5256
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
5357
Algorithm used to compute the nearest neighbors:
5458

0 commit comments

Comments
 (0)
0