8000 Merge pull request #6932 from jakirkham/opt_dot_trans · numpy/numpy@25c8d1c · GitHub
[go: up one dir, main page]

Skip to content

Commit 25c8d1c

Browse files
committed
Merge pull request #6932 from jakirkham/opt_dot_trans
ENH: Use `syrk` to compute certain dot products more quickly and accurately
2 parents a7377d8 + 8d8a74d commit 25c8d1c

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

benchmarks/benchmarks/bench_linalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
class Eindot(Benchmark):
99
def setup(self):
1010
self.a = np.arange(60000.0).reshape(150, 400)
11+
self.at = self.a.T
12+
self.atc = self.a.T.copy()
1113
self.b = np.arange(240000.0).reshape(400, 600)
1214
self.c = np.arange(600)
1315
self.d = np.arange(400)
@@ -21,6 +23,18 @@ def time_einsum_ij_jk_a_b(self):
2123
def time_dot_a_b(self):
2224
np.dot(self.a, self.b)
2325

26+
def time_dot_trans_a_at(self):
27+
np.dot(self.a, self.at)
28+
29+
def time_dot_trans_a_atc(self):
30+
np.dot(self.a, self.atc)
31+
32+
def time_dot_trans_at_a(self):
33+
np.dot(self.at, self.a)
34+
35+
def time_dot_trans_atc_a(self):
36+
np.dot(self.atc, self.a)
37+
2438
def time_einsum_i_ij_j(self):
2539
np.einsum('i,ij,j', self.d, self.b, self.c)
2640

doc/release/1.11.0-notes.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ useless computations when printing a masked array.
130130
The function now uses the fallocate system call to reserve sufficient
131131
diskspace on filesystems that support it.
132132

133+
``np.dot`` optimized for operations of the form ``A.T @ A`` and ``A @ A.T``
134+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135+
Previously, ``gemm`` BLAS operations were used for all matrix products. Now,
136+
if the matrix product is between a matrix and its transpose, it will use
137+
``syrk`` BLAS operations for a performance boost.
138+
139+
**Note:** Requires the transposed and non-transposed matrices to share data.
140+
133141
Changes
134142
=======
135143

numpy/core/src/multiarray/cblasfuncs.c

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
111111
}
112112

113113

114+
/*
115+
* Helper: dispatch to appropriate cblas_?syrk for typenum.
116+
*/
117+
static void
118+
syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
119+
int n, int k,
120+
PyArrayObject *A, int lda, PyArrayObject *R)
121+
{
122+
const void *Adata = PyArray_DATA(A);
123+
void *Rdata = PyArray_DATA(R);
124+
int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
125+
126+
npy_intp i;
127+
npy_intp j;
128+
129+
switch (typenum) {
130+
case NPY_DOUBLE:
131+
cblas_dsyrk(order, CblasUpper, trans, n, k, 1.,
132+
Adata, lda, 0., Rdata, ldc);
133+
134+
for (i = 0; i < n; i++)
135+
{
136+
for (j = i + 1; j < n; j++)
137+
{
138+
*((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j));
139+
}
140+
}
141+
break;
142+
case NPY_FLOAT:
143+
cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f,
144+
Adata, lda, 0.f, Rdata, ldc);
145+
146+
for (i = 0; i < n; i++)
147+
{
148+
for (j = i + 1; j < n; j++)
149+
{
150+
*((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j));
151+
}
152+
}
153+
break;
154+
case NPY_CDOUBLE:
155+
cblas_zsyrk(order, CblasUpper, trans, n, k, oneD,
156+
Adata, lda, zeroD, Rdata, ldc);
157+
158+
for (i = 0; i < n; i++)
159+
{
160+
for (j = i + 1; j < n; j++)
161+
{
162+
*((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
163+
}
164+
}
165+
break;
166+
case NPY_CFLOAT:
167+
cblas_csyrk(order, CblasUpper, trans, n, k, oneF,
168+
Adata, lda, zeroF, Rdata, ldc);
169+
170+
for (i = 0; i < n; i++)
171+
{
172+
for (j = i + 1; j < n; j++)
173+
{
174+
*((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
175+
}
176+
}
177+
break;
178+
}
179+
}
180+
181+
114182
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
115183

116184

@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
647715
Trans2 = CblasTrans;
648716
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
649717
}
650-
gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
718+
719+
/*
720+
* Use syrk if we have a case of a matrix times its transpose.
721+
* Otherwise, use gemm for all other cases.
722+
*/
723+
if (
724+
(PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
725+
(PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
726+
(PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
727+
(PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
728+
(PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
729+
((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
730+
((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
731+
)
732+
{
733+
if (Trans1 == CblasNoTrans)
734+
{
735+
syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
736+
}
737+
else
738+
{
739+
syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
740+
}
741+
}
742+
else
743+
{
744+
gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
745+
}
651746
NPY_END_ALLOW_THREADS;
652747
}
653748

0 commit comments

Comments
 (0)
0