|
6 | 6 | Demonstrate the resolution of a regression problem
|
7 | 7 | using a k-Nearest Neighbor and the interpolation of the
|
8 | 8 | target using both barycenter and constant weights.
|
9 |
| -
|
10 | 9 | """
|
11 | 10 |
|
12 | 11 | # Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
13 | 12 | # Fabian Pedregosa <fabian.pedregosa@inria.fr>
|
14 | 13 | #
|
15 | 14 | # License: BSD 3 clause (C) INRIA
|
16 | 15 |
|
17 |
| - |
18 | 16 | # %%
|
19 | 17 | # Generate sample data
|
20 | 18 | # --------------------
|
| 19 | +# Here we generate a few data points to use to train the model. We also generate |
| 20 | +# data in the whole range of the training data to visualize how the model would |
| 21 | +# react in that whole region. |
21 | 22 | import matplotlib.pyplot as plt
|
22 | 23 | import numpy as np
|
23 | 24 |
|
24 | 25 | from sklearn import neighbors
|
25 | 26 |
|
26 |
| -np.random.seed(0) |
27 |
| -X = np.sort(5 * np.random.rand(40, 1), axis=0) |
28 |
| -T = np.linspace(0, 5, 500)[:, np.newaxis] |
29 |
| -y = np.sin(X).ravel() |
| 27 | +rng = np.random.RandomState(0) |
| 28 | +X_train = np.sort(5 * rng.rand(40, 1), axis=0) |
| 29 | +X_test = np.linspace(0, 5, 500)[:, np.newaxis] |
| 30 | +y = np.sin(X_train).ravel() |
30 | 31 |
|
31 | 32 | # Add noise to targets
|
32 | 33 | y[::5] += 1 * (0.5 - np.random.rand(8))
|
33 | 34 |
|
34 | 35 | # %%
|
35 | 36 | # Fit regression model
|
36 | 37 | # --------------------
|
| 38 | +# Here we train a model and visualize how `uniform` and `distance` |
| 39 | +# weights in prediction effect predicted values. |
37 | 40 | n_neighbors = 5
|
38 | 41 |
|
39 | 42 | for i, weights in enumerate(["uniform", "distance"]):
|
40 | 43 | knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
|
41 |
| - y_ = knn.fit(X, y).predict(T) |
| 44 | + y_ = knn.fit(X_train, y).predict(X_test) |
42 | 45 |
|
43 | 46 | plt.subplot(2, 1, i + 1)
|
44 |
| - plt.scatter(X, y, color="darkorange", label="data") |
45 |
| - plt.plot(T, y_, color="navy", label="prediction") |
| 47 | + plt.scatter(X_train, y, color="darkorange", label="data") |
| 48 | + plt.plot(X_test, y_, color="navy", label="prediction") |
46 | 49 | plt.axis("tight")
|
47 | 50 | plt.legend()
|
48 | 51 | plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))
|
|
0 commit comments