8000 py/mpz: Fix overflow of borrow in mpn_div. · micropython/micropython@0a59938 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a59938

Browse files
committed
py/mpz: Fix overflow of borrow in mpn_div.
For certain operands to mpn_div, the existing code path for `DIG_SIZE == MPZ_DBL_DIG_SIZE / 2` had a bug in it where borrow could still overflow in the `(x >= *n || *n - x <= borrow)` branch, ie `borrow + x - (mpz_dbl_dig_t)*n` overflows the borrow variable. In such cases the subsequent right-shift of borrow would not bring in the overflow bit, leading to an error in the result. An example division that had overflow when MPZ_DIG_SIZE = 16 is `(2 ** 48 - 1) ** 2 // (2 ** 48 - 1)`. This is fixed in this commit by simplifying the code and handling the low digits of borrow first, and then the upper bits (to shift down) separately. There is no longer a distinction between `DIG_SIZE < MPZ_DBL_DIG_SIZE / 2` and `DIG_SIZE == MPZ_DBL_DIG_SIZE / 2`. This commit also simplifies the second part of the calculation so that borrow does not need to be negated (instead the code just works knowing that borrow is negative and using + instead of - in calculations involving borrow). Fixes #6777. Signed-off-by: Damien George <damien@micropython.org>
1 parent 9dedcf1 commit 0a59938

File tree

2 files changed

+20
-39
lines changed

2 files changed

+20
-39
lines changed

py/mpz.c

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -531,60 +531,37 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
531531
quo /= lead_den_digit;
532532

533533
// Multiply quo by den and subtract from num to get remainder.
534-
// We have different code here to handle different compile-time
535-
// configurations of mpz:
536-
//
537-
// 1. DIG_SIZE is stricly less than half the number of bits
538-
// available in mpz_dbl_dig_t. In this case we can use a
539-
// slightly more optimal (in time and space) routine that
540-
// uses the extra bits in mpz_dbl_dig_signed_t to store a
541-
// sign bit.
542-
//
543-
// 2. DIG_SIZE is exactly half the number of bits available in
544-
// mpz_dbl_dig_t. In this (common) case we need to be careful
545-
// not to overflow the borrow variable. And the shifting of
546-
// borrow needs some special logic (it's a shift right with
547-
// round up).
548-
//
534+
// Must be careful with overflow of the borrow variable. Both
535+
// borrow and low_digs are signed values and need signed right-shift,
536+
// but x is unsigned and may take a full-range value.
549537
const mpz_dig_t *d = den_dig;
550538
mpz_dbl_dig_t d_norm = 0;
551-
mpz_dbl_dig_t borrow = 0;
539+
mpz_dbl_dig_signed_t borrow = 0;
552540
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
541+
// Get the next digit in (den).
553542
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
543+
// Multiply the next digit in (quo * den).
554544
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
555-
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
556-
borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
557-
*n = borrow & DIG_MASK;
558-
borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
559-
#else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
560-
if (x >= *n || *n - x <= borrow) {
561-
borrow += x - (mpz_dbl_dig_t)*n;
562-
*n = (-borrow) & DIG_MASK;
563-
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
564-
} else {
565-
*n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
566-
borrow = 0;
567-
}
568-
#endif
545+
// Compute the low DIG_MASK bits of the next digit in (num - quo * den)
546+
mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
547+
// Store the digit result for (num).
548+
*n = low_digs & DIG_MASK;
549+
// Compute the borrow, shifted right before summing to avoid overflow.
550+
borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
569551
}
570552

571-
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
572-
// Borrow was negative in the above for-loop, make it positive for next if-block.
573-
borrow = -borrow;
574-
#endif
575-
576553
// At this point we have either:
577554
//
578555
// 1. quo was the correct value and the most-sig-digit of num is exactly
579-
// cancelled by borrow (borrow == *num_dig). In this case there is
556+
// cancelled by borrow (borrow + *num_dig == 0). In this case there is
580557
// nothing more to do.
581558
//
582559
// 2. quo was too large, we subtracted too many den from num, and the
583-
// most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
560+
// most-sig-digit of num is less than needed (borrow + *num_dig < 0).
584561
// In this case we must reduce quo and add back den to num until the
585562
// carry from this operation cancels out the borrow.
586563
//
587-
borrow -= *num_dig;
564+
borrow += *num_dig;
588565
for (; borrow != 0; --quo) {
589566
d = den_dig;
590567
d_norm = 0;
@@ -595,7 +572,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
595572
*n = carry & DIG_MASK;
596573
carry >>= DIG_SIZE;
597574
}
598-
borrow -= carry;
575+
borrow += carry;
599576
}
600577

601578
// store this digit of the quotient

tests/basics/int_big_div.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
print((x + 1) // x)
99
x = 0x86c60128feff5330
1010
print((x + 1) // x)
11+
12+
# these check edge cases where borrow overflows
13+
print((2 ** 48 - 1) ** 2 // (2 ** 48 - 1))
14+
print((2 ** 256 - 2 ** 32) ** 2 // (2 ** 256 - 2 ** 32))

0 commit comments

Comments
 (0)
0