8000 address comments from @thomasjpfan · scikit-learn/scikit-learn@096b444 · GitHub
[go: up one dir, main page]

Skip to content

Commit 096b444

Browse files
committed
address comments from @thomasjpfan
1 parent 6cafc91 commit 096b444

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

sklearn/linear_model/ridge.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -961,8 +961,8 @@ class _X_operator(sparse.linalg.LinearOperator):
961961
"""
962962

963963
def __init__(self, X, X_mean, sqrt_sw):
964-
self.n_samples, self.n_features = X.shape
965-
super().__init__(X.dtype, (self.n_samples, self.n_features + 1))
964+
n_samples, n_features = X.shape
965+
super().__init__(X.dtype, (n_samples, n_features + 1))
966966
self.X = X
967967
self.X_mean = X_mean
968968
self.sqrt_sw = sqrt_sw
@@ -991,15 +991,16 @@ class _Xt_operator(sparse.linalg.LinearOperator):
991991
"""
992992

993993
def __init__(self, X, X_mean, sqrt_sw):
994-
self.n_samples, self.n_features = X.shape
995-
super().__init__(X.dtype, (self.n_features + 1, self.n_samples))
994+
n_samples, n_features = X.shape
995+
super().__init__(X.dtype, (n_features + 1, n_samples))
996996
self.X = X
997997
self.X_mean = X_mean
998998
self.sqrt_sw = sqrt_sw
999999

10001000
def _matvec(self, v):
10011001
v = v.ravel()
1002-
res = np.empty(self.n_features + 1)
1002+
n_features = self.shape[0]
1003+
res = np.empty(n_features)
10031004
res[:-1] = (
10041005
safe_sparse_dot(self.X.T, v, dense_output=True) -
10051006
(self.X_mean * self.sqrt_sw.dot(v))
@@ -1008,7 +1009,8 @@ def _matvec(self, v):
10081009
return res
10091010

10101011
def _matmat(self, v):
1011-
res = np.empty((self.n_features + 1, v.shape[1]))
1012+
n_features = self.shape[0]
1013+
res = np.empty((n_features, v.shape[1]))
10121014
res[:-1] = (
10131015
safe_sparse_dot(self.X.T, v, dense_output=True) -
10141016
self.X_mean[:, None] * self.sqrt_sw.dot(v)
@@ -1119,8 +1121,9 @@ def _compute_gram(self, X, sqrt_sw):
11191121
X_mean *= n_samples / sqrt_sw.dot(sqrt_sw)
11201122
X_mX = sqrt_sw[:, None] * safe_sparse_dot(
11211123
X_mean, X.T, dense_output=True)
1122-
X_mX_m = np.empty((n_samples, n_samples), dtype=X.dtype)
1123-
X_mX_m[:, :] = np.dot(X_mean, X_mean)
1124+
X_mX_m = np.full((n_samples, n_samples),
1125+
fill_value=np.dot(X_mean, X_mean),
1126+
dtype=X.dtype)
11241127
X_mX_m *= sqrt_sw
11251128
X_mX_m *= sqrt_sw[:, None]
11261129
return (safe_sparse_dot(X, X.T, dense_output=True) + X_mX_m
@@ -1275,7 +1278,6 @@ def _solve_covariance_sparse_no_intercept(
12751278
Used when we have a decomposition of X^T.X
12761279
(n_features < n_samples and X is sparse), and not fitting an intercept.
12771280
"""
1278-
n_samples, n_features = X.shape
12791281
w = 1 / (s + alpha)
12801282
A = (V * w).dot(V.T)
12811283
AXy = A.dot(safe_sparse_dot(X.T, y, dense_output=True))
@@ -1294,7 +1296,6 @@ def _solve_covariance_sparse_intercept(
12941296
(n_features < n_samples and X is sparse),
12951297
and we are fitting an intercept.
12961298
"""
1297-
n_samples, n_features = X.shape
12981299
# the vector [0, 0, ..., 0, 1]
12991300
# is the eigenvector of X^TX which
13001301
# corresponds to the intercept; we cancel the regularization on

0 commit comments

Comments
 (0)
0