29
29
30
30
#include "py/mpz.h"
31
31
32
- // this is only needed for mp_not_implemented, which should eventually be removed
33
- #include "py/runtime.h"
34
-
35
32
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
36
33
37
34
#define DIG_SIZE (MPZ_DIG_SIZE)
@@ -199,6 +196,14 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
199
196
return idig + 1 - oidig ;
200
197
}
201
198
199
+ STATIC mp_uint_t mpn_remove_trailing_zeros (mpz_dig_t * oidig , mpz_dig_t * idig ) {
200
+ for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
201
+ }
202
+ return idig + 1 - oidig ;
203
+ }
204
+
205
+ #if MICROPY_OPT_MPZ_BITWISE
206
+
202
207
/* computes i = j & k
203
208
returns number of digits in i
204
209
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
@@ -211,41 +216,46 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
211
216
* idig = * jdig & * kdig ;
212
217
}
213
218
214
- // remove trailing zeros
215
- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
216
- }
217
-
218
- return idig + 1 - oidig ;
219
+ return mpn_remove_trailing_zeros (oidig , idig );
219
220
}
220
221
221
- /* computes i = j & -k = j & (~k + 1)
222
+ #endif
223
+
224
+ /* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
225
+ i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
226
+ i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
227
+ computes general form:
228
+ i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
222
229
returns number of digits in i
223
- assumes enough memory in i; assumes normalised j, k
230
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
224
231
can have i, j, k pointing to same memory
225
232
*/
226
- STATIC mp_uint_t mpn_and_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ) {
233
+ STATIC mp_uint_t mpn_and_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
234
+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
227
235
mpz_dig_t * oidig = idig ;
228
- mpz_dbl_dig_t carry = 1 ;
236
+ mpz_dig_t imask = (0 == carryi ) ? 0 : DIG_MASK ;
237
+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
238
+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
229
239
230
- for (; jlen > 0 && klen > 0 ; -- jlen , -- klen , ++ idig , ++ jdig , ++ kdig ) {
231
- carry += * kdig ^ DIG_MASK ;
232
- * idig = (* jdig & carry ) & DIG_MASK ;
233
- carry >>= DIG_SIZE ;
240
+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
241
+ carryj += * jdig ^ jmask ;
242
+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
243
+ carryi += ((carryj & carryk ) ^ imask ) & DIG_MASK ;
244
+ * idig = carryi & DIG_MASK ;
245
+ carryk >>= DIG_SIZE ;
246
+ carryj >>= DIG_SIZE ;
247
+ carryi >>= DIG_SIZE ;
234
248
}
235
249
236
- for (; jlen > 0 ; -- jlen , ++ idig , ++ jdig ) {
237
- carry += DIG_MASK ;
238
- * idig = (* jdig & carry ) & DIG_MASK ;
239
- carry >>= DIG_SIZE ;
240
- }
241
-
242
- // remove trailing zeros
243
- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
250
+ if (0 != carryi ) {
251
+ * idig ++ = carryi ;
244
252
}
245
253
246
- return idig + 1 - oidig ;
254
+ return mpn_remove_trailing_zeros ( oidig , idig ) ;
247
255
}
248
256
257
+ #if MICROPY_OPT_MPZ_BITWISE
258
+
249
259
/* computes i = j | k
250
260
returns number of digits in i
251
261
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -267,6 +277,74 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
267
277
return idig - oidig ;
268
278
}
269
279
280
+ #endif
281
+
282
+ /* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
283
+ i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
284
+ i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
285
+ computes general form:
286
+ i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
287
+ returns number of digits in i
288
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
289
+ can have i, j, k pointing to same memory
290
+ */
291
+
292
+ #if MICROPY_OPT_MPZ_BITWISE
293
+
294
+ STATIC mp_uint_t mpn_or_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
295
+ mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
296
+ mpz_dig_t * oidig = idig ;
297
+ mpz_dbl_dig_t carryi = 1 ;
298
+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
299
+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
300
+
301
+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
302
+ carryj += * jdig ^ jmask ;
303
+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
304
+ carryi += ((carryj | carryk ) ^ DIG_MASK ) & DIG_MASK ;
305
+ * idig = carryi & DIG_MASK ;
306
+ carryk >>= DIG_SIZE ;
307
+ carryj >>= DIG_SIZE ;
308
+ carryi >>= DIG_SIZE ;
309
+ }
310
+
311
+ if (0 != carryi ) {
312
+ * idig ++ = carryi ;
313
+ }
314
+
315
+ return mpn_remove_trailing_zeros (oidig , idig );
316
+ }
317
+
318
+ #else
319
+
320
+ STATIC mp_uint_t mpn_or_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
321
+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
322
+ mpz_dig_t * oidig = idig ;
323
+ mpz_dig_t imask = (0 == carryi ) ? 0 : DIG_MASK ;
324
+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
325
+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
326
+
327
+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
328
+ carryj += * jdig ^ jmask ;
329
+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
330
+ carryi += ((carryj | carryk ) ^ imask ) & DIG_MASK ;
331
+ * idig = carryi & DIG_MASK ;
332
+ carryk >>= DIG_SIZE ;
333
+ carryj >>= DIG_SIZE ;
334
+ carryi >>= DIG_SIZE ;
335
+ }
336
+
337
+ if (0 != carryi ) {
338
+ * idig ++ = carryi ;
339
+ }
340
+
341
+ return mpn_remove_trailing_zeros (oidig , idig );
342
+ }
343
+
344
+ #endif
345
+
346
+ #if MICROPY_OPT_MPZ_BITWISE
347
+
270
348
/* computes i = j ^ k
271
349
returns number of digits in i
272
350
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -285,11 +363,39 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
285
363
* idig = * jdig ;
286
364
}
287
365
288
- // remove trailing zeros
289
- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
366
+ return mpn_remove_trailing_zeros (oidig , idig );
367
+ }
368
+
369
+ #endif
370
+
371
+ /* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
372
+ i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
373
+ i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
374
+ computes general form:
375
+ i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
376
+ returns number of digits in i
377
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
378
+ can have i, j, k pointing to same memory
379
+ */
380
+ STATIC mp_uint_t mpn_xor_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
381
+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
382
+ mpz_dig_t * oidig = idig ;
383
+
384
+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
385
+ carryj += * jdig + DIG_MASK ;
386
+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ + DIG_MASK ) : DIG_MASK ;
387
+ carryi += (carryj ^ carryk ) & DIG_MASK ;
388
+ * idig = carryi & DIG_MASK ;
389
+ carryk >>= DIG_SIZE ;
390
+ carryj >>= DIG_SIZE ;
391
+ carryi >>= DIG_SIZE ;
290
392
}
291
393
292
- return idig + 1 - oidig ;
394
+ if (0 != carryi ) {
395
+ * idig ++ = carryi ;
396
+ }
397
+
398
+ return mpn_remove_trailing_zeros (oidig , idig );
293
399
}
294
400
295
401
/* computes i = i * d1 + d2
@@ -1097,81 +1203,106 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1097
1203
can have dest, lhs, rhs the same
1098
1204
*/
1099
1205
void mpz_and_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1100
- if (lhs -> neg == rhs -> neg ) {
1101
- if (lhs -> neg == 0 ) {
1102
- // make sure lhs has the most digits
1103
- if (lhs -> len < rhs -> len ) {
1104
- const mpz_t * temp = lhs ;
1105
- lhs = rhs ;
1106
- rhs = temp ;
1107
- }
1108
- // do the and'ing
1109
- mpz_need_dig (dest , rhs -> len );
1110
- dest -> len = mpn_and (dest -> dig , lhs -> dig , rhs -> dig , rhs -> len );
1111
- dest -> neg = 0 ;
1112
- } else {
1113
- // TODO both args are negative
1114
- mp_not_implemented ("bignum and with negative args" );
1115
- }
1206
+ // make sure lhs has the most digits
1207
+ if (lhs -> len < rhs -> len ) {
1208
+ const mpz_t * temp = lhs ;
1209
+ lhs = rhs ;
1210
+ rhs = temp ;
1211
+ }
1212
+
1213
+ #if MICROPY_OPT_MPZ_BITWISE
1214
+
1215
+ if ((0 == lhs -> neg ) && (0 == rhs -> neg )) {
1216
+ mpz_need_dig (dest , lhs -> len );
1217
+ dest -> len = mpn_and (dest -> dig , lhs -> dig , rhs -> dig , rhs -> len );
1218
+ dest -> neg = 0 ;
1116
1219
} else {
1117
- // args have different sign
1118
- // make sure lhs is the positive arg
1119
- if (rhs -> neg == 0 ) {
1120
- const mpz_t * temp = lhs ;
1121
- lhs = rhs ;
1122
- rhs = temp ;
1123
- }
1124
1220
mpz_need_dig (dest , lhs -> len + 1 );
1125
- dest -> len = mpn_and_neg (dest -> dig , lhs<
D96B
/span>-> dig , lhs -> len , rhs -> dig , rhs -> len );
1126
- assert ( dest -> len <= dest -> alloc );
1127
- dest -> neg = 0 ;
1221
+ dest -> len = mpn_and_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1222
+ lhs -> neg == rhs -> neg , 0 != lhs -> neg , 0 != rhs -> neg );
1223
+ dest -> neg = lhs -> neg & rhs -> neg ;
1128
1224
}
1225
+
1226
+ #else
1227
+
1228
+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1229
+ dest -> len = mpn_and_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1230
+ (lhs -> neg == rhs -> neg ) ? lhs -> neg : 0 , lhs -> neg , rhs -> neg );
1231
+ dest -> neg = lhs -> neg & rhs -> neg ;
1232
+
1233
+ #endif
1129
1234
}
1130
1235
1131
1236
/* computes dest = lhs | rhs
1132
1237
can have dest, lhs, rhs the same
1133
1238
*/
1134
1239
void mpz_or_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1135
- if (mpn_cmp (lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ) < 0 ) {
1240
+ // make sure lhs has the most digits
1241
+ if (lhs -> len < rhs -> len ) {
1136
1242
const mpz_t * temp = lhs ;
1137
1243
lhs = rhs ;
1138
1244
rhs = temp ;
1139
1245
}
1140
1246
1141
- if (lhs -> neg == rhs -> neg ) {
1247
+ #if MICROPY_OPT_MPZ_BITWISE
1248
+
1249
+ if ((0 == lhs -> neg ) && (0 == rhs -> neg )) {
1142
1250
mpz_need_dig (dest , lhs -> len );
1143
1251
dest -> len = mpn_or (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1252
+ dest -> neg = 0 ;
1144
1253
} else {
1145
- mpz_need_dig (dest , lhs -> len );
1146
- // TODO
1147
- mp_not_implemented ( "bignum or with negative args" );
1148
- // dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len) ;
1254
+ mpz_need_dig (dest , lhs -> len + 1 );
1255
+ dest -> len = mpn_or_neg ( dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1256
+ 0 != lhs -> neg , 0 != rhs -> neg );
1257
+ dest -> neg = 1 ;
1149
1258
}
1150
1259
1151
- dest -> neg = lhs -> neg ;
1260
+ #else
1261
+
1262
+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1263
+ dest -> len = mpn_or_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1264
+ (lhs -> neg || rhs -> neg ), lhs -> neg , rhs -> neg );
1265
+ dest -> neg = lhs -> neg | rhs -> neg ;
1266
+
1267
+ #endif
1152
1268
}
1153
1269
1154
1270
/* computes dest = lhs ^ rhs
1155
1271
can have dest, lhs, rhs the same
1156
1272
*/
1157
1273
void mpz_xor_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1158
- if (mpn_cmp (lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ) < 0 ) {
1274
+ // make sure lhs has the most digits
1275
+ if (lhs -> len < rhs -> len ) {
1159
1276
const mpz_t * temp = lhs ;
1160
1277
lhs = rhs ;
1161
1278
rhs = temp ;
1162
1279
}
1163
1280
1281
+ #if MICROPY_OPT_MPZ_BITWISE
1282
+
1164
1283
if (lhs -> neg == rhs -> neg ) {
1165
1284
mpz_need_dig (dest , lhs -> len );
1166
- dest -> len = mpn_xor (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1285
+ if (lhs -> neg == 0 ) {
1286
+ dest -> len = mpn_xor (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1287
+ } else {
1288
+ dest -> len = mpn_xor_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len , 0 , 0 , 0 );
1289
+ }
1290
+ dest -> neg = 0 ;
1167
1291
} else {
1168
- mpz_need_dig (dest , lhs -> len );
1169
- // TODO
1170
- mp_not_implemented ( "bignum xor with negative args" );
1171
- // dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len) ;
1292
+ mpz_need_dig (dest , lhs -> len + 1 );
1293
+ dest -> len = mpn_xor_neg ( dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len , 1 ,
1294
+ 0 == lhs -> neg , 0 == rhs -> neg );
1295
+ dest -> neg = 1 ;
1172
1296
}
1173
1297
1174
- dest -> neg = 0 ;
1298
+ #else
1299
+
1300
+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1301
+ dest -> len = mpn_xor_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1302
+ (lhs -> neg != rhs -> neg ), 0 == lhs -> neg , 0 == rhs -> neg );
1303
+ dest -> neg = lhs -> neg ^ rhs -> neg ;
1304
+
1305
+ #endif
1175
1306
}
1176
1307
1177
1308
/* computes dest = lhs * rhs
0 commit comments