|
16 | 16 | from sklearn.utils.testing import assert_greater
|
17 | 17 | from sklearn.datasets.samples_generator import make_sparse_uncorrelated
|
18 | 18 | from sklearn.datasets.samples_generator import make_regression
|
| 19 | +from sklearn.metrics import r2_score |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def test_linear_regression():
|
@@ -44,28 +45,33 @@ def test_linear_regression():
|
44 | 45 |
|
45 | 46 | def test_linear_regression_sample_weights():
|
46 | 47 | rng = np.random.RandomState(0)
|
| 48 | + n_samples, n_features = 6, 50 # over-determined system |
47 | 49 |
|
48 |
| - for n_samples, n_features in ((6, 5), (5, 10)): |
| 50 | + for fit_intercept in [True, False]: |
49 | 51 | y = rng.randn(n_samples)
|
50 | 52 | X = rng.randn(n_samples, n_features)
|
51 | 53 | sample_weight = 1.0 + rng.rand(n_samples)
|
52 | 54 |
|
53 |
| - clf = LinearRegression() |
54 |
| - clf.fit(X, y, sample_weight) |
55 |
| - coefs1 = clf.coef_ |
| 55 | + reg = LinearRegression(fit_intercept=fit_intercept) |
| 56 | + reg.fit(X, y, sample_weight=sample_weight) |
| 57 | + coefs1 = reg.coef_ |
| 58 | + intercept1 = reg.intercept_ |
56 | 59 |
|
57 |
| - assert_equal(clf.coef_.shape, (X.shape[1], )) |
58 |
| - assert_greater(clf.score(X, y), 0.9) |
59 |
| - assert_array_almost_equal(clf.predict(X), y) |
| 60 | + assert_equal(reg.coef_.shape, (X.shape[1], )) |
| 61 | + assert_greater(reg.score(X, y, sample_weight=sample_weight), 0.9) |
| 62 | + assert_greater(r2_score(y, reg.predict(X), |
| 63 | + sample_weight=sample_weight), 0.9) # same as above |
60 | 64 |
|
61 | 65 | # Sample weight can be implemented via a simple rescaling
|
62 | 66 | # for the square loss.
|
63 | 67 | scaled_y = y * np.sqrt(sample_weight)
|
64 | 68 | scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
|
65 |
| - clf.fit(X, y) |
66 |
| - coefs2 = clf.coef_ |
| 69 | + reg.fit(scaled_X, scaled_y) |
| 70 | + coefs2 = reg.coef_ |
| 71 | + intercept2 = reg.intercept_ |
67 | 72 |
|
68 | 73 | assert_array_almost_equal(coefs1, coefs2)
|
| 74 | + assert_array_almost_equal(intercept1, intercept2) |
69 | 75 |
|
70 | 76 |
|
71 | 77 | def test_raises_value_error_if_sample_weights_greater_than_1d():
|
@@ -321,4 +327,3 @@ def test_rescale_data():
|
321 | 327 | rescaled_y2 = y * np.sqrt(sample_weight)
|
322 | 328 | assert_array_almost_equal(rescaled_X, rescaled_X2)
|
323 | 329 | assert_array_almost_equal(rescaled_y, rescaled_y2)
|
324 |
| - |
|
0 commit comments