8000 py/mpz: Fix overflow of borrow in mpn_div. · micropython/micropython@97487e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 97487e6

Browse files
committed
py/mpz: Fix overflow of borrow in mpn_div.
Signed-off-by: Damien George <damien@micropython.org>
1 parent c891190 commit 97487e6

File tree

2 files changed

+20
-39
lines changed

2 files changed

+20
-39
lines changed

py/mpz.c

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -531,60 +531,37 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
531531
quo /= lead_den_digit;
532532

533533
// Multiply quo by den and subtract from num to get remainder.
534-
// We have different code here to handle different compile-time
535-
// configurations of mpz:
536-
//
537-
// 1. DIG_SIZE is stricly less than half the number of bits
538-
// available in mpz_dbl_dig_t. In this case we can use a
539-
// slightly more optimal (in time and space) routine that
540-
// uses the extra bits in mpz_dbl_dig_signed_t to store a
541-
// sign bit.
542-
//
543-
// 2. DIG_SIZE is exactly half the number of bits available in
544-
// mpz_dbl_dig_t. In this (common) case we need to be careful
545-
// not to overflow the borrow variable. And the shifting of
546-
// borrow needs some special logic (it's a shift right with
547-
// round up).
548-
//
534+
// Must be careful with overflow of the borrow variable. Both
535+
// borrow and low_digs are signed values and need signed right-shift,
536+
// but x is unsigned and may take a full-range value.
549537
const mpz_dig_t *d = den_dig;
550538
mpz_dbl_dig_t d_norm = 0;
551-
mpz_dbl_dig_t borrow = 0;
539+
mpz_dbl_dig_signed_t borrow = 0;
552540
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
541+
// Get the next digit in (den).
553542
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
543+
// Multiply the next digit in (quo * den).
554544
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
555-
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
556-
borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
557-
*n = borrow & DIG_MASK;
558-
borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
559-
#else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
560-
if (x >= *n || *n - x <= borrow) {
561-
borrow += x - (mpz_dbl_dig_t)*n;
562-
*n = (-borrow) & DIG_MASK;
563-
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
564-
} else {
565-
*n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
566-
borrow = 0;
567-
}
568-
#endif
545+
// Compute the low DIG_MASK bits of the next digit in (num - quo * den)
546+
mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
547+
// Store the digit result for (num).
548+
*n = low_digs & DIG_MASK;
549+
// Compute the borrow, shifted right before summing to avoid overflow.
550+
borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
569551
}
570552

571-
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
572-
// Borrow was negative in the above for-loop, make it positive for next if-block.
573-
borrow = -borrow;
574-
#endif
575-
576553
// At this point we have either:
577554
//
578555
// 1. quo was the correct value and the most-sig-digit of num is exactly
579-
// cancelled by borrow (borrow == *num_dig). In this case there is
556+
// cancelled by borrow (borrow + *num_dig == 0). In this case there is
580557
// nothing more to do.
581558
//
582559
// 2. quo was too large, we subtracted too many den from num, and the
583-
// most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
560+
// most-sig-digit of num is less than needed (borrow + *num_dig < 0).
584561
// In this case we must reduce quo and add back den to num until the
585562
// carry from this operation cancels out the borrow.
586563
//
587-
borrow -= *num_dig;
564+
borrow += *num_dig;
588565
for (; borrow != 0; --quo) {
589566
d = den_dig;
590567
d_norm = 0;
@@ -595,7 +572,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
595572
*n = carry & DIG_MASK;
596573
carry >>= DIG_SIZE;
597574
}
598-
borrow -= carry;
575+
borrow += carry;
599576
}
600577

601578
// store this digit of the quotient

tests/basics/int_big_div.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
print((x + 1) // x)
99
x = 0x86c60128feff5330
1010
print((x + 1) // x)
11+
12+
# these check edge cases where borrow overflows
13+
print((2 ** 48 - 1) ** 2 // (2 ** 48 - 1))
14+
print((2 ** 256 - 2 ** 32) ** 2 // (2 ** 256 - 2 ** 32))

0 commit comments

Comments
 (0)
0