From 548cafc513cb859d92d5530453efa61e01fecd3d Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 27 Mar 2022 13:41:40 +0200 Subject: [PATCH 1/5] ENH use more blas functions in cd solvers --- sklearn/linear_model/_cd_fast.pyx | 55 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index c64a464a7da9e..20e187f34bab6 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -17,10 +17,22 @@ from cython cimport floating import warnings from ..exceptions import ConvergenceWarning -from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2, - _copy, _scal) -from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans - +from ..utils._cython_blas cimport ( + ColMajor, + NoTrans, + Trans, + # BLAS Level 1 + _asum, + _axpy, + _copy, + _dot, + _nrm2, + _scal, + # BLAS Level 2 + _ger, + _gemv, + # BLAS Level 3 +) from ..utils._random cimport our_rand_r @@ -119,7 +131,9 @@ def enet_coordinate_descent(floating[::1] w, cdef unsigned int n_features = X.shape[1] # compute norms of the columns of X - cdef floating[::1] norm_cols_X = np.square(X).sum(axis=0) + # norm_cols_X = np.square(X).sum(axis=0) + # the following avoids large intermediate memory allocation + cdef floating[::1] norm_cols_X = np.einsum("ij,ij->j", X, X) # initial value of the residuals cdef floating[::1] R = np.empty(n_samples, dtype=dtype) @@ -657,9 +671,7 @@ def enet_coordinate_descent_gram(floating[::1] w, dual_norm_XtA = abs_max(n_features, XtA_ptr) # temp = np.sum(w * H) - tmp = 0.0 - for ii in range(n_features): - tmp += w[ii] * H[ii] + tmp = _dot(n_features, &w[0], 1, &H[0], 1) R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w # w_norm2 = np.dot(w, w) @@ -705,6 +717,11 @@ def enet_coordinate_descent_multi_task( 0.5 * norm(Y - X W.T, 2)^2 + l1_reg ||W.T||_21 + 0.5 * l2_reg norm(W.T, 2)^2 + Parameters + ---------- + W : F-condiguous ndarray of shape (n_tasks, n_features) + X : F-condiguous ndarray of shape (n_samples, n_features) + Y : F-condiguous ndarray of shape (n_samples, n_tasks) """ if floating is float: @@ -791,7 +808,7 @@ def enet_coordinate_descent_multi_task( # &X[0, ii], 1, # &w_ii[0], 1, &R[0, 0], n_tasks) # Using Blas Level1 and for loop to avoid slower threads - # for such small vectors + # for such small vectors (of size n_tasks) for jj in range(n_tasks): if w_ii[jj] != 0: _axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1, @@ -843,22 +860,38 @@ def enet_coordinate_descent_multi_task( # the tolerance: check the duality gap as ultimate stopping # criterion + # Using numpy: # XtA = np.dot(X.T, R) - l2_reg * W.T + # + # Using BLAS Level 3: + # floating[:, ::1] XtA = np.zeros((n_features, n_tasks), order="F") + # # np.dot(R.T, X) - l2_reg * W + # _copy(n_features * n_tasks, &W[0, 0], 1, &XtA[0, 0], 1) + # _gemm(ColMajor, Trans, NoTrans, n_tasks, n_features, n_samples, + # 1.0, &R[0, 0], n_samples, &X[0, 0], n_samples, -l2_reg, + # &XtA[0, 0], n_tasks) + # + # Using BLAS Level 2: + # for jj in range(n_tasks): + # # XtA[:, jj] = X.T @ R[:, jj] - l2_reg * W[jj, :] + # _gemv(ColMajor, Trans, n_samples, n_features, 1.0, &X[0, 0], + # n_samples, &R[0, jj], 1, -l2_reg, &XtA[jj, 0], n_tasks) + # Using BLAS Level 1: for ii in range(n_features): for jj in range(n_tasks): XtA[ii, jj] = _dot( n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 - ) - l2_reg * W[jj, ii] + ) - l2_reg * W[jj, ii] # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) dual_norm_XtA = 0.0 for ii in range(n_features): # np.sqrt(np.sum(XtA ** 2, axis=1)) + # sum is over tasks XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1) if XtA_axis1norm > dual_norm_XtA: dual_norm_XtA = XtA_axis1norm - # TODO: use squared L2 norm directly # R_norm = linalg.norm(R, ord='fro') # w_norm = linalg.norm(W, ord='fro') R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1) From 7623065e09430a906911a83056a6288160ba00dd Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 28 Mar 2022 22:12:57 +0200 Subject: [PATCH 2/5] ENH use gemm for R.T @ X --- sklearn/linear_model/_cd_fast.pyx | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 20e187f34bab6..27003e2f4a008 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -32,6 +32,7 @@ from ..utils._cython_blas cimport ( _ger, _gemv, # BLAS Level 3 + _gemm, ) from ..utils._random cimport our_rand_r @@ -735,7 +736,7 @@ def enet_coordinate_descent_multi_task( cdef unsigned int n_tasks = Y.shape[1] # to store XtA - cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype) + cdef floating[::1, :] XtA = np.zeros((n_tasks, n_features), dtype=dtype, order="F") cdef floating XtA_axis1norm cdef floating dual_norm_XtA @@ -861,34 +862,33 @@ def enet_coordinate_descent_multi_task( # criterion # Using numpy: - # XtA = np.dot(X.T, R) - l2_reg * W.T + # XtA = np.dot(R.T, X) - l2_reg * W # # Using BLAS Level 3: - # floating[:, ::1] XtA = np.zeros((n_features, n_tasks), order="F") - # # np.dot(R.T, X) - l2_reg * W - # _copy(n_features * n_tasks, &W[0, 0], 1, &XtA[0, 0], 1) - # _gemm(ColMajor, Trans, NoTrans, n_tasks, n_features, n_samples, - # 1.0, &R[0, 0], n_samples, &X[0, 0], n_samples, -l2_reg, - # &XtA[0, 0], n_tasks) - # + # np.dot(R.T, X) + _gemm(ColMajor, Trans, NoTrans, n_tasks, n_features, n_samples, 1.0, + &R[0, 0], n_samples, &X[0, 0], n_samples, 0.0, &XtA[0, 0], + n_tasks) + # XtA -= l2_reg * W + _axpy(n_features * n_tasks, -l2_reg, &W[0, 0], 1 ,&XtA[0, 0], 1) # Using BLAS Level 2: # for jj in range(n_tasks): - # # XtA[:, jj] = X.T @ R[:, jj] - l2_reg * W[jj, :] + # # XtA[jj, :] = X.T @ R[:, jj] - l2_reg * W[jj, :] # _gemv(ColMajor, Trans, n_samples, n_features, 1.0, &X[0, 0], # n_samples, &R[0, jj], 1, -l2_reg, &XtA[jj, 0], n_tasks) # Using BLAS Level 1: - for ii in range(n_features): - for jj in range(n_tasks): - XtA[ii, jj] = _dot( - n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 - ) - l2_reg * W[jj, ii] + # for ii in range(n_features): + # for jj in range(n_tasks): + # XtA[jj, ii] = _dot( + # n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 + # ) - l2_reg * W[jj, ii] - # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) + # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=0))) dual_norm_XtA = 0.0 for ii in range(n_features): - # np.sqrt(np.sum(XtA ** 2, axis=1)) + # np.sqrt(np.sum(XtA ** 2, axis=0)) # sum is over tasks - XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1) + XtA_axis1norm = _nrm2(n_tasks, &XtA[0, ii], 1) if XtA_axis1norm > dual_norm_XtA: dual_norm_XtA = XtA_axis1norm From ddcbacb599f37da4fb5c37c2bcc0368f20d8d773 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 29 Mar 2022 18:45:11 +0200 Subject: [PATCH 3/5] CLN remove comments and reinstate TODO --- sklearn/linear_model/_cd_fast.pyx | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 27003e2f4a008..dd3c53d5db6f3 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -863,25 +863,13 @@ def enet_coordinate_descent_multi_task( # Using numpy: # XtA = np.dot(R.T, X) - l2_reg * W - # # Using BLAS Level 3: - # np.dot(R.T, X) + # XtA = np.dot(R.T, X) _gemm(ColMajor, Trans, NoTrans, n_tasks, n_features, n_samples, 1.0, &R[0, 0], n_samples, &X[0, 0], n_samples, 0.0, &XtA[0, 0], n_tasks) # XtA -= l2_reg * W _axpy(n_features * n_tasks, -l2_reg, &W[0, 0], 1 ,&XtA[0, 0], 1) - # Using BLAS Level 2: - # for jj in range(n_tasks): - # # XtA[jj, :] = X.T @ R[:, jj] - l2_reg * W[jj, :] - # _gemv(ColMajor, Trans, n_samples, n_features, 1.0, &X[0, 0], - # n_samples, &R[0, jj], 1, -l2_reg, &XtA[jj, 0], n_tasks) - # Using BLAS Level 1: - # for ii in range(n_features): - # for jj in range(n_tasks): - # XtA[jj, ii] = _dot( - # n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 - # ) - l2_reg * W[jj, ii] # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=0))) dual_norm_XtA = 0.0 @@ -892,6 +880,7 @@ def enet_coordinate_descent_multi_task( if XtA_axis1norm > dual_norm_XtA: dual_norm_XtA = XtA_axis1norm + # TODO: use squared L2 norm directly # R_norm = linalg.norm(R, ord='fro') # w_norm = linalg.norm(W, ord='fro') R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1) From fb9c04b5aaa3c2ccbcf697f0da01ee66fd1c6bb2 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 29 Mar 2022 18:46:16 +0200 Subject: [PATCH 4/5] MNT rename XtA_axis1norm to XtA_axis0norm --- sklearn/linear_model/_cd_fast.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index dd3c53d5db6f3..07a9d85c61348 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -737,7 +737,7 @@ def enet_coordinate_descent_multi_task( # to store XtA cdef floating[::1, :] XtA = np.zeros((n_tasks, n_features), dtype=dtype, order="F") - cdef floating XtA_axis1norm + cdef floating XtA_axis0norm cdef floating dual_norm_XtA # initial value of the residuals @@ -876,9 +876,9 @@ def enet_coordinate_descent_multi_task( for ii in range(n_features): # np.sqrt(np.sum(XtA ** 2, axis=0)) # sum is over tasks - XtA_axis1norm = _nrm2(n_tasks, &XtA[0, ii], 1) - if XtA_axis1norm > dual_norm_XtA: - dual_norm_XtA = XtA_axis1norm + XtA_axis0norm = _nrm2(n_tasks, &XtA[0, ii], 1) + if XtA_axis0norm > dual_norm_XtA: + dual_norm_XtA = XtA_axis0norm # TODO: use squared L2 norm directly # R_norm = linalg.norm(R, ord='fro') From c73926966f54061db45eaf512f7c016a05162940 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 13 Apr 2022 20:51:32 +0200 Subject: [PATCH 5/5] CLN review comments --- sklearn/linear_model/_cd_fast.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 07a9d85c61348..64fa261fb2f37 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -720,9 +720,9 @@ def enet_coordinate_descent_multi_task( Parameters ---------- - W : F-condiguous ndarray of shape (n_tasks, n_features) - X : F-condiguous ndarray of shape (n_samples, n_features) - Y : F-condiguous ndarray of shape (n_samples, n_tasks) + W : F-contiguous ndarray of shape (n_tasks, n_features) + X : F-contiguous ndarray of shape (n_samples, n_features) + Y : F-contiguous ndarray of shape (n_samples, n_tasks) """ if floating is float: @@ -862,9 +862,9 @@ def enet_coordinate_descent_multi_task( # criterion # Using numpy: - # XtA = np.dot(R.T, X) - l2_reg * W + # XtA = np.dot(R.T, X) - l2_reg * W # Using BLAS Level 3: - # XtA = np.dot(R.T, X) + # XtA = np.dot(R.T, X) _gemm(ColMajor, Trans, NoTrans, n_tasks, n_features, n_samples, 1.0, &R[0, 0], n_samples, &X[0, 0], n_samples, 0.0, &XtA[0, 0], n_tasks)