📄 binaryinteger.java
字号:
/**
* 取模运算
*
* @param m
* the modulus.
* @return <tt>this mod m</tt>
* @throws ArithmeticException
* <tt>m <= 0</tt>
* @see #remainder
*/
public BinaryInteger mod(BinaryInteger m) {
if (m.signum <= 0)
throw new ArithmeticException("BinaryInteger: modulus not positive");
BinaryInteger result = this.remainder(m);
return (result.signum >= 0 ? result : result.add(m));
}
/**
* 乘幂取模运算
*
* @param exponent
* the exponent.
* @param m
* the modulus.
* @return <tt>this<sup>exponent</sup> mod m</tt>
* @throws ArithmeticException
* <tt>m <= 0</tt>
* @see #modInverse
*/
public BinaryInteger modPow(BinaryInteger exponent, BinaryInteger m) {
if (m.signum <= 0)
throw new ArithmeticException("BinaryInteger: modulus not positive");
// Trivial cases
if (exponent.signum == 0)
return (m.equals(ONE) ? ZERO : ONE);
if (this.equals(ONE))
return (m.equals(ONE) ? ZERO : ONE);
if (this.equals(ZERO) && exponent.signum >= 0)
return ZERO;
if (this.equals(negConst[1]) && (!exponent.testBit(0)))
return (m.equals(ONE) ? ZERO : ONE);
boolean invertResult;
if ((invertResult = (exponent.signum < 0)))
exponent = exponent.negate();
BinaryInteger base = (this.signum < 0 || this.compareTo(m) >= 0 ? this
.mod(m) : this);
BinaryInteger result;
if (m.testBit(0)) { // odd modulus
result = base.oddModPow(exponent, m);
} else {
/*
* Even modulus. Tear it into an "odd part" (m1) and power of two
* (m2), exponentiate mod m1, manually exponentiate mod m2, and use
* Chinese Remainder Theorem to combine results.
*/
// Tear m apart into odd part (m1) and power of 2 (m2)
int p = m.getLowestSetBit(); // Max pow of 2 that divides m
BinaryInteger m1 = m.shiftRight(p); // m/2**p
BinaryInteger m2 = ONE.shiftLeft(p); // 2**p
// Calculate new base from m1
BinaryInteger base2 = (this.signum < 0 || this.compareTo(m1) >= 0 ? this
.mod(m1)
: this);
// Caculate (base ** exponent) mod m1.
BinaryInteger a1 = (m1.equals(ONE) ? ZERO : base2.oddModPow(
exponent, m1));
// Calculate (this ** exponent) mod m2
BinaryInteger a2 = base.modPow2(exponent, p);
// Combine results using Chinese Remainder Theorem
BinaryInteger y1 = m2.modInverse(m1);
BinaryInteger y2 = m1.modInverse(m2);
result = a1.multiply(m2).multiply(y1).add(
a2.multiply(m1).multiply(y2)).mod(m);
}
return (invertResult ? result.modInverse(m) : result);
}
static int[] bnExpModThreshTable = { 7, 25, 81, 241, 673, 1793,
Integer.MAX_VALUE }; // Sentinel
/**
* Returns a BinaryInteger whose value is x to the power of y mod z.
* Assumes: z is odd && x < z.
*/
private BinaryInteger oddModPow(BinaryInteger y, BinaryInteger z) {
/*
* The algorithm is adapted from Colin Plumb's C library.
*
* The window algorithm: The idea is to keep a running product of b1 =
* n^(high-order bits of exp) and then keep appending exponent bits to
* it. The following patterns apply to a 3-bit window (k = 3): To append
* 0: square To append 1: square, multiply by n^1 To append 10: square,
* multiply by n^1, square To append 11: square, square, multiply by n^3
* To append 100: square, multiply by n^1, square, square To append 101:
* square, square, square, multiply by n^5 To append 110: square,
* square, multiply by n^3, square To append 111: square, square,
* square, multiply by n^7
*
* Since each pattern involves only one multiply, the longer the pattern
* the better, except that a 0 (no multiplies) can be appended directly.
* We precompute a table of odd powers of n, up to 2^k, and can then
* multiply k bits of exponent at a time. Actually, assuming random
* exponents, there is on average one zero bit between needs to multiply
* (1/2 of the time there's none, 1/4 of the time there's 1, 1/8 of the
* time, there's 2, 1/32 of the time, there's 3, etc.), so you have to
* do one multiply per k+1 bits of exponent.
*
* The loop walks down the exponent, squaring the result buffer as it
* goes. There is a wbits+1 bit lookahead buffer, buf, that is filled
* with the upcoming exponent bits. (What is read after the end of the
* exponent is unimportant, but it is filled with zero here.) When the
* most-significant bit of this buffer becomes set, i.e. (buf & tblmask) !=
* 0, we have to decide what pattern to multiply by, and when to do it.
* We decide, remember to do it in future after a suitable number of
* squarings have passed (e.g. a pattern of "100" in the buffer requires
* that we multiply by n^1 immediately; a pattern of "110" calls for
* multiplying by n^3 after one more squaring), clear the buffer, and
* continue.
*
* When we start, there is one more optimization: the result buffer is
* implcitly one, so squaring it or multiplying by it can be optimized
* away. Further, if we start with a pattern like "100" in the lookahead
* window, rather than placing n into the buffer and then starting to
* square it, we have already computed n^2 to compute the odd-powers
* table, so we can place that into the buffer and save a squaring.
*
* This means that if you have a k-bit window, to compute n^z, where z
* is the high k bits of the exponent, 1/2 of the time it requires no
* squarings. 1/4 of the time, it requires 1 squaring, ... 1/2^(k-1) of
* the time, it reqires k-2 squarings. And the remaining 1/2^(k-1) of
* the time, the top k bits are a 1 followed by k-1 0 bits, so it again
* only requires k-2 squarings, not k-1. The average of these is 1. Add
* that to the one squaring we have to do to compute the table, and
* you'll see that a k-bit window saves k-2 squarings as well as
* reducing the multiplies. (It actually doesn't hurt in the case k = 1,
* either.)
*/
// Special case for exponent of one
if (y.equals(ONE))
return this;
// Special case for base of zero
if (signum == 0)
return ZERO;
int[] base = (int[]) mag.clone();
int[] exp = y.mag;
int[] mod = z.mag;
int modLen = mod.length;
// Select an appropriate window size
int wbits = 0;
int ebits = bitLength(exp, exp.length);
// if exponent is 65537 (0x10001), use minimum window size
if ((ebits != 17) || (exp[0] != 65537)) {
while (ebits > bnExpModThreshTable[wbits]) {
wbits++;
}
}
// Calculate appropriate table size
int tblmask = 1 << wbits;
// Allocate table for precomputed odd powers of base in Montgomery form
int[][] table = new int[tblmask][];
for (int i = 0; i < tblmask; i++)
table[i] = new int[modLen];
// Compute the modular inverse
int inv = -MutableBigInteger.inverseMod32(mod[modLen - 1]);
// Convert base to Montgomery form
int[] a = leftShift(base, base.length, modLen << 5);
MutableBigInteger q = new MutableBigInteger(), r = new MutableBigInteger(), a2 = new MutableBigInteger(
a), b2 = new MutableBigInteger(mod);
a2.divide(b2, q, r);
table[0] = r.toIntArray();
// Pad table[0] with leading zeros so its length is at least modLen
if (table[0].length < modLen) {
int offset = modLen - table[0].length;
int[] t2 = new int[modLen];
for (int i = 0; i < table[0].length; i++)
t2[i + offset] = table[0][i];
table[0] = t2;
}
// Set b to the square of the base
int[] b = squareToLen(table[0], modLen, null);
b = montReduce(b, mod, modLen, inv);
// Set t to high half of b
int[] t = new int[modLen];
for (int i = 0; i < modLen; i++)
t[i] = b[i];
// Fill in the table with odd powers of the base
for (int i = 1; i < tblmask; i++) {
int[] prod = multiplyToLen(t, modLen, table[i - 1], modLen, null);
table[i] = montReduce(prod, mod, modLen, inv);
}
// Pre load the window that slides over the exponent
int bitpos = 1 << ((ebits - 1) & (32 - 1));
int buf = 0;
int elen = exp.length;
int eIndex = 0;
for (int i = 0; i <= wbits; i++) {
buf = (buf << 1) | (((exp[eIndex] & bitpos) != 0) ? 1 : 0);
bitpos >>>= 1;
if (bitpos == 0) {
eIndex++;
bitpos = 1 << (32 - 1);
elen--;
}
}
int multpos = ebits;
// The first iteration, which is hoisted out of the main loop
ebits--;
boolean isone = true;
multpos = ebits - wbits;
while ((buf & 1) == 0) {
buf >>>= 1;
multpos++;
}
int[] mult = table[buf >>> 1];
buf = 0;
if (multpos == ebits)
isone = false;
// The main loop
while (true) {
ebits--;
// Advance the window
buf <<= 1;
if (elen != 0) {
buf |= ((exp[eIndex] & bitpos) != 0) ? 1 : 0;
bitpos >>>= 1;
if (bitpos == 0) {
eIndex++;
bitpos = 1 << (32 - 1);
elen--;
}
}
// Examine the window for pending multiplies
if ((buf & tblmask) != 0) {
multpos = ebits - wbits;
while ((buf & 1) == 0) {
buf >>>= 1;
multpos++;
}
mult = table[buf >>> 1];
buf = 0;
}
// Perform multiply
if (ebits == multpos) {
if (isone) {
b = (int[]) mult.clone();
isone = false;
} else {
t = b;
a = multiplyToLen(t, modLen, mult, modLen, a);
a = montReduce(a, mod, modLen, inv);
t = a;
a = b;
b = t;
}
}
// Check if done
if (ebits == 0)
break;
// Square the input
if (!isone) {
t = b;
a = squareToLen(t, modLen, a);
a = montReduce(a, mod, modLen, inv);
t = a;
a = b;
b = t;
}
}
// Convert result out of Montgomery form and return
int[] t2 = new int[2 * modLen];
for (int i = 0; i < modLen; i++)
t2[i + modLen] = b[i];
b = montReduce(t2, mod, modLen, inv);
t2 = new int[modLen];
for (int i = 0; i < modLen; i++)
t2[i] = b[i];
return new BinaryInteger(1, t2);
}
/**
* Montgomery reduce n, modulo mod. This reduces modulo mod and divides by
* 2^(32*mlen). Adapted from Colin Plumb's C library.
*/
private static int[] montReduce(int[] n, int[] mod, int mlen, int inv) {
int c = 0;
int len = mlen;
int offset = 0;
do {
int nEnd = n[n.length - 1 - offset];
int carry = mulAdd(n, mod, offset, mlen, inv * nEnd);
c += addOne(n, offset, mlen, carry);
offset++;
} while (--len > 0);
while (c > 0)
c += subN(n, mod, mlen);
while (intArrayCmpToLen(n, mod, mlen) >= 0)
subN(n, mod, mlen);
return n;
}
/*
* Returns -1, 0 or +1 as big-endian unsigned int array arg1 is less than,
* equal to, or greater than arg2 up to length len.
*/
private static int intArrayCmpToLen(int[] arg1, int[] arg2, int len) {
for (int i = 0; i < len; i++) {
long b1 = arg1[i] & LONG_MASK;
long b2 = arg2[i] & LONG_MASK;
if (b1 < b2)
return -1;
if (b1 > b2)
return 1;
}
return 0;
}
/**
* Subtracts two numbers of same length, returning borrow.
*/
private static int subN(int[] a, int[] b, int len) {
long sum = 0;
while (--len >= 0) {
sum = (a[len] & LONG_MASK) - (b[len] & LONG_MASK) + (sum >> 32);
a[len] = (int) sum;
}
return (int) (sum >> 32);
}
/**
* Multiply an array by one word k and add to result, return the carry
*/
static int mulAdd(int[] out, int[] in, int offset, int len, int k) {
long kLong = k & LONG_MASK;
long carry = 0;
offset = out.length - offset - 1;
for (int j = len - 1; j >= 0; j--) {
long product = (in[j] & LONG_MASK) * kLong
+ (out[offset] & LONG_MASK) + carry;
out[offset--] = (int) product;
carry = product >>> 32;
}
return (int) carry;
}
/**
* Add one word to the number a mlen words into a. Return the resulting
* carry.
*/
static int addOne(int[] a, int offset, int mlen, int carry) {
offset = a.length - 1 - mlen - offset;
long t = (a[offset] & LONG_MASK) + (carry & LONG_MASK);
a[offset] = (int) t;
if ((t >>> 32) == 0)
return 0;
while (--mlen >= 0) {
if (--offset < 0) { // Carry out of number
return 1;
} else {
a[offset]++;
if (a[offset] != 0)
return 0;
}
}
return 1;
}
/**
* Returns a BinaryInteger whose value is (this ** exponent) mod (2**p)
*/
private BinaryInteger modPow2(BinaryInteger exponent, int p) {
/*
* Perform exponentiation using repeated squaring trick, chopping off
* high order bits as indicated by modulus.
*/
BinaryInteger result = valueOf(1);
BinaryInteger baseToPow2 = this.mod2(p);
int expOffset = 0;
int limit = exponent.bitLength();
if (this.testBit(0))
limit = (p - 1) < limit ? (p - 1) : limit;
while (expOffset < limit) {
if (exponent.testBit(expOffset))
result = result.multiply(baseToPow2).mod2(p);
expOffset++;
if (expOffset < limit)
baseToPow2 = baseToPow2.square().mod2(p);
}
return result;
}
/**
* Returns a BinaryInteger whose value is this mod(2**p). Assumes that this
* BinaryInteger >= 0 and p > 0.
*/
private BinaryInteger mod2(int p) {
if (bitLength() <= p)
return this;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -