8000 fixed old test and linearregression weights · scikit-learn/scikit-learn@66633b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 66633b8

Browse files
author
giorgiop
committed
fixed old test and linearregression weights
1 parent 861ac13 commit 66633b8

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

sklearn/linear_model/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,14 @@ def fit(self, X, y, sample_weight=None):
436436
sample_weight).ndim > 1):
437437
sample_weight = column_or_1d(sample_weight, warn=True)
438438

439-
X, y, X_mean, y_mean, X_std = self._center_data(
440-
X, y, self.fit_intercept, self.normalize, self.copy_X,
441-
sample_weight=sample_weight)
442-
443439
if sample_weight is not None:
444440
# Sample weight can be implemented via a simple rescaling.
445441
X, y = _rescale_data(X, y, sample_weight)
446442

443+
X, y, X_mean, y_mean, X_std = self._center_data(
444+
X, y, self.fit_intercept, self.normalize, self.copy_X,
445+
sample_weight=None)
446+
447447
if sp.issparse(X):
448448
if y.ndim < 2:
449449
out = sparse_lsqr(X, y)

sklearn/linear_model/tests/test_base.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.utils.testing import assert_greater
1717
from sklearn.datasets.samples_generator import make_sparse_uncorrelated
1818
from sklearn.datasets.samples_generator import make_regression
19+
from sklearn.metrics import r2_score
1920

2021

2122
def test_linear_regression():
@@ -44,26 +45,28 @@ def test_linear_regression():
4445

4546
def test_linear_regression_sample_weights():
4647
rng = np.random.RandomState(0)
48+
n_samples, n_features = 6, 50 # over-determined system
4749

48-
for n_samples, n_features in ((6, 5), (5, 10)):
50+
for fit_intercept in [True, False]:
4951
y = rng.randn(n_samples)
5052
X = rng.randn(n_samples, n_features)
5153
sample_weight = 1.0 + rng.rand(n_samples)
5254

53-
clf = LinearRegression()
54-
clf.fit(X, y, sample_weight)
55-
coefs1 = clf.coef_
55+
reg = LinearRegression(fit_intercept=fit_intercept)
56+
reg.fit(X, y, sample_weight=sample_weight)
57+
coefs1 = reg.coef_
5658

57-
assert_equal(clf.coef_.shape, (X.shape[1], ))
58-
assert_greater(clf.score(X, y), 0.9)
59-
assert_array_almost_equal(clf.predict(X), y)
59+
assert_equal(reg.coef_.shape, (X.shape[1], ))
60+
assert_greater(reg.score(X, y, sample_weight=sample_weight), 0.9)
61+
assert_greater(r2_score(y, reg.predict(X),
62+
sample_weight=sample_weight), 0.9) # same as above
6063

6164
# Sample weight can be implemented via a simple rescaling
6265
# for the square loss.
6366
scaled_y = y * np.sqrt(sample_weight)
6467
scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
65-
clf.fit(X, y)
66-
coefs2 = clf.coef_
68+
reg.fit(scaled_X, scaled_y)
69+
coefs2 = reg.coef_
6770

6871
assert_array_almost_equal(coefs1, coefs2)
6972

@@ -321,4 +324,3 @@ def test_rescale_data():
321324
rescaled_y2 = y * np.sqrt(sample_weight)
322325
assert_array_almost_equal(rescaled_X, rescaled_X2)
323326
assert_array_almost_equal(rescaled_y, rescaled_y2)
324-

0 commit comments

Comments
 (0)
0