8000 MNT: Use GEMV in enet_coordinate_descent (Pt. 1) (#11896) · scikit-learn/scikit-learn@8d7ce8e · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d7ce8e

Browse files
jakirkhamrth
authored andcommitted
MNT: Use GEMV in enet_coordinate_descent (Pt. 1) (#11896)
1 parent 1c84b81 commit 8d7ce8e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

sklearn/linear_model/cd_fast.pyx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,18 @@ def enet_coordinate_descent(floating[::1] w,
159159
# fused types version of BLAS functions
160160
if floating is float:
161161
dtype = np.float32
162+
gemv = sgemv
162163
dot = sdot
163164
axpy = saxpy
164165
asum = sasum
166+
copy = scopy
165167
else:
166168
dtype = np.float64
169+
gemv = dgemv
167170
dot = ddot
168171
axpy = daxpy
169172
asum = dasum
173+
copy = dcopy
170174

171175
# get the data information into easy vars
172176
cdef unsigned int n_samples = X.shape[0]
@@ -205,8 +209,11 @@ def enet_coordinate_descent(floating[::1] w,
205209

206210
with nogil:
207211
# R = y - np.dot(X, w)
208-
for i in range(n_samples):
209-
R[i] = y[i] - dot(n_features, &X[i, 0], n_samples, &w[0], 1)
212+
copy(n_samples, &y[0], 1, &R[0], 1)
213+
gemv(CblasColMajor, CblasNoTrans,
214+
n_samples, n_features, -1.0, &X[0, 0], n_samples,
215+
&w[0], 1,
216+
1.0, &R[0], 1)
210217

211218
# tol *= np.dot(y, y)
212219
tol *= dot(n_samples, &y[0], 1, &y[0], 1)

0 commit comments

Comments
 (0)
0