8000 py/mpz: Complete implementation of mpz_{and,or,xor} for negative args. · micropython/micropython@2e2e15c · GitHub
[go: up one dir, main page]

Skip to content

Commit 2e2e15c

Browse files
dcurriedpgeorge
authored andcommitted< 8000 div class="Box-sc-g0xbh4-0 LoadingSkeleton-sc-695d630a-0 irPhWZ irithh d-none d-sm-flex ml-1" width="60px">
py/mpz: Complete implementation of mpz_{and,or,xor} for negative args.
For these 3 bitwise operations there are now fast functions for positive-only arguments, and general functions for arbitrary sign arguments (the fast functions are the existing implementation). By default the fast functions are not used (to save space) and instead the general functions are used for all operations. Enable MICROPY_OPT_MPZ_BITWISE to use the fast functions for positive arguments.
1 parent 5f3e005 commit 2e2e15c

File tree

8 files changed

+592
-90
lines changed

8 files changed

+592
-90
lines changed

py/mpconfig.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,12 @@
342342
#define MICROPY_OPT_CACHE_MAP_LOOKUP_IN_BYTECODE (0)
343343
#endif
344344

345+
// Whether to use fast versions of bitwise operations (and, or, xor) when the
346+
// arguments are both positive. Increases Thumb2 code size by about 250 bytes.
347+
#ifndef MICROPY_OPT_MPZ_BITWISE
348+
#define MICROPY_OPT_MPZ_BITWISE (0)
349+
#endif
350+
345351
/*****************************************************************************/
346352
/* Python internal features */
347353

py/mpz.c

Lines changed: 199 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929

3030
#include "py/mpz.h"
3131

32-
// this is only needed for mp_not_implemented, which should eventually be removed
33-
#include "py/runtime.h"
34-
3532
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
3633

3734
#define DIG_SIZE (MPZ_DIG_SIZE)
@@ -199,6 +196,14 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
199196
return idig + 1 - oidig;
200197
}
201198

199+
STATIC mp_uint_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
200+
for (--idig; idig >= oidig && *idig == 0; --idig) {
201+
}
202+
return idig + 1 - oidig;
203+
}
204+
205+
#if MICROPY_OPT_MPZ_BITWISE
206+
202207
/* computes i = j & k
203208
returns number of digits in i
204209
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
@@ -211,41 +216,46 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
211216
*idig = *jdig & *kdig;
212217
}
213218

214-
// remove trailing zeros
215-
for (--idig; idig >= oidig && *idig == 0; --idig) {
216-
}
217-
218-
return idig + 1 - oidig;
219+
return mpn_remove_trailing_zeros(oidig, idig);
219220
}
220221

221-
/* computes i = j & -k = j & (~k + 1)
222+
#endif
223+
224+
/* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
225+
i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
226+
i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
227+
computes general form:
228+
i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
222229
returns number of digits in i
223-
assumes enough memory in i; assumes normalised j, k
230+
assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
224231
can have i, j, k pointing to same memory
225232
*/
226-
STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
233+
STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
234+
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
227235
mpz_dig_t *oidig = idig;
228-
mpz_dbl_dig_t carry = 1;
236+
mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
237+
mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
238+
mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
229239

230-
for (; jlen > 0 && klen > 0; --jlen, --klen, ++idig, ++jdig, ++kdig) {
231-
carry += *kdig ^ DIG_MASK;
232-
*idig = (*jdig & carry) & DIG_MASK;
233-
carry >>= DIG_SIZE;
240+
for (; jlen > 0; ++idig, ++jdig) {
241+
carryj += *jdig ^ jmask;
242+
carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
243+
carryi += ((carryj & carryk) ^ imask) & DIG_MASK;
244+
*idig = carryi & DIG_MASK;
245+
carryk >>= DIG_SIZE;
246+
carryj >>= DIG_SIZE;
247+
carryi >>= DIG_SIZE;
234248
}
235249

236-
for (; jlen > 0; --jlen, ++idig, ++jdig) {
237-
carry += DIG_MASK;
238-
*idig = (*jdig & carry) & DIG_MASK;
239-
carry >>= DIG_SIZE;
240-
}
241-
242-
// remove trailing zeros
243-
for (--idig; idig >= oidig && *idig == 0; --idig) {
250+
if (0 != carryi) {
251+
*idig++ = carryi;
244252
}
245253

246-
return idig + 1 - oidig;
254+
return mpn_remove_trailing_zeros(oidig, idig);
247255
}
248256

257+
#if MICROPY_OPT_MPZ_BITWISE
258+
249259
/* computes i = j | k
250260
returns number of digits in i
251261
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -267,6 +277,74 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
267277
return idig - oidig;
268278
}
269279

280+
#endif
281+
282+
/* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
283+
i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
284+
i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
285+
computes general form:
286+
i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
287+
returns number of digits in i
288+
assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
289+
can have i, j, k pointing to same memory
290+
*/
291+
292+
#if MICROPY_OPT_MPZ_BITWISE
293+
294+
STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
295+
mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
296+
mpz_dig_t *oidig = idig;
297+
mpz_dbl_dig_t carryi = 1;
298+
mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
299+
mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
300+
301+
for (; jlen > 0; ++idig, ++jdig) {
302+
carryj += *jdig ^ jmask;
303+
carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
304+
carryi += ((carryj | carryk) ^ DIG_MASK) & DIG_MASK;
305+
*idig = carryi & DIG_MASK;
306+
carryk >>= DIG_SIZE;
307+
carryj >>= DIG_SIZE;
308+
carryi >>= DIG_SIZE;
309+
}
310+
311+
if (0 != carryi) {
312+
*idig++ = carryi;
313+
}
314+
315+
return mpn_remove_trailing_zeros(oidig, idig);
316+
}
317+
318+
#else
319+
320+
STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
321+
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
322+
mpz_dig_t *oidig = idig;
323+
mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
324+
mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
325+
mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
326+
327+
for (; jlen > 0; ++idig, ++jdig) {
328+
carryj += *jdig ^ jmask;
329+
carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
330+
carryi += ((carryj | carryk) ^ imask) & DIG_MASK;
331+
*idig = carryi & DIG_MASK;
332+
carryk >>= DIG_SIZE;
333+
carryj >>= DIG_SIZE;
334+
carryi >>= DIG_SIZE;
335+
}
336+
337+
if (0 != carryi) {
338+
*idig++ = carryi;
339+
}
340+
341+
return mpn_remove_trailing_zeros(oidig, idig);
342+
}
343+
344+
#endif
345+
346+
#if MICROPY_OPT_MPZ_BITWISE
347+
270348
/* computes i = j ^ k
271349
returns number of digits in i
272350
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -285,11 +363,39 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
285363
*idig = *jdig;
286364
}
287365

288-
// remove trailing zeros
289-
for (--idig; idig >= oidig && *idig == 0; --idig) {
366+
return mpn_remove_trailing_zeros(oidig, idig);
367+
}
368+
369+
#endif
370+
371+
/* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
372+
i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
373+
i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
374+
computes general form:
375+
i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
376+
returns number of digits in i
377+
assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
378+
can have i, j, k pointing to same memory
379+
*/
380+
STATIC mp_uint_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
381+
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
382+
mpz_dig_t *oidig = idig;
383+
384+
for (; jlen > 0; ++idig, ++jdig) {
385+
carryj += *jdig + DIG_MASK;
386+
carryk += (--klen <= --jlen) ? (*kdig++ + DIG_MASK) : DIG_MASK;
387+
carryi += (carryj ^ carryk) & DIG_MASK;
388+
*idig = carryi & DIG_MASK;
389+
carryk >>= DIG_SIZE;
390+
carryj >>= DIG_SIZE;
391+
carryi >>= DIG_SIZE;
290392
}
291393

292-
return idig + 1 - oidig;
394+
if (0 != carryi) {
395+
*idig++ = carryi;
396+
}
397+
398+
return mpn_remove_trailing_zeros(oidig, idig);
293399
}
294400

295401
/* computes i = i * d1 + d2
@@ -1097,81 +1203,106 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
10971203
can have dest, lhs, rhs the same
10981204
*/
10991205
void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1100-
if (lhs->neg == rhs->neg) {
1101-
if (lhs->neg == 0) {
1102-
// make sure lhs has the most digits
1103-
if (lhs->len < rhs->len) {
1104-
const mpz_t *temp = lhs;
1105-
lhs = rhs;
1106-
rhs = temp;
1107-
}
1108-
// do the and'ing
1109-
mpz_need_dig(dest, rhs->len);
1110-
dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
1111-
dest->neg = 0;
1112-
} else {
1113-
// TODO both args are negative
1114-
mp_not_implemented("bignum and with negative args");
1115-
}
1206+
// make sure lhs has the most digits
1207+
if (lhs->len < rhs->len) {
1208+
const mpz_t *temp = lhs;
1209+
lhs = rhs;
1210+
rhs = temp;
1211+
}
1212+
1213+
#if MICROPY_OPT_MPZ_BITWISE
1214+
1215+
if ((0 == lhs->neg) && (0 == rhs->neg)) {
1216+
mpz_need_dig(dest, lhs->len);
1217+
dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
1218+
dest->neg = 0;
11161219
} else {
1117-
// args have different sign
1118-
// make sure lhs is the positive arg
1119-
if (rhs->neg == 0) {
1120-
const mpz_t *temp = lhs;
1121-
lhs = rhs;
1122-
rhs = temp;
1123-
}
11241220
mpz_need_dig(dest, lhs->len + 1);
1125-
dest->len = mpn_and_neg(dest->dig, lhs< D96B /span>->dig, lhs->len, rhs->dig, rhs->len);
1126-
assert(dest->len <= dest->alloc);
1127-
dest->neg = 0;
1221+
dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1222+
lhs->neg == rhs->neg, 0 != lhs->neg, 0 != rhs->neg);
1223+
dest->neg = lhs->neg & rhs->neg;
11281224
}
1225+
1226+
#else
1227+
1228+
mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1229+
dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1230+
(lhs->neg == rhs->neg) ? lhs->neg : 0, lhs->neg, rhs->neg);
1231+
dest->neg = lhs->neg & rhs->neg;
1232+
1233+
#endif
11291234
}
11301235

11311236
/* computes dest = lhs | rhs
11321237
can have dest, lhs, rhs the same
11331238
*/
11341239
void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1135-
if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1240+
// make sure lhs has the most digits
1241+
if (lhs->len < rhs->len) {
11361242
const mpz_t *temp = lhs;
11371243
lhs = rhs;
11381244
rhs = temp;
11391245
}
11401246

1141-
if (lhs->neg == rhs->neg) {
1247+
#if MICROPY_OPT_MPZ_BITWISE
1248+
1249+
if ((0 == lhs->neg) && (0 == rhs->neg)) {
11421250
mpz_need_dig(dest, lhs->len);
11431251
dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1252+
dest->neg = 0;
11441253
} else {
1145-
mpz_need_dig(dest, lhs->len);
1146-
// TODO
1147-
mp_not_implemented("bignum or with negative args");
1148-
// dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1254+
mpz_need_dig(dest, lhs->len + 1);
1255+
dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1256+
0 != lhs->neg, 0 != rhs->neg);
1257+
dest->neg = 1;
11491258
}
11501259

1151-
dest->neg = lhs->neg;
1260+
#else
1261+
1262+
mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1263+
dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1264+
(lhs->neg || rhs->neg), lhs->neg, rhs->neg);
1265+
dest->neg = lhs->neg | rhs->neg;
1266+
1267+
#endif
11521268
}
11531269

11541270
/* computes dest = lhs ^ rhs
11551271
can have dest, lhs, rhs the same
11561272
*/
11571273
void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1158-
if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1274+
// make sure lhs has the most digits
1275+
if (lhs->len < rhs->len) {
11591276
const mpz_t *temp = lhs;
11601277
lhs = rhs;
11611278
rhs = temp;
11621279
}
11631280

1281+
#if MICROPY_OPT_MPZ_BITWISE
1282+
11641283
if (lhs->neg == rhs->neg) {
11651284
mpz_need_dig(dest, lhs->len);
1166-
dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1285+
if (lhs->neg == 0) {
1286+
dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1287+
} else {
1288+
dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 0, 0, 0);
1289+
}
1290+
dest->neg = 0;
11671291
} else {
1168-
mpz_need_dig(dest, lhs->len);
1169-
// TODO
1170-
mp_not_implemented("bignum xor with negative args");
1171-
// dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1292+
mpz_need_dig(dest, lhs->len + 1);
1293+
dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 1,
1294+
0 == lhs->neg, 0 == rhs->neg);
1295+
dest->neg = 1;
11721296
}
11731297

1174-
dest->neg = 0;
1298+
#else
1299+
1300+
mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1301+
dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1302+
(lhs->neg != rhs->neg), 0 == lhs->neg, 0 == rhs->neg);
1303+
dest->neg = lhs->neg ^ rhs->neg;
1304+
1305+
#endif
11751306
}
11761307

11771308
/* computes dest = lhs * rhs

stmhal/mpconfigport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#define MICROPY_FLOAT_IMPL (MICROPY_FLOAT_IMPL_FLOAT)
4747
#define MICROPY_OPT_COMPUTED_GOTO (1)
4848
#define MICROPY_OPT_CACHE_MAP_LOOKUP_IN_BYTECODE (0)
49+
#define MICROPY_OPT_MPZ_BITWISE (1)
4950

5051
// fatfs configuration used in ffconf.h
5152
#define MICROPY_FATFS_ENABLE_LFN (1)

0 commit comments

Comments
 (0)
0