10000 BUG, SIMD: Fix unexpected result of uint8 division on X86 · numpy/numpy@519ab99 · GitHub
[go: up one dir, main page]

Skip to content

Commit 519ab99

Browse files
committed
BUG, SIMD: Fix unexpected result of uint8 division on X86
The bug can occur in special cases e.g. when the divisor is scalar and equal to 9 or 13 and the dividend is array contains consecutive duplicate values of 233.
1 parent 6f20549 commit 519ab99

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

numpy/core/src/common/simd/avx2/arithmetic.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,16 @@
7373
// divide each unsigned 8-bit element by a precomputed divisor
7474
NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
7575
{
76-
const __m256i bmask = _mm256_set1_epi32(0xFF00FF00);
76+
const __m256i bmask = _mm256_set1_epi32(0x00FF00FF);
7777
const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]);
7878
const __m128i shf2 = _mm256_castsi256_si128(divisor.val[2]);
7979
const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
8080
const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
8181
// high part of unsigned multiplication
82-
__m256i mulhi_odd = _mm256_mulhi_epu16(a, divisor.val[0]);
83-
__m256i mulhi_even = _mm256_mulhi_epu16(_mm256_slli_epi16(a, 8), divisor.val[0]);
82+
__m256i mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(a, bmask), divisor.val[0]);
8483
mulhi_even = _mm256_srli_epi16(mulhi_even, 8);
85-
__m256i mulhi = _mm256_blendv_epi8(mulhi_even, mulhi_odd, bmask);
84+
__m256i mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(a, 8), divisor.val[0]);
85+
__m256i mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask);
8686
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
8787
__m256i q = _mm256_sub_epi8(a, mulhi);
8888
q = _mm256_and_si256(_mm256_srl_epi16(q, shf1), shf1b);

numpy/core/src/common/simd/avx512/arithmetic.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,13 @@ NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
116116
const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
117117
const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
118118
#ifdef NPY_HAVE_AVX512BW
119+
const __m512i bmask = _mm512_set1_epi32(0x00FF00FF);
119120
const __m512i shf1b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
120121
const __m512i shf2b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
121122
// high part of unsigned multiplication
122-
__m512i mulhi_odd = _mm512_mulhi_epu16(a, divisor.val[0]);
123-
__m512i mulhi_even = _mm512_mulhi_epu16(_mm512_slli_epi16(a, 8), divisor.val[0]);
123+
__m512i mulhi_even = _mm512_mullo_epi16(_mm512_and_si512(a, bmask), divisor.val[0]);
124124
mulhi_even = _mm512_srli_epi16(mulhi_even, 8);
125+
__m512i mulhi_odd = _mm512_mullo_epi16(_mm512_srli_epi16(a, 8), divisor.val[0]);
125126
__m512i mulhi = _mm512_mask_mov_epi8(mulhi_even, 0xAAAAAAAAAAAAAAAA, mulhi_odd);
126127
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
127128
__m512i q = _mm512_sub_epi8(a, mulhi);
@@ -130,18 +131,18 @@ NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
130131
q = _mm512_and_si512(_mm512_srl_epi16(q, shf2), shf2b);
131132
return q;
132133
#else
133-
const __m256i bmask = _mm256_set1_epi32(0xFF00FF00);
134+
const __m256i bmask = _mm256_set1_epi32(0x00FF00FF);
134135
const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
135136
const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
136137
const __m512i shf2bw= npyv512_combine_si256(shf2b, shf2b);
137138
const __m256i mulc = npyv512_lower_si256(divisor.val[0]);
138139
//// lower 256-bit
139140
__m256i lo_a = npyv512_lower_si256(a);
140141
// high part of unsigned multiplication
141-
__m256i mulhi_odd = _mm256_mulhi_epu16(lo_a, mulc);
142-
__m256i mulhi_even = _mm256_mulhi_epu16(_mm256_slli_epi16(lo_a, 8), mulc);
142+
__m256i mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(lo_a, bmask), mulc);
143143
mulhi_even = _mm256_srli_epi16(mulhi_even, 8);
144-
__m256i mulhi = _mm256_blendv_epi8(mulhi_even, mulhi_odd, bmask);
144+
__m256i mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(lo_a, 8), mulc);
145+
__m256i mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask);
145146
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
146147
__m256i lo_q = _mm256_sub_epi8(lo_a, mulhi);
147148
lo_q = _mm256_and_si256(_mm256_srl_epi16(lo_q, shf1), shf1b);
@@ -151,10 +152,10 @@ NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
151152
//// higher 256-bit
152153
__m256i hi_a = npyv512_higher_si256(a);
153154
// high part of unsigned multiplication
154-
mulhi_odd = _mm256_mulhi_epu16(hi_a, mulc);
155-
mulhi_even = _mm256_mulhi_epu16(_mm256_slli_epi16(hi_a, 8), mulc);
155+
mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(hi_a, bmask), mulc);
156156
mulhi_even = _mm256_srli_epi16(mulhi_even, 8);
157-
mulhi = _mm256_blendv_epi8(mulhi_even, mulhi_odd, bmask);
157+
mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(hi_a, 8), mulc);
158+
mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask);
158159
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
159160
__m256i hi_q = _mm256_sub_epi8(hi_a, mulhi);
160161
hi_q = _mm256_and_si256(_mm256_srl_epi16(hi_q, shf1), shf1b);

numpy/core/src/common/simd/intdiv.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,16 @@ NPY_FINLINE npyv_u8x3 npyv_divisor_u8(npy_uint8 d)
204204
sh1 = 1; sh2 = l - 1; // shift counts
205205
}
206206
npyv_u8x3 divisor;
207-
divisor.val[0] = npyv_setall_u8(m);
208207
#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512
208+
divisor.val[0] = npyv_setall_u16(m);
209209
divisor.val[1] = npyv_set_u8(sh1);
210210
divisor.val[2] = npyv_set_u8(sh2);
211211
#elif defined(NPY_HAVE_VSX2)
212+
divisor.val[0] = npyv_setall_u8(m);
212213
divisor.val[1] = npyv_setall_u8(sh1);
213214
divisor.val[2] = npyv_setall_u8(sh2);
214215
#elif defined(NPY_HAVE_NEON)
216+
divisor.val[0] = npyv_setall_u8(m);
215217
divisor.val[1] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh1));
216218
divisor.val[2] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh2));
217219
#else

numpy/core/src/common/simd/sse/arithmetic.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b)
9292
// divide each unsigned 8-bit element by a precomputed divisor
9393
NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
9494
{
95-
const __m128i bmask = _mm_set1_epi32(0xFF00FF00);
95+
const __m128i bmask = _mm_set1_epi32(0x00FF00FF);
9696
const __m128i shf1b = _mm_set1_epi8(0xFFU >> _mm_cvtsi128_si32(divisor.val[1]));
9797
const __m128i shf2b = _mm_set1_epi8(0xFFU >> _mm_cvtsi128_si32(divisor.val[2]));
9898
// high part of unsigned multiplication
99-
__m128i mulhi_odd = _mm_mulhi_epu16(a, divisor.val[0]);
100-
__m128i mulhi_even = _mm_mulhi_epu16(_mm_slli_epi16(a, 8), divisor.val[0]);
99+
__m128i mulhi_even = _mm_mullo_epi16(_mm_and_si128(a, bmask), divisor.val[0]);
100+
__m128i mulhi_odd = _mm_mullo_epi16(_mm_srli_epi16(a, 8), divisor.val[0]);
101101
mulhi_even = _mm_srli_epi16(mulhi_even, 8);
102-
__m128i mulhi = npyv_select_u8(bmask, mulhi_odd, mulhi_even);
102+
__m128i mulhi = npyv_select_u8(bmask, mulhi_even, mulhi_odd);
103103
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
104104
__m128i q = _mm_sub_epi8(a, mulhi);
105105
q = _mm_and_si128(_mm_srl_epi16(q, divisor.val[1]), shf1b);

0 commit comments

Comments
 (0)
0