@@ -122,14 +122,25 @@ cdef extern from "cblas.h":
122
122
void dger " cblas_dger" (CBLAS_ORDER Order, int M, int N, double alpha,
123
123
double * X, int incX, double * Y, int incY,
124
124
double * A, int lda) nogil
125
+ void sger " cblas_sger" (CBLAS_ORDER Order, int M, int N, float alpha,
126
+ float * X, int incX, float * Y, int incY,
127
+ float * A, int lda) nogil
125
128
void dgemv " cblas_dgemv" (CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
126
129
int M, int N, double alpha, double * A, int lda,
127
130
double * X, int incX, double beta,
128
131
double * Y, int incY) nogil
132
+ void sgemv " cblas_sgemv" (CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
133
+ int M, int N, float alpha, float * A, int lda,
134
+ float * X, int incX, float beta,
135
+ float * Y, int incY) nogil
129
136
double dnrm2 " cblas_dnrm2" (int N, double * X, int incX) nogil
137
+ float snrm2 " cblas_snrm2" (int N, float * X, int incX) nogil
130
138
void dcopy " cblas_dcopy" (int N, double * X, int incX, double * Y,
131
139
int incY) nogil
140
+ void scopy " cblas_scopy" (int N, float * X, int incX, float * Y,
141
+ int incY) nogil
132
142
void dscal " cblas_dscal" (int N, double alpha, double * X, int incX) nogil
143
+ void sscal " cblas_sscal" (int N, float alpha, float * X, int incX) nogil
133
144
134
145
135
146
@ cython.boundscheck (False )
@@ -686,11 +697,11 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
686
697
@ cython.boundscheck (False )
687
698
@ cython.wraparound (False )
688
699
@ cython.cdivision (True )
689
- def enet_coordinate_descent_multi_task (double [::1, :] W , double l1_reg ,
690
- double l2_reg ,
691
- np.ndarray[double , ndim = 2 , mode = ' fortran' ] X,
692
- np.ndarray[double , ndim = 2 ] Y,
693
- int max_iter , double tol , object rng ,
700
+ def enet_coordinate_descent_multi_task (floating [::1, :] W , floating l1_reg ,
701
+ floating l2_reg ,
702
+ np.ndarray[floating , ndim = 2 , mode = ' fortran' ] X,
703
+ np.ndarray[floating , ndim = 2 ] Y,
704
+ int max_iter , floating tol , object rng ,
694
705
bint random = 0 ):
695
706
""" Cython version of the coordinate descent algorithm
696
707
for Elastic-Net mult-task regression
@@ -700,42 +711,68 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
700
711
(1/2) * norm(y - X w, 2)^2 + l1_reg ||w||_21 + (1/2) * l2_reg norm(w, 2)^2
701
712
702
713
"""
714
+ # fused types version of BLAS functions
715
+ cdef DOT dot
716
+ cdef AXPY axpy
717
+ cdef ASUM asum
718
+
719
+ if floating is float :
720
+ dtype = np.float32
721
+ dot = sdot
722
+ nrm2 = snrm2
723
+ asum = sasum
724
+ copy = scopy
725
+ scal = sscal
726
+ ger = sger
727
+ gemv = sgemv
728
+ else :
729
+ dtype = np.float64
730
+ dot = ddot
731
+ nrm2 = dnrm2
732
+ asum = dasum
733
+ copy = dcopy
734
+ scal = dscal
735
+ ger = dger
736
+ gemv = dgemv
737
+
703
738
# get the data information into easy vars
704
739
cdef unsigned int n_samples = X.shape[0 ]
705
740
cdef unsigned int n_features = X.shape[1 ]
706
741
cdef unsigned int n_tasks = Y.shape[1 ]
707
742
708
743
# to store XtA
709
- cdef double [:, ::1 ] XtA = np.zeros((n_features, n_tasks))
710
- cdef double XtA_axis1norm
711
- cdef double dual_norm_XtA
744
+ cdef floating [:, ::1 ] XtA = np.zeros((n_features, n_tasks), dtype = dtype )
745
+ cdef floating XtA_axis1norm
746
+ cdef floating dual_norm_XtA
712
747
713
748
# initial value of the residuals
714
- cdef double [:, ::1 ] R = np.zeros((n_samples, n_tasks))
715
-
716
- cdef double [:] norm_cols_X = np.zeros(n_features)
717
- cdef double [::1 ] tmp = np.zeros(n_tasks, dtype = np.float)
718
- cdef double [:] w_ii = np.zeros(n_tasks, dtype = np.float)
719
- cdef double d_w_max
720
- cdef double w_max
721
- cdef double d_w_ii
722
- cdef double nn
723
- cdef double W_ii_abs_max
724
- cdef double gap = tol + 1.0
725
- cdef double d_w_tol = tol
726
- cdef double ry_sum
727
- cdef double l21_norm
749
+ cdef floating[:, ::1 ] R = np.zeros((n_samples, n_tasks), dtype = dtype)
750
+
751
+ cdef floating[:] norm_cols_X = np.zeros(n_features, dtype = dtype)
752
+ cdef floating[::1 ] tmp = np.zeros(n_tasks, dtype = dtype)
753
+ cdef floating[:] w_ii = np.zeros(n_tasks, dtype = dtype)
754
+ cdef floating d_w_max
755
+ cdef floating w_max
756
+ cdef floating d_w_ii
757
+ cdef floating nn
758
+ cdef floating W_ii_abs_max
759
+ cdef floating gap = tol + 1.0
760
+ cdef floating d_w_tol = tol
761
+ cdef floating R_norm
762
+ cdef floating w_norm
763
+ cdef floating ry_sum
764
+ cdef floating l21_norm
728
765
cdef unsigned int ii
729
766
cdef unsigned int jj
730
767
cdef unsigned int n_iter = 0
731
768
cdef unsigned int f_iter
732
769
cdef UINT32_t rand_r_state_seed = rng.randint(0 , RAND_R_MAX)
733
770
cdef UINT32_t* rand_r_state = & rand_r_state_seed
734
771
735
- cdef double * X_ptr = & X[0 , 0 ]
736
- cdef double * W_ptr = & W[0 , 0 ]
737
- cdef double * Y_ptr = & Y[0 , 0 ]
738
- cdef double * wii_ptr = & w_ii[0 ]
772
+ cdef floating * X_ptr = & X[0 , 0 ]
773
+ cdef floating * W_ptr = & W[0 , 0 ]
774
+ cdef floating * Y_ptr = & Y[0 , 0 ]
775
+ cdef floating * wii_ptr = & w_ii[0 ]
739
776
740
777
if l1_reg == 0 :
741
778
warnings.warn(" Coordinate descent with l1_reg=0 may lead to unexpected"
@@ -751,11 +788,11 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
751
788
for ii in range (n_samples):
752
789
for jj in range (n_tasks):
753
790
R[ii, jj] = Y[ii, jj] - (
754
- ddot (n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks)
791
+ dot (n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks)
755
792
)
756
793
757
794
# tol = tol * linalg.norm(Y, ord='fro') ** 2
758
- tol = tol * dnrm2 (n_samples * n_tasks, Y_ptr, 1 ) ** 2
795
+ tol = tol * nrm2 (n_samples * n_tasks, Y_ptr, 1 ) ** 2
759
796
760
797
for n_iter in range (max_iter):
761
798
w_max = 0.0
@@ -770,33 +807,33 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
770
807
continue
771
808
772
809
# w_ii = W[:, ii] # Store previous value
773
- dcopy (n_tasks, W_ptr + ii * n_tasks, 1 , wii_ptr, 1 )
810
+ copy (n_tasks, W_ptr + ii * n_tasks, 1 , wii_ptr, 1 )
774
811
775
812
# if np.sum(w_ii ** 2) != 0.0: # can do better
776
- if dnrm2 (n_tasks, wii_ptr, 1 ) != 0.0 :
813
+ if nrm2 (n_tasks, wii_ptr, 1 ) != 0.0 :
777
814
# R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
778
- dger (CblasRowMajor, n_samples, n_tasks, 1.0 ,
815
+ ger (CblasRowMajor, n_samples, n_tasks, 1.0 ,
779
816
X_ptr + ii * n_samples, 1 ,
780
817
wii_ptr, 1 , & R[0 , 0 ], n_tasks)
781
818
782
819
# tmp = np.dot(X[:, ii][None, :], R).ravel()
783
- dgemv (CblasRowMajor, CblasTrans,
820
+ gemv (CblasRowMajor, CblasTrans,
784
821
n_samples, n_tasks, 1.0 , & R[0 , 0 ], n_tasks,
785
822
X_ptr + ii * n_samples, 1 , 0.0 , & tmp[0 ], 1 )
786
823
787
824
# nn = sqrt(np.sum(tmp ** 2))
788
- nn = dnrm2 (n_tasks, & tmp[0 ], 1 )
825
+ nn = nrm2 (n_tasks, & tmp[0 ], 1 )
789
826
790
827
# W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg)
791
- dcopy (n_tasks, & tmp[0 ], 1 , W_ptr + ii * n_tasks, 1 )
792
- dscal (n_tasks, fmax(1. - l1_reg / nn, 0 ) / (norm_cols_X[ii] + l2_reg),
828
+ copy (n_tasks, & tmp[0 ], 1 , W_ptr + ii * n_tasks, 1 )
829
+ scal (n_tasks, fmax(1. - l1_reg / nn, 0 ) / (norm_cols_X[ii] + l2_reg),
793
830
W_ptr + ii * n_tasks, 1 )
794
831
795
832
# if np.sum(W[:, ii] ** 2) != 0.0: # can do better
796
- if dnrm2 (n_tasks, W_ptr + ii * n_tasks, 1 ) != 0.0 :
833
+ if nrm2 (n_tasks, W_ptr + ii * n_tasks, 1 ) != 0.0 :
797
834
# R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
798
835
# Update residual : rank 1 update
799
- dger (CblasRowMajor, n_samples, n_tasks, - 1.0 ,
836
+ ger (CblasRowMajor, n_samples, n_tasks, - 1.0 ,
800
837
X_ptr + ii * n_samples, 1 , W_ptr + ii * n_tasks, 1 ,
801
838
& R[0 , 0 ], n_tasks)
802
839
@@ -818,7 +855,7 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
818
855
# XtA = np.dot(X.T, R) - l2_reg * W.T
819
856
for ii in range (n_features):
820
857
for jj in range (n_tasks):
821
- XtA[ii, jj] = ddot (
858
+ XtA[ii, jj] = dot (
822
859
n_samples, X_ptr + ii * n_samples, 1 ,
823
860
& R[0 , 0 ] + jj, n_tasks
824
861
) - l2_reg * W[jj, ii]
@@ -827,15 +864,15 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
827
864
dual_norm_XtA = 0.0
828
865
for ii in range (n_features):
829
866
# np.sqrt(np.sum(XtA ** 2, axis=1))
830
- XtA_axis1norm = dnrm2 (n_tasks, & XtA[0 , 0 ] + ii * n_tasks, 1 )
867
+ XtA_axis1norm = nrm2 (n_tasks, & XtA[0 , 0 ] + ii * n_tasks, 1 )
831
868
if XtA_axis1norm > dual_norm_XtA:
832
869
dual_norm_XtA = XtA_axis1norm
833
870
834
871
# TODO: use squared L2 norm directly
835
872
# R_norm = linalg.norm(R, ord='fro')
836
873
# w_norm = linalg.norm(W, ord='fro')
837
- R_norm = dnrm2 (n_samples * n_tasks, & R[0 , 0 ], 1 )
838
- w_norm = dnrm2 (n_features * n_tasks, W_ptr, 1 )
874
+ R_norm = nrm2 (n_samples * n_tasks, & R[0 , 0 ], 1 )
875
+ w_norm = nrm2 (n_features * n_tasks, W_ptr, 1 )
839
876
if (dual_norm_XtA > l1_reg):
840
877
const = l1_reg / dual_norm_XtA
841
878
A_norm = R_norm * const
@@ -854,7 +891,7 @@ def enet_coordinate_descent_multi_task(double[::1, :] W, double l1_reg,
854
891
l21_norm = 0.0
855
892
for ii in range (n_features):
856
893
# np.sqrt(np.sum(W ** 2, axis=0))
857
- l21_norm += dnrm2 (n_tasks, W_ptr + n_tasks * ii, 1 )
894
+ l21_norm += nrm2 (n_tasks, W_ptr + n_tasks * ii, 1 )
858
895
859
896
gap += l1_reg * l21_norm - const * ry_sum + \
860
897
0.5 * l2_reg * (1 + const ** 2 ) * (w_norm ** 2 )
0 commit comments