@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
111
111
}
112
112
113
113
114
+ /*
115
+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
116
+ */
117
+ static void
118
+ syrk (int typenum , enum CBLAS_ORDER order , enum CBLAS_TRANSPOSE trans ,
119
+ int n , int k ,
120
+ PyArrayObject * A , int lda , PyArrayObject * R )
121
+ {
122
+ const void * Adata = PyArray_DATA (A );
123
+ void * Rdata = PyArray_DATA (R );
124
+ int ldc = PyArray_DIM (R , 1 ) > 1 ? PyArray_DIM (R , 1 ) : 1 ;
125
+
126
+ npy_intp i ;
127
+ npy_intp j ;
128
+
129
+ switch (typenum ) {
130
+ case NPY_DOUBLE :
131
+ cblas_dsyrk (order , CblasUpper , trans , n , k , 1. ,
132
+ Adata , lda , 0. , Rdata , ldc );
133
+
134
+ for (i = 0 ; i < n ; i ++ )
135
+ {
136
+ for (j = i + 1 ; j < n ; j ++ )
137
+ {
138
+ * ((npy_double * )PyArray_GETPTR2 (R , j , i )) = * ((npy_double * )PyArray_GETPTR2 (R , i , j ));
139
+ }
140
+ }
141
+ break ;
142
+ case NPY_FLOAT :
143
+ cblas_ssyrk (order , CblasUpper , trans , n , k , 1.f ,
144
+ Adata , lda , 0.f , Rdata , ldc );
145
+
146
+ for (i = 0 ; i < n ; i ++ )
147
+ {
148
+ for (j = i + 1 ; j < n ; j ++ )
149
+ {
150
+ * ((npy_float * )PyArray_GETPTR2 (R , j , i )) = * ((npy_float * )PyArray_GETPTR2 (R , i , j ));
151
+ }
152
+ }
153
+ break ;
154
+ case NPY_CDOUBLE :
155
+ cblas_zsyrk (order , CblasUpper , trans , n , k , oneD ,
156
+ Adata , lda , zeroD , Rdata , ldc );
157
+
158
+ for (i = 0 ; i < n ; i ++ )
159
+ {
160
+ for (j = i + 1 ; j < n ; j ++ )
161
+ {
162
+ * ((npy_cdouble * )PyArray_GETPTR2 (R , j , i )) = * ((npy_cdouble * )PyArray_GETPTR2 (R , i , j ));
163
+ }
164
+ }
165
+ break ;
166
+ case NPY_CFLOAT :
167
+ cblas_csyrk (order , CblasUpper , trans , n , k , oneF ,
168
+ Adata , lda , zeroF , Rdata , ldc );
169
+
170
+ for (i = 0 ; i < n ; i ++ )
171
+ {
172
+ for (j = i + 1 ; j < n ; j ++ )
173
+ {
174
+ * ((npy_cfloat * )PyArray_GETPTR2 (R , j , i )) = * ((npy_cfloat * )PyArray_GETPTR2 (R , i , j ));
175
+ }
176
+ }
177
+ break ;
178
+ }
179
+ }
180
+
181
+
114
182
typedef enum {_scalar , _column , _row , _matrix } MatrixShape ;
115
183
116
184
@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
647
715
Trans2 = CblasTrans ;
648
716
ldb = (PyArray_DIM (ap2 , 0 ) > 1 ? PyArray_DIM (ap2 , 0 ) : 1 );
649
717
}
650
- gemm (typenum , Order , Trans1 , Trans2 , L , N , M , ap1 , lda , ap2 , ldb , ret );
718
+
719
+ /*
720
+ * Use syrk if we have a case of a matrix times its transpose.
721
+ * Otherwise, use gemm for all other cases.
722
+ */
723
+ if (
724
+ (PyArray_BYTES (ap1 ) == PyArray_BYTES (ap2 )) &&
725
+ (PyArray_DIM (ap1 , 0 ) == PyArray_DIM (ap2 , 1 )) &&
726
+ (PyArray_DIM (ap1 , 1 ) == PyArray_DIM (ap2 , 0 )) &&
727
+ (PyArray_STRIDE (ap1 , 0 ) == PyArray_STRIDE (ap2 , 1 )) &&
728
+ (PyArray_STRIDE (ap1 , 1 ) == PyArray_STRIDE (ap2 , 0 )) &&
729
+ ((Trans1 == CblasTrans ) ^ (Trans2 == CblasTrans )) &&
730
+ ((Trans1 == CblasNoTrans ) ^ (Trans2 == CblasNoTrans ))
731
+ )
732
+ {
733
+ if (Trans1 == CblasNoTrans )
734
+ {
735
+ syrk (typenum , Order , Trans1 , N , M , ap1 , lda , ret );
736
+ }
737
+ else
738
+ {
739
+ syrk (typenum , Order , Trans1 , N , M , ap2 , ldb , ret );
740
+ }
741
+ }
742
+ else
743
+ {
744
+ gemm (typenum , Order , Trans1 , Trans2 , L , N , M , ap1 , lda , ap2 , ldb , ret );
745
+ }
651
746
NPY_END_ALLOW_THREADS ;
652
747
}
653
748
0 commit comments