8000 ENH: einsum: Specialize contiguous reduction, add SSE prefetching · numpy/numpy@260824f · GitHub
[go: up one dir, main page]

Skip to content

Commit 260824f

Browse files
committed
ENH: einsum: Specialize contiguous reduction, add SSE prefetching
Also fix some compiler warnings. The biggest performance improvement was from adding SSE prefetching.
1 parent 8598315 commit 260824f

File tree

1 file changed

+239
-6
lines changed

1 file changed

+239
-6
lines changed

numpy/core/src/multiarray/einsum.c.src

Lines changed: 239 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ static void
169169
#define _SUMPROD_NOP nop
170170
# endif
171171
npy_@temp@ re, im, tmp;
172+
int i;
172173
re = ((npy_@temp@ *)dataptr[0])[0];
173174
im = ((npy_@temp@ *)dataptr[0])[1];
174-
int i;
175175
for (i = 1; i < _SUMPROD_NOP; ++i) {
176176
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
177177
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -202,7 +202,8 @@ static void
202202
npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
203203
npy_@name@ *data_out = (npy_@name@ *)dataptr[1];
204204

205-
NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n", (int)count);
205+
NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n",
206+
(int)count);
206207

207208
/* This is placed before the main loop to make small counts faster */
208209
finish_after_unrolled_loop:
@@ -268,7 +269,8 @@ static void
268269
__m128 a, b;
269270
#endif
270271

271-
NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n", (int)count);
272+
NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n",
273+
(int)count);
272274

273275
/* This is placed before the main loop to make small counts faster */
274276
finish_after_unrolled_loop:
@@ -592,6 +594,9 @@ finish_after_unrolled_loop:
592594
while (count >= 8) {
593595
count -= 8;
594596

597+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
598+
_mm_prefetch(data1 + 512, _MM_HINT_T0);
599+
595600
/**begin repeat2
596601
* #i = 0, 4#
597602
*/
@@ -623,6 +628,9 @@ finish_after_unrolled_loop:
623628
while (count >= 8) {
624629
count -= 8;
625630

631+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
632+
_mm_prefetch(data1 + 512, _MM_HINT_T0);
633+
626634
/**begin repeat2
627635
* #i = 0, 2, 4, 6#
628636
*/
@@ -652,6 +660,9 @@ finish_after_unrolled_loop:
652660
count -= 8;
653661

654662
#if EINSUM_USE_SSE1 && @float32@
663+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
664+
_mm_prefetch(data1 + 512, _MM_HINT_T0);
665+
655666
/**begin repeat2
656667
* #i = 0, 4#
657668
*/
@@ -663,6 +674,9 @@ finish_after_unrolled_loop:
663674
accum_sse = _mm_add_ps(accum_sse, a);
664675
/**end repeat2**/
665676
#elif EINSUM_USE_SSE2 && @float64@
677+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
678+
_mm_prefetch(data1 + 512, _MM_HINT_T0);
679+
666680
/**begin repeat2
667681
* #i = 0, 2, 4, 6#
668682
*/
@@ -943,7 +957,7 @@ static void
943957
/**end repeat2**/
944958
}
945959

946-
#else
960+
#else /* @nop@ > 3 || @complex */
947961

948962
static void
949963
@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr,
@@ -971,9 +985,9 @@ static void
971985
# define _SUMPROD_NOP nop
972986
# endif
973987
npy_@temp@ re, im, tmp;
988+
int i;
974989
re = ((npy_@temp@ *)dataptr[0])[0];
975990
im = ((npy_@temp@ *)dataptr[0])[1];
976-
int i;
977991
for (i = 1; i < _SUMPROD_NOP; ++i) {
978992
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
979993
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -994,7 +1008,186 @@ static void
9941008
}
9951009
}
9961010

1011+
#endif /* functions for various @nop@ */
1012+
1013+
#if @nop@ == 1
1014+
1015+
static void
1016+
@name@_sum_of_products_contig_outstride0_one(int nop, char **dataptr,
1017+
npy_intp *strides, npy_intp count)
1018+
{
1019+
#if @complex@
1020+
npy_@temp@ accum_re = 0, accum_im = 0;
1021+
npy_@temp@ *data0 = (npy_@temp@ *)dataptr[0];
1022+
#else
1023+
npy_@temp@ accum = 0;
1024+
npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
1025+
#endif
1026+
1027+
#if EINSUM_USE_SSE1 && @float32@
1028+
__m128 a, accum_sse = _mm_setzero_ps();
1029+
#elif EINSUM_USE_SSE2 && @float64@
1030+
__m128d a, accum_sse = _mm_setzero_pd();
1031+
#endif
1032+
1033+
1034+
NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_outstride0_one (%d)\n",
1035+
(int)count);
1036+
1037+
/* This is placed before the main loop to make small counts faster */
1038+
finish_after_unrolled_loop:
1039+
switch (count) {
1040+
/**begin repeat2
1041+
* #i = 6, 5, 4, 3, 2, 1, 0#
1042+
*/
1043+
case @i@+1:
1044+
#if !@complex@
1045+
accum += @from@(data0[@i@]);
1046+
#else /* complex */
1047+
accum_re += data0[2*@i@+0];
1048+
accum_im += data0[2*@i@+1];
1049+
#endif
1050+
/**end repeat2**/
1051+
case 0:
1052+
#if @complex@
1053+
((npy_@temp@ *)dataptr[1])[0] += accum_re;
1054+
((npy_@temp@ *)dataptr[1])[1] += accum_im;
1055+
#else
1056+
*((npy_@name@ *)dataptr[1]) = @to@(accum +
1057+
@from@(*((npy_@name@ *)dataptr[1])));
1058+
#endif
1059+
return;
1060+
}
1061+
1062+
#if EINSUM_USE_SSE1 && @float32@
1063+
/* Use aligned instructions if possible */
1064+
if (EINSUM_IS_SSE_ALIGNED(data0)) {
1065+
/* Unroll the loop by 8 */
1066+
while (count >= 8) {
1067+
count -= 8;
1068+
1069+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
1070+
1071+
/**begin repeat2
1072+
* #i = 0, 4#
1073+
*/
1074+
/*
1075+
* NOTE: This accumulation changes the order, so will likely
1076+
* produce slightly different results.
1077+
*/
1078+
accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@));
1079+
/**end repeat2**/
1080+
data0 += 8;
1081+
}
1082+
1083+
/* Add the four SSE values and put in accum */
1084+
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
1085+
accum_sse = _mm_add_ps(a, accum_sse);
1086+
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
1087+
accum_sse = _mm_add_ps(a, accum_sse);
1088+
_mm_store_ss(&accum, accum_sse);
1089+
1090+
/* Finish off the loop */
1091+
goto finish_after_unrolled_loop;
1092+
}
1093+
#elif EINSUM_USE_SSE2 && @float64@
1094+
/* Use aligned instructions if possible */
1095+
if (EINSUM_IS_SSE_ALIGNED(data0)) {
1096+
/* Unroll the loop by 8 */
1097+
while (count >= 8) {
1098+
count -= 8;
1099+
1100+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
1101+
1102+
/**begin repeat2
1103+
* #i = 0, 2, 4, 6#
1104+
*/
1105+
/*
1106+
* NOTE: This accumulation changes the order, so will likely
1107+
* produce slightly different results.
1108+
*/
1109+
accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@));
1110+
/**end repeat2**/
1111+
data0 += 8;
1112+
}
1113+
1114+
/* Add the two SSE2 values and put in accum */
1115+
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
1116+
accum_sse = _mm_add_pd(a, accum_sse);
1117+
_mm_store_sd(&accum, accum_sse);
1118+
1119+
/* Finish off the loop */
1120+
goto finish_after_unrolled_loop;
1121+
}
1122+
#endif
1123+
1124+
/* Unroll the loop by 8 */
1125+
while (count >= 8) {
1126+
count -= 8;
1127+
1128+
#if EINSUM_USE_SSE1 && @float32@
1129+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
1130+
1131+
/**begin repeat2
1132+
* #i = 0, 4#
1133+
*/
1134+
/*
1135+
* NOTE: This accumulation changes the order, so will likely
1136+
* produce slightly different results.
1137+
*/
1138+
accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@));
1139+
/**end repeat2**/
1140+
#elif EINSUM_USE_SSE2 && @float64@
1141+
_mm_prefetch(data0 + 512, _MM_HINT_T0);
1142+
1143+
/**begin repeat2
1144+
* #i = 0, 2, 4, 6#
1145+
*/
1146+
/*
1147+
* NOTE: This accumulation changes the order, so will likely
1148+
* produce slightly different results.
1149+
*/
1150+
accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@));
1151+
/**end repeat2**/
1152+
#else
1153+
/**begin repeat2
1154+
* #i = 0, 1, 2, 3, 4, 5, 6, 7#
1155+
*/
1156+
# if !@complex@
1157+
accum += @from@(data0[@i@]);
1158+
# else /* complex */
1159+
accum_re += data0[2*@i@+0];
1160+
accum_im += data0[2*@i@+1];
1161+
# endif
1162+
/**end repeat2**/
1163+
#endif
1164+
1165+
#if !@complex@
1166+
data0 += 8;
1167+
#else
1168+
data0 += 8*2;
9971169
#endif
1170+
}
1171+
1172+
#if EINSUM_USE_SSE1 && @float32@
1173+
/* Add the four SSE values and put in accum */
1174+
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
1175+
accum_sse = _mm_add_ps(a, accum_sse);
1176+
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
1177+
accum_sse = _mm_add_ps(a, accum_sse);
1178+
_mm_store_ss(&accum, accum_sse);
1179+
#elif EINSUM_USE_SSE2 && @float64@
1180+
/* Add the two SSE2 values and put in accum */
1181+
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
1182+
accum_sse = _mm_add_pd(a, accum_sse);
1183+
_mm_store_sd(&accum, accum_sse);
1184+
#endif
1185+
1186+
/* Finish off the loop */
1187+
goto finish_after_unrolled_loop;
1188+
}
1189+
1190+
#endif /* @nop@ == 1 */
9981191

9991192
static void
10001193
@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
@@ -1062,9 +1255,9 @@ static void
10621255
#define _SUMPROD_NOP nop
10631256
# endif
10641257
npy_@temp@ re, im, tmp;
1258+
int i;
10651259
re = ((npy_@temp@ *)dataptr[0])[0];
10661260
im = ((npy_@temp@ *)dataptr[0])[1];
1067-
int i;
10681261
for (i = 1; i < _SUMPROD_NOP; ++i) {
10691262
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
10701263
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -1347,6 +1540,37 @@ bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
13471540
typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp);
13481541

13491542
/* These tables need to match up with the type enum */
1543+
static sum_of_products_fn
1544+
_contig_outstride0_unary_specialization_table[NPY_NTYPES] = {
1545+
/**begin repeat
1546+
* #name = bool,
1547+
* byte, ubyte,
1548+
* short, ushort,
1549+
* int, uint,
1550+
* long, ulong,
1551+
* longlong, ulonglong,
1552+
* float, double, longdouble,
1553+
* cfloat, cdouble, clongdouble,
1554+
* object, string, unicode, void,
1555+
* datetime, timedelta, half#
1556+
* #use = 0,
1557+
* 1, 1,
1558+
* 1, 1,
1559+
* 1, 1,
1560+
* 1, 1,
1561+
* 1, 1,
1562+
* 1, 1, 1,
1563+
* 1, 1, 1,
1564+
* 0, 0, 0, 0,
1565+
* 0, 0, 1#
1566+
*/
1567+
#if @use@
1568+
&@name@_sum_of_products_contig_outstride0_one,
1569+
#else
1570+
NULL,
1571+
#endif
1572+
/**end repeat**/
1573+
}; /* End of _contig_outstride0_unary_specialization_table */
13501574

13511575
static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = {
13521576
/**begin repeat
@@ -1503,6 +1727,15 @@ get_sum_of_products_function(int nop, int type_num,
15031727
return NULL;
15041728
}
15051729

1730+
/* contiguous reduction */
1731+
if (nop == 1 && fixed_strides[0] == itemsize && fixed_strides[1] == 0) {
1732+
sum_of_products_fn ret =
1733+
_contig_outstride0_unary_specialization_table[type_num];
1734+
if (ret != NULL) {
1735+
return ret;
1736+
}
1737+
}
1738+
15061739
/* nop of 2 has more specializations */
15071740
if (nop == 2) {
15081741
/* Encode the zero/contiguous strides */

0 commit comments

Comments
 (0)
0