8000 WIP::ENH:SIMD Improve the performance of comparison operators by seiko2plus · Pull Request #16960 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

WIP::ENH:SIMD Improve the performance of comparison operators #16960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Prev Previous commit
Next Next commit
ENH:NPYV add non-contiguous load/store intrinsics for all vectors types
  • Loading branch information
seiko2plus committed Aug 3, 2020
commit 1fa7b8155c9a4ce80d545ff9354c9863f6f4a11d
143 changes: 143 additions & 0 deletions numpy/core/src/common/simd/avx2/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,147 @@ NPYV_IMPL_AVX2_MEM_INT(npy_int64, s64)
#define npyv_storeh_f32(PTR, VEC) _mm_storeu_ps(PTR, _mm256_extractf128_ps(VEC, 1))
#define npyv_storeh_f64(PTR, VEC) _mm_storeu_pd(PTR, _mm256_extractf128_pd(VEC, 1))

/***************************
* Non-contiguous Load
***************************/
//// 8
NPY_FINLINE npyv_u8 npyv_loadn_u8(const npy_uint8 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
const __m256i cut32 = _mm256_set1_epi32(0xFF);
const __m256i sort_odd = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
__m256i a = _mm256_i32gather_epi32((const int*)ptr, idx, 1);
__m256i b = _mm256_i32gather_epi32((const int*)(ptr + stride*8), idx, 1);
__m256i c = _mm256_i32gather_epi32((const int*)(ptr + stride*16), idx, 1);
__m256i d = _mm256_i32gather_epi32((const int*)((ptr-3/*overflow guard*/) + stride*24), idx, 1);
a = _mm256_and_si256(a, cut32);
b = _mm256_and_si256(b, cut32);
c = _mm256_and_si256(c, cut32);
d = _mm256_srli_epi32(d, 24);
a = _mm256_packus_epi32(a, b);
c = _mm256_packus_epi32(c, d);
return _mm256_permutevar8x32_epi32(_mm256_packus_epi16(a, c), sort_odd);
}
NPY_FINLINE npyv_s8 npyv_loadn_s8(const npy_int8 *ptr, int stride)
{ return npyv_loadn_u8((const npy_uint8 *)ptr, stride); }
//// 16
NPY_FINLINE npyv_u16 npyv_loadn_u16(const npy_uint16 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
const __m256i cut32 = _mm256_set1_epi32(0xFF);
__m256i a = _mm256_i32gather_epi32((const int*)ptr, idx, 2);
__m256i b = _mm256_i32gather_epi32((const int*)((ptr-1/*overflow guard*/) + stride*8), idx, 2);
a = _mm256_and_si256(a, cut32);
b = _mm256_srli_epi32(b, 16);
return npyv256_shuffle_odd(_mm256_packus_epi16(a, b));
}
NPY_FINLINE npyv_s16 npyv_loadn_s16(const npy_int16 *ptr, int stride)
{ return npyv_loadn_u16((const npy_uint16 *)ptr, stride); }
//// 32
NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
return _mm256_i32gather_epi32((const int*)ptr, idx, 4);
}
NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, int stride)
{ return npyv_loadn_u32((const npy_uint32*)ptr, stride); }
NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, int stride)
{ return _mm256_castsi256_ps(npyv_loadn_u32((const npy_uint32*)ptr, stride)); }
//// 64
NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, int stride)
{
const __m128i steps = _mm_setr_epi32(0, 1, 2, 3);
const __m128i idx = _mm_mullo_epi32(_mm_set1_epi32(stride), steps);
return _mm256_i32gather_epi64((const void*)ptr, idx, 8);
}
NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, int stride)
{ return npyv_loadn_u64((const npy_uint64*)ptr, stride); }
NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, int stride)
{ return _mm256_castsi256_pd(npyv_loadn_u64((const npy_uint64*)ptr, stride)); }

/***************************
* Non-contiguous Store
***************************/
//// 8
NPY_FINLINE void npyv_storen_u8(npy_uint8 *ptr, int stride, npyv_u8 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
#define NPYV_IMPL_AVX2_STOREN8(VEC, EI, I) \
{ \
unsigned e = (unsigned)_mm_extract_epi32(VEC, EI); \
ptr[stride*(I+0)] = (npy_uint8)e; \
ptr[stride*(I+1)] = (npy_uint8)(e >> 8); \
ptr[stride*(I+2)] = (npy_uint8)(e >> 16); \
ptr[stride*(I+3)] = (npy_uint8)(e >> 24); \
}
NPYV_IMPL_AVX2_STOREN8(a0, 0, 0)
NPYV_IMPL_AVX2_STOREN8(a0, 1, 4)
NPYV_IMPL_AVX2_STOREN8(a0, 2, 8)
NPYV_IMPL_AVX2_STOREN8(a0, 3, 12)
NPYV_IMPL_AVX2_STOREN8(a1, 0, 16)
NPYV_IMPL_AVX2_STOREN8(a1, 1, 20)
NPYV_IMPL_AVX2_STOREN8(a1, 2, 24)
NPYV_IMPL_AVX2_STOREN8(a1, 3, 28)
}
NPY_FINLINE void npyv_storen_s8(npy_int8 *ptr, int stride, npyv_s8 a)
{ npyv_storen_u8((npy_uint8*)ptr, stride, a); }
//// 16
NPY_FINLINE void npyv_storen_u16(npy_uint16 *ptr, int stride, npyv_u16 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
#define NPYV_IMPL_AVX2_STOREN16(VEC, EI, I) \
{ \
unsigned e = (unsigned)_mm_extract_epi32(VEC, EI); \
ptr[stride*(I+0)] = (npy_uint16)e; \
ptr[stride*(I+1)] = (npy_uint16)(e >> 16); \
}
NPYV_IMPL_AVX2_STOREN16(a0, 0, 0)
NPYV_IMPL_AVX2_STOREN16(a0, 1, 2)
NPYV_IMPL_AVX2_STOREN16(a0, 2, 4)
NPYV_IMPL_AVX2_STOREN16(a0, 3, 6)
NPYV_IMPL_AVX2_STOREN16(a1, 0, 8)
NPYV_IMPL_AVX2_STOREN16(a1, 1, 10)
NPYV_IMPL_AVX2_STOREN16(a1, 2, 12)
NPYV_IMPL_AVX2_STOREN16(a1, 3, 14)
}
NPY_FINLINE void npyv_storen_s16(npy_int16 *ptr, int stride, npyv_s16 a)
{ npyv_storen_u16((npy_uint16*)ptr, stride, a); }
//// 32
NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, int stride, npyv_s32 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
ptr[stride * 0] = _mm_cvtsi128_si32(a0);
ptr[stride * 1] = _mm_extract_epi32(a0, 1);
ptr[stride * 2] = _mm_extract_epi32(a0, 2);
ptr[stride * 3] = _mm_extract_epi32(a0, 3);
ptr[stride * 4] = _mm_cvtsi128_si32(a1);
ptr[stride * 5] = _mm_extract_epi32(a1, 1);
ptr[stride * 6] = _mm_extract_epi32(a1, 2);
ptr[stride * 7] = _mm_extract_epi32(a1, 3);
}
NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, int stride, npyv_u32 a)
{ npyv_storen_s32((npy_int32*)ptr, stride, a); }
NPY_FINLINE void npyv_storen_f32(float *ptr, int stride, npyv_f32 a)
{ npyv_storen_s32((npy_int32*)ptr, stride, _mm256_castps_si256(a)); }
//// 64
NPY_FINLINE void npyv_storen_f64(double *ptr, int stride, npyv_f64 a)
{
__m128d a0 = _mm256_castpd256_pd128(a);
__m128d a1 = _mm256_extractf128_pd(a, 1);
_mm_storel_pd(ptr + stride * 0, a0);
_mm_storeh_pd(ptr + stride * 1, a0);
_mm_storel_pd(ptr + stride * 2, a1);
_mm_storeh_pd(ptr + stride * 3, a1);
}
NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, int stride, npyv_u64 a)
{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); }
NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, int stride, npyv_s64 a)
{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); }

#endif // _NPY_SIMD_AVX2_MEMORY_H
200 changes: 200 additions & 0 deletions numpy/core/src/common/simd/avx512/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,204 @@ NPYV_IMPL_AVX512_MEM_INT(npy_int64, s64)
#define npyv_storeh_f32(PTR, VEC) _mm256_storeu_ps(PTR, npyv512_higher_ps256(VEC))
#define npyv_storeh_f64(PTR, VEC) _mm256_storeu_pd(PTR, npyv512_higher_pd256(VEC))

// non-contiguous load
//// 8
NPY_FINLINE npyv_u8 npyv_loadn_u8(const npy_uint8 *ptr, int stride)
{
const __m512i steps = npyv_set_u32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
__m512i a = _mm512_i32gather_epi32(idx, (const void*)ptr, 1);
__m512i b = _mm512_i32gather_epi32(idx, (const void*)(ptr + stride*16), 1);
__m512i c = _mm512_i32gather_epi32(idx, (const void*)(ptr + stride*32), 1);
__m512i d = _mm512_i32gather_epi32(idx, (const void*)((ptr-3/*overflow guard*/)+stride*48), 1);
#ifdef NPY_HAVE_AVX512BW
const __m512i cut32 = _mm512_set1_epi32(0xFF);
a = _mm512_and_si512(a, cut32);
b = _mm512_and_si512(b, cut32);
c = _mm512_and_si512(c, cut32);
d = _mm512_srli_epi32(d, 24);
a = _mm512_packus_epi32(a, b);
c = _mm512_packus_epi32(c, d);
return npyv512_shuffle_odd32(_mm512_packus_epi16(a, c));
#else
__m128i af = _mm512_cvtepi32_epi8(a);
__m128i bf = _mm512_cvtepi32_epi8(b);
__m128i cf = _mm512_cvtepi32_epi8(c);
__m128i df = _mm512_cvtepi32_epi8(_mm512_srli_epi32(d, 24));
return npyv512_combine_si256(
_mm256_inserti128_si256(_mm256_castsi128_si256(af), bf, 1),
_mm256_inserti128_si256(_mm256_castsi128_si256(cf), df, 1)
);
#endif // !NPY_HAVE_AVX512BW
}
NPY_FINLINE npyv_s8 npyv_loadn_s8(const npy_int8 *ptr, int stride)
{ return npyv_loadn_u8((const npy_uint8*)ptr, stride); }
//// 16
NPY_FINLINE npyv_u16 npyv_loadn_u16(const npy_uint16 *ptr, int stride)
{
const __m512i steps = npyv_set_s32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
__m512i a = _mm512_i32gather_epi32(idx, (const void*)ptr, 2);
__m512i b = _mm512_i32gather_epi32(idx, (const void*)((ptr-1/*overflow guard*/)+stride*16), 2);
#ifdef NPY_HAVE_AVX512BW
const __m512i perm = npyv_set_u16(
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63
);
return _mm512_permutex2var_epi16(a, perm, b);
#else
__m256i af = _mm512_cvtepi32_epi16(a);
__m256i bf = _mm512_cvtepi32_epi16(_mm512_srli_epi32(b, 16));
return npyv512_combine_si256(af, bf);
#endif
}
NPY_FINLINE npyv_s16 npyv_loadn_s16(const npy_int16 *ptr, int stride)
{ return npyv_loadn_u16((const npy_uint16*)ptr, stride); }
//// 32
NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, int stride)
{
const __m512i steps = npyv_set_s32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
return _mm512_i32gather_epi32(idx, (const int*)ptr, 4);
}
NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, int stride)
{ return npyv_loadn_u32((const npy_uint32*)ptr, stride); }
NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, int stride)
{ return _mm512_castsi512_ps(npyv_loadn_u32((const npy_uint32*)ptr, stride)); }
//// 64
NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
return _mm512_i32gather_epi64(idx, (const void*)ptr, 8);
}
NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, int stride)
{ return npyv_loadn_u64((const npy_uint64*)ptr, stride); }
NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, int stride)
{ return _mm512_castsi512_pd(npyv_loadn_u64((const npy_uint64*)ptr, stride)); }

// non-contiguous store
//// 8
NPY_FINLINE void npyv_storen_u8(npy_uint8 *ptr, int stride, npyv_u8 a)
{
// GIT:WARN Buggy Buggy, need a fix
// TODO: overflow guard cause small strides overlaping (-3/-2/-1/1/2/3) between [45:48]
const __m512i steps = _mm512_setr_epi32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
__m512i m0 = _mm512_i32gather_epi32(idx, (const void*)ptr, 1);
__m512i m1 = _mm512_i32gather_epi32(idx, (const void*)(ptr + stride*16), 1);
__m512i m2 = _mm512_i32gather_epi32(idx, (const void*)(ptr + stride*32), 1);
__m512i m3 = _mm512_i32gather_epi32(idx, (const void*)((ptr-3/*overflow guard*/)+stride*48), 1);
#if 0 // def NPY_HAVE_AVX512VBMI
// NOTE: experimental
const __m512i perm = npyv_set_u8(
64, 1, 2, 3, 65, 5, 6, 7, 66, 9, 10, 11, 67, 13, 14, 15,
68, 17, 18, 19, 69, 21, 22, 23, 70, 25, 26, 27, 71, 29, 30, 31,
72, 33, 34, 35, 73, 37, 38, 39, 74, 41, 42, 43, 75, 45, 46, 47,
76, 49, 50, 51, 77, 53, 54, 55, 78, 57, 58, 59, 79, 61, 62, 63
);
const __m512i perm_ofg = _mm512_ror_epi32(perm, 8);
__m512i a1 = _mm512_castsi128_si512(_mm512_extracti64x2_epi64(a, 1));
__m512i a2 = _mm512_castsi128_si512(_mm512_extracti64x2_epi64(a, 2));
__m512i a3 = _mm512_castsi128_si512(_mm512_extracti64x2_epi64(a, 3));
__m512i s0 = _mm512_permutex2var_epi8(m0, perm, a);
__m512i s1 = _mm512_permutex2var_epi8(m1, perm, a1);
__m512i s2 = _mm512_permutex2var_epi8(m2, perm, a2);
__m512i s3 = _mm512_permutex2var_epi8(_mm512_rol_epi32(m3, 8), perm_ofg, a3);
#else
#if 0 // def NPY_HAVE_AVX512DQ
__m512i a0 = _mm512_cvtepu8_epi32(_mm512_castsi512_si128(a));
__m512i a1 = _mm512_cvtepu8_epi32(_mm512_extracti64x2_epi64(a, 1));
__m512i a2 = _mm512_cvtepu8_epi32(_mm512_extracti64x2_epi64(a, 2));
__m512i a3 = _mm512_cvtepu8_epi32(_mm512_extracti64x2_epi64(a, 3));
a3 = _mm512_slli_epi32(a3, 24);
#else
__m256i low = _mm512_extracti64x4_epi64(a, 0);
__m256i high = _mm512_extracti64x4_epi64(a, 1);
__m512i a0 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(low));
__m512i a1 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(low, 1));
__m512i a2 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(high));
__m512i a3 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(high, 1));
a3 = _mm512_slli_epi32(a3, 24);
#endif // NPY_HAVE_AVX512DQ
#ifdef NPY_HAVE_AVX512BW
__m512i s0 = _mm512_mask_blend_epi8(0x1111111111111111, m0, a0);
__m512i s1 = _mm512_mask_blend_epi8(0x1111111111111111, m1, a1);
__m512i s2 = _mm512_mask_blend_epi8(0x1111111111111111, m2, a2);
__m512i s3 = _mm512_mask_blend_epi8(0x8888888888888888, m3, a3);
#else
const __m512i maskl = _mm512_set1_epi32(0x000000FF);
const __m512i maskh = _mm512_set1_epi32(0xFF000000);
__m512i s0 = npyv_select_u8(maskl, a0, m0);
__m512i s1 = npyv_select_u8(maskl, a1, m1);
__m512i s2 = npyv_select_u8(maskl, a2, m2);
__m512i s3 = npyv_select_u8(maskh, a3, m3);
#endif // NPY_HAVE_AVX512BW
#endif // AVX512VBMI
_mm512_i32scatter_epi32((int*)ptr, idx, s0, 1);
_mm512_i32scatter_epi32((int*)(ptr + stride*16), idx, s1, 1);
_mm512_i32scatter_epi32((int*)((ptr-3/*overflow guard*/)+ stride*48), idx, s3, 1);
_mm512_i32scatter_epi32((int*)(ptr + stride*32), idx, s2, 1);
}
NPY_FINLINE void npyv_storen_s8(npy_int8 *ptr, int stride, npyv_s8 a)
{ npyv_storen_u8((npy_uint8*)ptr, stride, a); }
//// 16
NPY_FINLINE void npyv_storen_u16(npy_uint16 *ptr, int stride, npyv_u16 a)
{
const __m512i steps = npyv_set_s32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
__m512i m0 = _mm512_i32gather_epi32(idx, (const void*)ptr, 2);
__m512i m1 = _mm512_i32gather_epi32(idx, (const void*)((ptr-1/*overflow guard*/)+stride*16), 2);

__m512i a0 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(a));
__m512i a1 = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a, 1));
a1 = _mm512_slli_epi32(a1, 16);
#ifdef NPY_HAVE_AVX512BW
__m512i s0 = _mm512_mask_blend_epi16(0x55555555, m0, a0);
__m512i s1 = _mm512_mask_blend_epi16(0xAAAAAAAA, m1, a1);
#else
const __m512i mask = _mm512_set1_epi32(0x0000FFFF);
__m512i s0 = npyv_select_u16(mask, a0, m0);
__m512i s1 = npyv_select_u16(mask, m1, a1);
#endif // NPY_HAVE_AVX512BW
_mm512_i32scatter_epi32((int*)ptr, idx, s0, 2);
_mm512_i32scatter_epi32((int*)((ptr-1/*overflow guard*/)+stride*16), idx, s1, 2);
}
NPY_FINLINE void npyv_storen_s16(npy_int16 *ptr, int stride, npyv_s16 a)
{ npyv_storen_u16((npy_uint16*)ptr, stride, a); }
//// 32
NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, int stride, npyv_u32 a)
{
const __m512i steps = _mm512_setr_epi32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32(stride));
_mm512_i32scatter_epi32((int*)ptr, idx, a, 4);
}
NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, int stride, npyv_s32 a)
{ npyv_storen_u32((npy_uint32*)ptr, stride, a); }
NPY_FINLINE void npyv_storen_f32(float *ptr, int stride, npyv_f32 a)
{ npyv_storen_u32((npy_uint32*)ptr, stride, _mm512_castps_si512(a)); }
//// 64
NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, int stride, npyv_u64 a)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
_mm512_i32scatter_epi64((void*)ptr, idx, a, 8);
}
NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, int stride, npyv_s64 a)
{ npyv_storen_u64((npy_uint64*)ptr, stride, a); }
NPY_FINLINE void npyv_storen_f64(double *ptr, int stride, npyv_f64 a)
{ npyv_storen_u64((npy_uint64*)ptr, stride, _mm512_castpd_si512(a)); }

#endif // _NPY_SIMD_AVX512_MEMORY_H
Loading
0