10000 py/mpz: Simplify handling of borrow and quo adjustment in mpn_div. · neverhover/micropython@9766fdd · GitHub
[go: up one dir, main page]

Skip to content

Commit 9766fdd

Browse files
committed
py/mpz: Simplify handling of borrow and quo adjustment in mpn_div.
The motivation behind this patch is to remove unreachable code in mpn_div. This unreachable code was added some time ago in 9a21d2e, when a loop in mpn_div was copied and adjusted to work when mpz_dig_t was exactly half of the size of mpz_dbl_dig_t (a common case). The loop was copied correctly but it wasn't noticed at the time that the final part of the calculation of num-quo*den could be optimised, and hence unreachable code was left for a case that never occurred. The observation for the optimisation is that the initial value of quo in mpn_div is either exact or too large (never too small), and therefore the subtraction of quo*den from num may subtract exactly enough or too much (but never too little). Using this observation the part of the algorithm that handles the borrow value can be simplified, and most importantly this eliminates the unreachable code. The new code has been tested with DIG_SIZE=3 and DIG_SIZE=4 by dividing all possible combinations of non-negative integers with between 0 and 3 (inclusive) mpz digits.
1 parent c7cb1df commit 9766fdd

File tree

2 files changed

+48
-70
lines changed

2 files changed

+48
-70
lines changed

py/mpz.c

Lines changed: 44 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -537,83 +537,57 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
537537
// not to overflow the borrow variable. And the shifting of
538538
// borrow needs some special logic (it's a shift right with
539539
// round up).
540-
541-
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
542-
const mpz_dig_t *d = den_dig;
543-
mpz_dbl_dig_t d_norm = 0;
544-
mpz_dbl_dig_signed_t borrow = 0;
545-
546-
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
547-
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
548-
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
549-
*n = borrow & DIG_MASK;
550-
borrow >>= DIG_SIZE;
551-
}
552-
borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
553-
*num_dig = borrow & DIG_MASK;
554-
borrow >>= DIG_SIZE;
555-
556-
// adjust quotient if it is too big
557-
for (; borrow != 0; --quo) {
558-
d = den_dig;
559-
d_norm = 0;
560-
mpz_dbl_dig_t carry = 0;
561-
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
562-
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
563-
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
564-
*n = carry & DIG_MASK;
565-
carry >>= DIG_SIZE;
566-
}
567-
carry += *num_dig;
568-
*num_dig = carry & DIG_MASK;
569-
carry >>= DIG_SIZE;
570-
571-
borrow += carry;
572-
}
573-
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
574-
const mpz_dig_t *d = den_dig;
575-
mpz_dbl_dig_t d_norm = 0;
576-
mpz_dbl_dig_t borrow = 0;
577-
578-
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
579-
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
580-
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
581-
if (x >= *n || *n - x <= borrow) {
582-
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
583-
*n = (-borrow) & DIG_MASK;
584-
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
585-
} else {
586-
*n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
587-
borrow = 0;
588-
}
589-
}
590-
if (borrow >= *num_dig) {
591-
borrow -= (mpz_dbl_dig_t)*num_dig;
592-
*num_dig = (-borrow) & DIG_MASK;
540+
//
541+
const mpz_dig_t *d = den_dig;
542+
mpz_dbl_dig_t d_norm = 0;
543+
mpz_dbl_dig_t borrow = 0;
544+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
545+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
546+
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
547+
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
548+
borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
549+
*n = borrow & DIG_MASK;
550+
borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
551+
#else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
552+
if (x >= *n || *n - x <= borrow) {
553+
borrow += x - (mpz_dbl_dig_t)*n;
554+
*n = (-borrow) & DIG_MASK;
593555
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
594556
} else {
595-
*num_dig = (*num_dig - borrow) & DIG_MASK;
557+
*n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
596558
borrow = 0;
597559
}
560+
#endif
561+
}
598562

599-
// adjust quotient if it is too big
600-
for (; borrow != 0; --quo) {
601-
d = den_dig;
602-
d_norm = 0;
603-
mpz_dbl_dig_t carry = 0;
604-
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
605-
d_norm = ((mpz_dbl_dig_t)*d << norm_shif 57AE t) | (d_norm >> DIG_SIZE);
606-
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
607-
*n = carry & DIG_MASK;
608-
carry >>= DIG_SIZE;
609-
}
610-
carry += (mpz_dbl_dig_t)*num_dig;
611-
*num_dig = carry & DIG_MASK;
612-
carry >>= DIG_SIZE;
563+
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
564+
// Borrow was negative in the above for-loop, make it positive for next if-block.
565+
borrow = -borrow;
566+
#endif
613567

614-
//assert(borrow >= carry); // enable this to check the logic
615-
borrow -= carry;
568+
// At this point we have either:
569+
//
570+
// 1. quo was the correct value and the most-sig-digit of num is exactly
571+
// cancelled by borrow (borrow == *num_dig). In this case there is
572+
// nothing more to do.
573+
//
574+
// 2. quo was too large, we subtracted too many den from num, and the
575+
// most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
576+
// In this case we must reduce quo and add back den to num until the
577+
// carry from this operation cancels out the borrow.
578+
//
579+
borrow -= *num_dig;
580+
for (; borrow != 0; --quo) {
581+
d = den_dig;
582+
d_norm = 0;
583+
mpz_dbl_dig_t carry = 0;
584+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
585+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
586+
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
587+
*n = carry & DIG_MASK;
588+
carry >>= DIG_SIZE;
616589
}
590+
borrow -= carry;
617591
}
618592

619593
// store this digit of the quotient

py/mpz.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,22 @@
5555
#endif
5656

5757
#if MPZ_DIG_SIZE > 16
58+
#define MPZ_DBL_DIG_SIZE (64)
5859
typedef uint32_t mpz_dig_t;
5960
typedef uint64_t mpz_dbl_dig_t;
6061
typedef int64_t mpz_dbl_dig_signed_t;
6162
#elif MPZ_DIG_SIZE > 8
63+
#define MPZ_DBL_DIG_SIZE (32)
6264
typedef uint16_t mpz_dig_t;
6365
typedef uint32_t mpz_dbl_dig_t;
6466
typedef int32_t mpz_dbl_dig_signed_t;
6567
#elif MPZ_DIG_SIZE > 4
68+
#define MPZ_DBL_DIG_SIZE (16)
6669
typedef uint8_t mpz_dig_t;
6770
typedef uint16_t mpz_dbl_dig_t;
6871
typedef int16_t mpz_dbl_dig_signed_t;
6972
#else
73+
#define MPZ_DBL_DIG_SIZE (8)
7074
typedef uint8_t mpz_dig_t;
7175
typedef uint8_t mpz_dbl_dig_t;
7276
typedef int8_t mpz_dbl_dig_signed_t;

0 commit comments

Comments
 (0)
0