8000 BUG fixed sample weights linearregression · scikit-learn/scikit-learn@c1c35d0 · GitHub
[go: up one dir, main page]

Skip to content

Commit c1c35d0

Browse files
author
giorgiop
committed
BUG fixed sample weights linearregression
1 parent 861ac13 commit c1c35d0

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

sklearn/linear_model/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,14 @@ def fit(self, X, y, sample_weight=None):
436436
sample_weight).ndim > 1):
437437
sample_weight = column_or_1d(sample_weight, warn=True)
438438

439-
X, y, X_mean, y_mean, X_std = self._center_data(
440-
X, y, self.fit_intercept, self.normalize, self.copy_X,
441-
sample_weight=sample_weight)
442-
443439
if sample_weight is not None:
444440
# Sample weight can be implemented via a simple rescaling.
445441
X, y = _rescale_data(X, y, sample_weight)
446442

443+
X, y, X_mean, y_mean, X_std = self._center_data(
444+
X, y, self.fit_intercept, self.normalize, self.copy_X,
445+
sample_weight=None)
446+
447447
if sp.issparse(X):
448448
if y.ndim < 2:
449449
out = sparse_lsqr(X, y)

sklearn/linear_model/tests/test_base.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.utils.testing import assert_greater
1717
from sklearn.datasets.samples_generator import make_sparse_uncorrelated
1818
from sklearn.datasets.samples_generator import make_regression
19+
from sklearn.metrics import r2_score
1920

2021

2122
def test_linear_regression():
@@ -44,28 +45,33 @@ def test_linear_regression():
4445

4546
def test_linear_regression_sample_weights():
4647
rng = np.random.RandomState(0)
48+
n_samples, n_features = 6, 50 # over-determined system
4749

48-
for n_samples, n_features in ((6, 5), (5, 10)):
50+
for fit_intercept in [True, False]:
4951
y = rng.randn(n_samples)
5052
X = rng.randn(n_samples, n_features)
5153
sample_weight = 1.0 + rng.rand(n_samples)
5254

53-
clf = LinearRegression()
54-
clf.fit(X, y, sample_weight)
55-
coefs1 = clf.coef_
55+
reg = LinearRegression(fit_intercept=fit_intercept)
56+
reg.fit(X, y, sample_weight=sample_weight)
57+
coefs1 = reg.coef_
58+
intercept1 = reg.intercept_
5659

57-
assert_equal(clf.coef_.shape, (X.shape[1], ))
58-
assert_greater(clf.score(X, y), 0.9)
59-
assert_array_almost_equal(clf.predict(X), y)
60+
assert_equal(reg.coef_.shape, (X.shape[1], ))
61+
assert_greater(reg.score(X, y, sample_weight=sample_weight), 0.9)
62+
assert_greater(r2_score(y, reg.predict(X),
63+
sample_weight=sample_weight), 0.9) # same as above
6064

6165
# Sample weight can be implemented via a simple rescaling
6266
# for the square loss.
6367
scaled_y = y * np.sqrt(sample_weight)
6468
scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
65-
clf.fit(X, y)
66-
coefs2 = clf.coef_
69+
reg.fit(scaled_X, scaled_y)
70+
coefs2 = reg.coef_
71+
intercept2 = reg.intercept_
6772

6873
assert_array_almost_equal(coefs1, coefs2)
74+
assert_array_almost_equal(intercept1, intercept2)
6975

7076

7177
def test_raises_value_error_if_sample_weights_greater_than_1d():
@@ -321,4 +327,3 @@ def test_rescale_data():
321327
rescaled_y2 = y * np.sqrt(sample_weight)
322328
assert_array_almost_equal(rescaled_X, rescaled_X2)
323329
assert_array_almost_equal(rescaled_y, rescaled_y2)
324-

0 commit comments

Comments
 (0)
0