8000 Merge pull request #17994 from Qiyu8/einsum-dot · numpy/numpy@12d99b5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 12d99b5

Browse files
authored
Merge pull request #17994 from Qiyu8/einsum-dot
SIMD: Optimize the performance of einsum's submodule dot .
2 parents e4feb70 + 4c7b3d6 commit 12d99b5

File tree

1 file changed

+45
-141
lines changed

1 file changed

+45
-141
lines changed

numpy/core/src/multiarray/einsum_sumprod.c.src

Lines changed: 45 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -589,164 +589,68 @@ finish_after_unrolled_loop:
589589
goto finish_after_unrolled_loop;
590590
}
591591

592-
static void
592+
static NPY_GCC_OPT_3 void
593593
@name@_sum_of_products_contig_contig_outstride0_two(int nop, char **dataptr,
594594
npy_intp const *NPY_UNUSED(strides), npy_intp count)
595595
{
596596
@type@ *data0 = (@type@ *)dataptr[0];
597597
@type@ *data1 = (@type@ *)dataptr[1];
598598
@temptype@ accum = 0;
599599

600-
#if EINSUM_USE_SSE1 && @float32@
601-
__m128 a, accum_sse = _mm_setzero_ps();
602-
#elif EINSUM_USE_SSE2 && @float64@
603-
__m128d a, accum_sse = _mm_setzero_pd();
604-
#endif
605-
606600
NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_contig_outstride0_two (%d)\n",
607601
(int)count);
608-
609-
/* This is placed before the main loop to make small counts faster */
610-
finish_after_unrolled_loop:
611-
switch (count) {
612-
/**begin repeat2
613-
* #i = 6, 5, 4, 3, 2, 1, 0#
614-
*/
615-
case @i@+1:
616-
accum += @from@(data0[@i@]) * @from@(data1[@i@]);
617-
/**end repeat2**/
618-
case 0:
619-
*(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum);
620-
return;
621-
}
622-
623-
#if EINSUM_USE_SSE1 && @float32@
602+
#if @NPYV_CHK@ // NPYV check for @type@
624603
/* Use aligned instructions if possible */
625-
if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) {
626-
/* Unroll the 10000 loop by 8 */
627-
while (count >= 8) {
628-
count -= 8;
629-
630-
_mm_prefetch(data0 + 512, _MM_HINT_T0);
631-
_mm_prefetch(data1 + 512, _MM_HINT_T0);
604+
const int is_aligned = EINSUM_IS_ALIGNED(data0) && EINSUM_IS_ALIGNED(data1);
605+
const int vstep = npyv_nlanes_@sfx@;
606+
npyv_@sfx@ vaccum = npyv_zero_@sfx@();
632607

633-
/**begin repeat2
634-
* #i = 0, 4#
635-
*/
636-
/*
637-
* NOTE: This accumulation changes the order, so will likely
638-
* produce slightly different results.
608+
/**begin repeat2
609+
* #cond = if(is_aligned), else#
610+
* #ld = loada, load#
611+
* #st = storea, store#
612+
*/
613+
@cond@ {
614+
const npy_intp vstepx4 = vstep * 4;
615+
for (; count >= vstepx4; count -= vstepx4, data0 += vstepx4, data1 += vstepx4) {
616+
/**begin repeat3
617+
* #i = 0, 1, 2, 3#
639618
*/
640-
a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@));
641-
accum_sse = _mm_add_ps(accum_sse, a);
642-
/**end repeat2**/
643-
data0 += 8;
644-
data1 += 8;
619+
npyv_@sfx@ a@i@ = npyv_@ld@_@sfx@(data0 + vstep * @i@);
620+
npyv_@sfx@ b@i@ = npyv_@ld@_@sfx@(data1 + vstep * @i@);
621+
/**end repeat3**/
622+
npyv_@sfx@ ab3 = npyv_muladd_@sfx@(a3, b3, vaccum);
623+
npyv_@sfx@ ab2 = npyv_muladd_@sfx@(a2, b2, ab3);
624+
npyv_@sfx@ ab1 = npyv_muladd_@sfx@(a1, b1, ab2);
625+
vaccum = npyv_muladd_@sfx@(a0, b0, ab1);
645626
}
646-
647-
/* Add the four SSE values and put in accum */
648-
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
649-
accum_sse = _mm_add_ps(a, accum_sse);
650-
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
651-
accum_sse = _mm_add_ps(a, accum_sse);
652-
_mm_store_ss(&accum, accum_sse);
653-
654-
/* Finish off the loop */
655-
goto finish_after_unrolled_loop;
656627
}
657-
#elif EINSUM_USE_SSE2 && @float64@
658-
/* Use aligned instructions if possible */
659-
if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) {
660-
/* Unroll the loop by 8 */
661-
while (count >= 8) {
662-
count -= 8;
663-
664-
_mm_prefetch(data0 + 512, _MM_HINT_T0);
665-
_mm_prefetch(data1 + 512, _MM_HINT_T0);
666-
667-
/**begin repeat2
668-
* #i = 0, 2, 4, 6#
669-
*/
670-
/*
671-
* NOTE: This accumulation changes the order, so will likely
672-
* produce slightly different results.
673-
*/
674-
a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@));
675-
accum_sse = _mm_add_pd(accum_sse, a);
676-
/**end repeat2**/
677-
data0 += 8;
678-
data1 += 8;
679-
}
680-
681-
/* Add the two SSE2 values and put in accum */
682-
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
683-
accum_sse = _mm_add_pd(a, accum_sse);
684-
_mm_store_sd(&accum, accum_sse);
685-
686-
/* Finish off the loop */
687-
goto finish_after_unrolled_loop;
628+
/**end repeat2**/
629+
for (; count > 0; count -= vstep, data0 += vstep, data1 += vstep) {
630+
npyv_@sfx@ a = npyv_load_tillz_@sfx@(data0, count);
631+
npyv_@sfx@ b = npyv_load_tillz_@sfx@(data1, count);
632+
vaccum = npyv_muladd_@sfx@(a, b, vaccum);
688633
}
689-
#endif
690-
691-
/* Unroll the loop by 8 */
692-
while (count >= 8) {
693-
count -= 8;
694-
695-
#if EINSUM_USE_SSE1 && @float32@
696-
_mm_prefetch(data0 + 512, _MM_HINT_T0);
697-
_mm_prefetch(data1 + 512, _MM_HINT_T0);
698-
699-
/**begin repeat2
700-
* #i = 0, 4#
701-
*/
702-
/*
703-
* NOTE: This accumulation changes the order, so will likely
704-
* produce slightly different results.
705-
*/
706-
a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@));
707-
accum_sse = _mm_add_ps(accum_sse, a);
708-
/**end repeat2**/
709-
#elif EINSUM_USE_SSE2 && @float64@
710-
_mm_prefetch(data0 + 512, _MM_HINT_T0);
711-
_mm_prefetch(data1 + 512, _MM_HINT_T0);
712-
713-
/**begin repeat2
714-
* #i = 0, 2, 4, 6#
715-
*/
716-
/*
717-
* NOTE: This accumulation changes the order, so will likely
718-
* produce slightly different results.
719-
*/
720-
a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@));
721-
accum_sse = _mm_add_pd(accum_sse, a);
722-
/**end repeat2**/
634+
accum = npyv_sum_@sfx@(vaccum);
635+
npyv_cleanup();
723636
#else
724-
/**begin repeat2
725-
* #i = 0, 1, 2, 3, 4, 5, 6, 7#
726-
*/
727-
accum += @from@(data0[@i@]) * @from@(data1[@i@]);
728-
/**end repeat2**/
729-
#endif
730-
data0 += 8;
731-
data1 += 8;
637+
#ifndef NPY_DISABLE_OPTIMIZATION
638+
for (; count >= 4; count -= 4, data0 += 4, data1 += 4) {
639+
/**begin repeat2
640+
* #i = 0, 1, 2, 3#
641+
*/
642+
const @type@ ab@i@ = @from@(data0[@i@]) * @from@(data1[@i@]);
643+
/**end repeat2**/
644+
accum += ab0 + ab1 + ab2 + ab3;
732645
}
733-
734-
#if EINSUM_USE_SSE1 && @float32@
735-
/* Add the four SSE values and put in accum */
736-
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
737-
accum_sse = _mm_add_ps(a, accum_sse);
738-
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
739-
accum_sse = _mm_add_ps(a, accum_sse);
740-
_mm_store_ss(&accum, accum_sse);
741-
#elif EINSUM_USE_SSE2 && @float64@
742-
/* Add the two SSE2 values and put in accum */
743-
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
744-
accum_sse = _mm_add_pd(a, accum_sse);
745-
_mm_store_sd(&accum, accum_sse);
746-
#endif
747-
748-
/* Finish off the loop */
749-
goto finish_after_unrolled_loop;
646+
#endif // !NPY_DISABLE_OPTIMIZATION
647+
for (; count > 0; --count, ++data0, ++data1) {
648+
const @type@ a = @from@(*data0);
649+
const @type@ b = @from@(*data1);
650+
accum += a * b;
651+
}
652+
#endif // NPYV check for @type@
653+
*(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum);
750654
}
751655

752656
static void

0 commit comments

Comments
 (0)
0