|
6 | 6 | Demonstrate the resolution of a regression problem
|
7 | 7 | using a k-Nearest Neighbor and the interpolation of the
|
8 | 8 | target using both barycenter and constant weights.
|
9 |
| -
|
10 | 9 | """
|
11 | 10 |
|
12 | 11 | # Authors: The scikit-learn developers
|
13 | 12 | # SPDX-License-Identifier: BSD-3-Clause
|
14 | 13 |
|
15 |
| - |
16 | 14 | # %%
|
17 | 15 | # Generate sample data
|
18 | 16 | # --------------------
|
| 17 | +# Here we generate a few data points to use to train the model. We also generate |
| 18 | +# data in the whole range of the training data to visualize how the model would |
| 19 | +# react in that whole region. |
19 | 20 | import matplotlib.pyplot as plt
|
20 | 21 | import numpy as np
|
21 | 22 |
|
22 | 23 | from sklearn import neighbors
|
23 | 24 |
|
24 |
| -np.random.seed(0) |
25 |
| -X = np.sort(5 * np.random.rand(40, 1), axis=0) |
26 |
| -T = np.linspace(0, 5, 500)[:, np.newaxis] |
27 |
| -y = np.sin(X).ravel() |
| 25 | +rng = np.random.RandomState(0) |
| 26 | +X_train = np.sort(5 * rng.rand(40, 1), axis=0) |
| 27 | +X_test = np.linspace(0, 5, 500)[:, np.newaxis] |
| 28 | +y = np.sin(X_train).ravel() |
28 | 29 |
|
29 | 30 | # Add noise to targets
|
30 | 31 | y[::5] += 1 * (0.5 - np.random.rand(8))
|
31 | 32 |
|
32 | 33 | # %%
|
33 | 34 | # Fit regression model
|
34 | 35 | # --------------------
|
| 36 | +# Here we train a model and visualize how `uniform` and `distance` |
| 37 | +# weights in prediction effect predicted values. |
35 | 38 | n_neighbors = 5
|
36 | 39 |
|
37 | 40 | for i, weights in enumerate(["uniform", "distance"]):
|
38 | 41 | knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
|
39 |
| - y_ = knn.fit(X, y).predict(T) |
8000
td> | 42 | + y_ = knn.fit(X_train, y).predict(X_test) |
40 | 43 |
|
41 | 44 | plt.subplot(2, 1, i + 1)
|
42 |
| - plt.scatter(X, y, color="darkorange", label="data") |
43 |
| - plt.plot(T, y_, color="navy", label="prediction") |
| 45 | + plt.scatter(X_train, y, color="darkorange", label="data") |
| 46 | + plt.plot(X_test, y_, color="navy", label="prediction") |
44 | 47 | plt.axis("tight")
|
45 | 48 | plt.legend()
|
46 | 49 | plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))
|
|
0 commit comments