8000 py/mpz: Fix bugs with bitwise of -0 by ensuring all 0's are positive. · DvdGiessen/micropython@2c139bb · GitHub
[go: up one dir, main page]

Skip to content

Commit 2c139bb

Browse files
committed
py/mpz: Fix bugs with bitwise of -0 by ensuring all 0's are positive.
This commit makes sure that the value zero is always encoded in an mpz_t as neg=0 and len=0 (previously it was just len=0). This invariant is needed for some of the bitwise operations that operate on negative numbers, because they cannot handle -0. For example (-((1<<100)-(1<<100)))|1 was being computed as -65535, instead of 1. Fixes issue micropython#8042. Signed-off-by: Damien George <damien@micropython.org>
1 parent 05bea70 commit 2c139bb

File tree

3 files changed

+52
-11
lines changed

3 files changed

+52
-11
lines changed

py/mpz.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) {
713713

714714
void mpz_set_from_int(mpz_t *z, mp_int_t val) {
715715
if (val == 0) {
716+
z->neg = 0;
716717
z->len = 0;
717718
return;
718719
}
@@ -899,10 +900,6 @@ bool mpz_is_even(const mpz_t *z) {
899900
#endif
900901

901902
int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
902-
// to catch comparison of -0 with +0
903-
if (z1->len == 0 && z2->len == 0) {
904-
return 0;
905-
}
906903
int cmp = (int)z2->neg - (int)z1->neg;
907904
if (cmp != 0) {
908905
return cmp;
@@ -1052,7 +1049,9 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
10521049
if (dest != z) {
10531050
mpz_set(dest, z);
10541051
}
1055-
dest->neg = 1 - dest->neg;
1052+
if (dest->len) {
1053+
dest->neg = 1 - dest->neg;
1054+
}
10561055
}
10571056

10581057
/* computes dest = ~z (= -z - 1)
@@ -1148,7 +1147,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
11481147
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
11491148
}
11501149

1151-
dest->neg = lhs->neg;
1150+
dest->neg = lhs->neg & !!dest->len;
11521151
}
11531152

11541153
/* computes dest = lhs - rhs
@@ -1172,7 +1171,9 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
11721171
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
11731172
}
11741173

1175-
if (neg) {
1174+
if (dest->len == 0) {
1175+
dest->neg = 0;
1176+
} else if (neg) {
11761177
dest->neg = 1 - lhs->neg;
11771178
} else {
11781179
dest->neg = lhs->neg;
@@ -1484,14 +1485,16 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
14841485

14851486
mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
14861487
memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
1488+
dest_quo->neg = 0;
14871489
dest_quo->len = 0;
14881490
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
14891491
mpz_set(dest_rem, lhs);
14901492
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
1493+
dest_rem->neg &= !!dest_rem->len;
14911494

14921495
// check signs and do Python style modulo
14931496
if (lhs->neg != rhs->neg) {
1494-
dest_quo->neg = 1;
1497+
dest_quo->neg = !!dest_quo->len;
14951498
if (!mpz_is_zero(dest_rem)) {
14961499
mpz_t mpzone;
14971500
mpz_init_from_int(&mpzone, -1);

py/mpz.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ typedef int8_t mpz_dbl_dig_signed_t;
9191
#define MPZ_NUM_DIG_FOR_LL ((sizeof(long long) * 8 + MPZ_DIG_SIZE - 1) / MPZ_DIG_SIZE)
9292

9393
typedef struct _mpz_t {
94+
// Zero has neg=0, len=0. Negative zero is not allowed.
9495
size_t neg : 1;
9596
size_t fixed_dig : 1;
9697
size_t alloc : (8 * sizeof(size_t) - 2);
@@ -119,7 +120,7 @@ static inline bool mpz_is_zero(const mpz_t *z) {
119120
return z->len == 0;
120121
}
121122
static inline bool mpz_is_neg(const mpz_t *z) {
122-
return z->len != 0 && z->neg != 0;
123+
return z->neg != 0;
123124
}
124125
int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs);
125126

tests/basics/int_big_zeroone.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# test [0,-0,1,-1] edge cases of bignum
1+
# test [0,1,-1] edge cases of bignum
22

33
long_zero = (2**64) >> 65
44
long_neg_zero = -long_zero
@@ -13,7 +13,7 @@
1313
print([c >> 1 for c in cases])
1414
print([c << 1 for c in cases])
1515

16-
# comparison of 0/-0/+0
16+
# comparison of 0
1717
print(long_zero == 0)
1818
print(long_neg_zero == 0)
1919
print(long_one - 1 == 0)
@@ -26,3 +26,40 @@
2626
print(long_neg_zero < -1)
2727
print(long_neg_zero > 1)
2828
print(long_neg_zero > -1)
29+
30+
# generate zeros that involve negative numbers
31+
large = 1 << 70
32+
large_plus_one = large + 1
33+
zeros = (
34+
large - large,
35+
-large + large,
36+
large + -large,
37+
-(large - large),
38+
large - large_plus_one + 1,
39+
-large & (large - large),
40+
-large ^ -large,
41+
-large * (large - large),
42+
(large - large) // -large,
43+
-large // -large_plus_one,
44+
-(large + large) % large,
45+
(large + large) % -large,
46+
-(large + large) % -large,
47+
)
48+
print(zeros)
49+
50+
# compute arithmetic operations that may have problems with -0
51+
# (this checks that -0 is never generated in the zeros tuple)
52+
cases = (0, 1, -1) + zeros
53+
for lhs in cases:
54+
print("-{} = {}".format(lhs, -lhs))
55+
print("~{} = {}".format(lhs, ~lhs))
56+
print("{} >> 1 = {}".format(lhs, lhs >> 1))
57+
print("{} << 1 = {}".format(lhs, lhs << 1))
58+
for rhs in cases:
59+
print("{} == {} = {}".format(lhs, rhs, lhs == rhs))
60+
print("{} + {} = {}".format(lhs, rhs, lhs + rhs))
61+
print("{} - {} = {}".format(lhs, rhs, lhs - rhs))
62+
print("{} * {} = {}".format(lhs, rhs, lhs * rhs))
63+
print("{} | {} = {}".format(lhs, rhs, lhs | rhs))
64+
print("{} & {} = {}".format(lhs, rhs, lhs & rhs))
65+
print("{} ^ {} = {}".format(lhs, rhs, lhs ^ rhs))

0 commit comments

Comments
 (0)
0