8000 SIMD: add NPYV fast integer division intrinsics for NEON · numpy/numpy@2da9858 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2da9858

Browse files
committed
SIMD: add NPYV fast integer division intrinsics for NEON
1 parent 5c185cc commit 2da9858

File tree

1 file changed

+149
-1
lines changed

1 file changed

+149
-1
lines changed

numpy/core/src/common/simd/neon/arithmetic.h

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,154 @@
6060
#define npyv_mul_f32 vmulq_f32
6161
#define npyv_mul_f64 vmulq_f64
6262

63+
/***************************
64+
* Integer Division
65+
***************************/
66+
// See simd/intdiv.h for more clarification
67+
// divide each unsigned 8-bit element by a precomputed divisor
68+
NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
69+
{
70+
const uint8x8_t mulc_lo = vget_low_u8(divisor.val[0]);
71+
// high part of unsigned multiplication
72+
uint16x8_t mull_lo = vmull_u8(vget_low_u8(a), mulc_lo);
73+
#if NPY_SIMD_F64
74+
uint16x8_t mull_hi = vmull_high_u8(a, divisor.val[0]);
75+
// get the high unsigned bytes
76+
uint8x16_t mulhi = vuzp2q_u8(vreinterpretq_u8_u16(mull_lo), vreinterpretq_u8_u16(mull_hi));
77+
#else
78+
const uint8x8_t mulc_hi = vget_high_u8(divisor.val[0]);
79+
uint16x8_t mull_hi = vmull_u8(vget_high_u8(a), mulc_hi);
80+
uint8x16_t mulhi = vuzpq_u8(vreinterpretq_u8_u16(mull_lo), vreinterpretq_u8_u16(mull_hi)).val[1];
81+
#endif
82+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
83+
uint8x16_t q = vsubq_u8(a, mulhi);
84+
q = vshlq_u8(q, vreinterpretq_s8_u8(divisor.val[1]));
85+
q = vaddq_u8(mulhi, q);
86+
q = vshlq_u8(q, vreinterpretq_s8_u8(divisor.val[2]));
87+
return q;
88+
}
89+
// divide each signed 8-bit element by a precomputed divisor (round towards zero)
90+
NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor)
91+
{
92+
const int8x8_t mulc_lo = vget_low_s8(divisor.val[0]);
93+
// high part of signed multiplication
94+
int16x8_t mull_lo = vmull_s8(vget_low_s8(a), mulc_lo);
95+
#if NPY_SIMD_F64
96+
int16x8_t mull_hi = vmull_high_s8(a, divisor.val[0]);
97+
// get the high unsigned bytes
98+
int8x16_t mulhi = vuzp2q_s8(vreinterpretq_s8_s16(mull_lo), vreinterpretq_s8_s16(mull_hi));
99+
#else
100+
const int8x8_t mulc_hi = vget_high_s8(divisor.val[0]);
101+
int16x8_t mull_hi = vmull_s8(vget_high_s8(a), mulc_hi);
102+
int8x16_t mulhi = vuzpq_s8(vreinterpretq_s8_s16(mull_lo), vreinterpretq_s8_s16(mull_hi)).val[1];
103+
#endif
104+
// q = ((a + mulhi) >> sh1) - XSIGN(a)
105+
// trunc(a/d) = (q ^ dsign) - dsign
106+
int8x16_t q = vshlq_s8(vaddq_s8(a, mulhi), divisor.val[1]);
107+
q = vsubq_s8(q, vshrq_n_s8(a, 7));
108+
q = vsubq_s8(veorq_s8(q, divisor.val[2]), divisor.val[2]);
109+
return q;
110+
}
111+
// divide each unsigned 16-bit element by a precomputed divisor
112+
NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor)
113+
{
114+
const uint16x4_t mulc_lo = vget_low_u16(divisor.val[0]);
115+
// high part of unsigned multiplication
116+
uint32x4_t mull_lo = vmull_u16(vget_low_u16(a), mulc_lo);
117+
#if NPY_SIMD_F64
118+
uint32x4_t mull_hi = vmull_high_u16(a, divisor.val[0]);
119+
// get the high unsigned bytes
120+
uint16x8_t mulhi = vuzp2q_u16(vreinterpretq_u16_u32(mull_lo), vreinterpretq_u16_u32(mull_hi));
121+
#else
122+
const uint16x4_t mulc_hi = vget_high_u16(divisor.val[0]);
123+
uint32x4_t mull_hi = vmull_u16(vget_high_u16(a), mulc_hi);
124+
uint16x8_t mulhi = vuzpq_u16(vreinterpretq_u16_u32(mull_lo), vreinterpretq_u16_u32(mull_hi)).val[1];
125+
#endif
126+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
127+
uint16x8_t q = vsubq_u16(a, mulhi);
128+
q = vshlq_u16(q, vreinterpretq_s16_u16(divisor.val[1]));
129+
q = vaddq_u16(mulhi, q);
130+
q = vshlq_u16(q, vreinterpretq_s16_u16(divisor.val[2]));
131+
return q;
132+
}
133+
// divide each signed 16-bit element by a precomputed divisor (round towards zero)
134+
NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor)
135+
{
136+
const int16x4_t mulc_lo = vget_low_s16(divisor.val[0]);
137+
// high part of signed multiplication
138+
int32x4_t mull_lo = vmull_s16(vget_low_s16(a), mulc_lo);
139+
#if NPY_SIMD_F64
140+
int32x4_t mull_hi = vmull_high_s16(a, divisor.val[0]);
141+
// get the high unsigned bytes
142+
int16x8_t mulhi = vuzp2q_s16(vreinterpretq_s16_s32(mull_lo), vreinterpretq_s16_s32(mull_hi));
143+
#else
144+
const int16x4_t mulc_hi = vget_high_s16(divisor.val[0]);
145+
int32x4_t mull_hi = vmull_s16(vget_high_s16(a), mulc_hi);
146+
int16x8_t mulhi = vuzpq_s16(vreinterpretq_s16_s32(mull_lo), vreinterpretq_s16_s32(mull_hi)).val[1];
147+
#endif
148+
// q = ((a + mulhi) >> sh1) - XSIGN(a)
149+
// trunc(a/d) = (q ^ dsign) - dsign
150+
int16x8_t q = vshlq_s16(vaddq_s16(a, mulhi), divisor.val[1]);
151+
q = vsubq_s16(q, vshrq_n_s16(a, 15));
152+
q = vsubq_s16(veorq_s16(q, divisor.val[2]), divisor.val[2]);
153+
return q;
154+
}
155+
// divide each unsigned 32-bit element by a precomputed divisor
156+
NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor)
157+
{
158+
const uint32x2_t mulc_lo = vget_low_u32(divisor.val[0]);
159+
// high part of unsigned multiplication
160+
uint64x2_t mull_lo = vmull_u32(vget_low_u32(a), mulc_lo);
161+
#if NPY_SIMD_F64
162+
uint64x2_t mull_hi = vmull_high_u32(a, divisor.val[0]);
163+
// get the high unsigned bytes
164+
uint32x4_t mulhi = vuzp2q_u32(vreinterpretq_u32_u64(mull_lo), vreinterpretq_u32_u64(mull_hi));
165+
#else
166+
const uint32x2_t mulc_hi = vget_high_u32(divisor.val[0]);
167+
uint64x2_t mull_hi = vmull_u32(vget_high_u32(a), mulc_hi);
168+
uint32x4_t mulhi = vuzpq_u32(vreinterpretq_u32_u64(mull_lo), vreinterpretq_u32_u64(mull_hi)).val[1];
169+
#endif
170+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
171+
uint32x4_t q = vsubq_u32(a, mulhi);
172+
q = vshlq_u32(q, vreinterpretq_s32_u32(divisor.val[1]));
173+
q = vaddq_u32(mulhi, q);
174+
q = vshlq_u32(q, vreinterpretq_s32_u32(divisor.val[2]));
175+
return q;
176+
}
177+
// divide each signed 32-bit element by a precomputed divisor (round towards zero)
178+
NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor)
179+
{
180+
const int32x2_t mulc_lo = vget_low_s32(divisor.val[0]);
181+
// high part of signed multiplication
182+
int64x2_t mull_lo = vmull_s32(vget_low_s32(a), mulc_lo);
183+
#if NPY_SIMD_F64
184+
int64x2_t mull_hi = vmull_high_s32(a, divisor.val[0]);
185+
// get the high unsigned bytes
186+
int32x4_t mulhi = vuzp2q_s32(vreinterpretq_s32_s64(mull_lo), vreinterpretq_s32_s64(mull_hi));
187+
#else
188+
const int32x2_t mulc_hi = vget_high_s32(divisor.val[0]);
189+
int64x2_t mull_hi = vmull_s32(vget_high_s32(a), mulc_hi);
190+
int32x4_t mulhi = vuzpq_s32(vreinterpretq_s32_s64(mull_lo), vreinterpretq_s32_s64(mull_hi)).val[1];
191+
#endif
192+
// q = ((a + mulhi) >> sh1) - XSIGN(a)
193+
// trunc(a/d) = (q ^ dsign) - dsign
194+
int32x4_t q = vshlq_s32(vaddq_s32(a, mulhi), divisor.val[1]);
195+
q = vsubq_s32(q, vshrq_n_s32(a, 31));
196+
q = vsubq_s32(veorq_s32(q, divisor.val[2]), divisor.val[2]);
197+
return q;
198+
}
199+
// divide each unsigned 64-bit element by a divisor
200+
NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor)
201+
{
202+
const uint64_t d = vgetq_lane_u64(divisor.val[0], 0);
203+
return npyv_set_u64(vgetq_lane_u64(a, 0) / d, vgetq_lane_u64(a, 1) / d);
204+
}
205+
// returns the high 64 bits of signed 64-bit multiplication
206+
NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor)
207+
{
208+
const int64_t d = vgetq_lane_s64(divisor.val[0], 0);
209+
return npyv_set_s64(vgetq_lane_s64(a, 0) / d, vgetq_lane_s64(a, 1) / d);
210+
}
63211
/***************************
64212
* Division
65213
***************************/
@@ -148,7 +296,7 @@
148296

149297
NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a)
150298
{
151-
uint32x2_t a0 = vpadd_u32(vget_low_u32(a), vget_high_u32(a));
299+
uint32x2_t a0 = vpadd_u32(vget_low_u32(a), vget_high_u32(a));
152300
return (unsigned)vget_lane_u32(vpadd_u32(a0, vget_high_u32(a)),0);
153301
}
154302

0 commit comments

Comments
 (0)
0