📄 integer.cpp
字号:
PositiveSubtract(sum, *this, b);
}
else
{
if (b.NotNegative())
PositiveSubtract(sum, b, *this);
else
{
PositiveAdd(sum, *this, b);
sum.sign = Integer::NEGATIVE;
}
}
return sum;
}
Integer& Integer::operator+=(const Integer& t)
{
reg.CleanGrow(t.reg.size);
if (NotNegative())
{
if (t.NotNegative())
PositiveAdd(*this, *this, t);
else
PositiveSubtract(*this, *this, t);
}
else
{
if (t.NotNegative())
PositiveSubtract(*this, t, *this);
else
{
PositiveAdd(*this, *this, t);
sign = Integer::NEGATIVE;
}
}
return *this;
}
Integer Integer::Minus(const Integer& b) const
{
Integer diff((word)0, STDMAX(reg.size, b.reg.size));
if (NotNegative())
{
if (b.NotNegative())
PositiveSubtract(diff, *this, b);
else
PositiveAdd(diff, *this, b);
}
else
{
if (b.NotNegative())
{
PositiveAdd(diff, *this, b);
diff.sign = Integer::NEGATIVE;
}
else
PositiveSubtract(diff, b, *this);
}
return diff;
}
Integer& Integer::operator-=(const Integer& t)
{
reg.CleanGrow(t.reg.size);
if (NotNegative())
{
if (t.NotNegative())
PositiveSubtract(*this, *this, t);
else
PositiveAdd(*this, *this, t);
}
else
{
if (t.NotNegative())
{
PositiveAdd(*this, *this, t);
sign = Integer::NEGATIVE;
}
else
PositiveSubtract(*this, t, *this);
}
return *this;
}
Integer& Integer::operator<<=(unsigned int n)
{
const unsigned int wordCount = WordCount();
const unsigned int shiftWords = n / WORD_BITS;
const unsigned int shiftBits = n % WORD_BITS;
reg.CleanGrow(RoundupSize(wordCount+bitsToWords(n)));
ShiftWordsLeftByWords(reg, wordCount + shiftWords, shiftWords);
ShiftWordsLeftByBits(reg+shiftWords, wordCount+bitsToWords(shiftBits), shiftBits);
return *this;
}
Integer& Integer::operator>>=(unsigned int n)
{
const unsigned int wordCount = WordCount();
const unsigned int shiftWords = n / WORD_BITS;
const unsigned int shiftBits = n % WORD_BITS;
ShiftWordsRightByWords(reg, wordCount, shiftWords);
if (wordCount > shiftWords)
ShiftWordsRightByBits(reg, wordCount-shiftWords, shiftBits);
if (IsNegative() && WordCount()==0) // avoid -0
*this = Zero();
return *this;
}
void PositiveMultiply(Integer &product, const Integer &a, const Integer &b)
{
unsigned aSize = RoundupSize(a.WordCount());
unsigned bSize = RoundupSize(b.WordCount());
product.reg.CleanNew(RoundupSize(aSize+bSize));
product.sign = Integer::POSITIVE;
SecWordBlock workspace(aSize + bSize);
AsymmetricMultiply(product.reg, workspace, a.reg, aSize, b.reg, bSize);
}
void Multiply(Integer &product, const Integer &a, const Integer &b)
{
PositiveMultiply(product, a, b);
if (a.NotNegative() != b.NotNegative())
product.Negate();
}
Integer Integer::Times(const Integer &b) const
{
Integer product;
Multiply(product, *this, b);
return product;
}
/*
void PositiveDivide(Integer &remainder, Integer "ient,
const Integer ÷nd, const Integer &divisor)
{
remainder.reg.CleanNew(divisor.reg.size);
remainder.sign = Integer::POSITIVE;
quotient.reg.New(0);
quotient.sign = Integer::POSITIVE;
unsigned i=dividend.BitCount();
while (i--)
{
word overflow = ShiftWordsLeftByBits(remainder.reg, remainder.reg.size, 1);
remainder.reg[0] |= dividend[i];
if (overflow || remainder >= divisor)
{
Subtract(remainder.reg, remainder.reg, divisor.reg, remainder.reg.size);
quotient.SetBit(i);
}
}
}
*/
void PositiveDivide(Integer &remainder, Integer "ient,
const Integer &a, const Integer &b)
{
unsigned aSize = a.WordCount();
unsigned bSize = b.WordCount();
if (!bSize)
throw Integer::DivideByZero();
if (a.PositiveCompare(b) == -1)
{
remainder = a;
remainder.sign = Integer::POSITIVE;
quotient = Integer::Zero();
return;
}
aSize += aSize%2; // round up to next even number
bSize += bSize%2;
remainder.reg.CleanNew(RoundupSize(bSize));
remainder.sign = Integer::POSITIVE;
quotient.reg.CleanNew(RoundupSize(aSize-bSize+2));
quotient.sign = Integer::POSITIVE;
SecWordBlock T(aSize+2*bSize+4);
Divide(remainder.reg, quotient.reg, T, a.reg, aSize, b.reg, bSize);
}
void Integer::Divide(Integer &remainder, Integer "ient, const Integer ÷nd, const Integer &divisor)
{
PositiveDivide(remainder, quotient, dividend, divisor);
if (dividend.IsNegative())
{
quotient.Negate();
if (remainder.NotZero())
{
--quotient;
remainder = divisor.AbsoluteValue() - remainder;
}
}
if (divisor.IsNegative())
quotient.Negate();
}
void Integer::DivideByPowerOf2(Integer &r, Integer &q, const Integer &a, unsigned int n)
{
q = a;
q >>= n;
const unsigned int wordCount = bitsToWords(n);
if (wordCount <= a.WordCount())
{
r.reg.Resize(RoundupSize(wordCount));
CopyWords(r.reg, a.reg, wordCount);
SetWords(r.reg+wordCount, 0, r.reg.size-wordCount);
if (n % WORD_BITS != 0)
r.reg[wordCount-1] %= (1 << (n % WORD_BITS));
}
else
{
r.reg.Resize(RoundupSize(a.WordCount()));
CopyWords(r.reg, a.reg, r.reg.size);
}
r.sign = POSITIVE;
if (a.IsNegative() && r.NotZero())
{
--q;
r = Power2(n) - r;
}
}
Integer Integer::DividedBy(const Integer &b) const
{
Integer remainder, quotient;
Integer::Divide(remainder, quotient, *this, b);
return quotient;
}
Integer Integer::Modulo(const Integer &b) const
{
Integer remainder, quotient;
Integer::Divide(remainder, quotient, *this, b);
return remainder;
}
void Integer::Divide(word &remainder, Integer "ient, const Integer ÷nd, word divisor)
{
if (!divisor)
throw Integer::DivideByZero();
assert(divisor);
if ((divisor & (divisor-1)) == 0) // divisor is a power of 2
{
quotient = dividend >> (BitPrecision(divisor)-1);
remainder = dividend.reg[0] & (divisor-1);
return;
}
unsigned int i = dividend.WordCount();
quotient.reg.CleanNew(RoundupSize(i));
remainder = 0;
while (i--)
{
quotient.reg[i] = word(MAKE_DWORD(dividend.reg[i], remainder) / divisor);
remainder = word(MAKE_DWORD(dividend.reg[i], remainder) % divisor);
}
if (dividend.NotNegative())
quotient.sign = POSITIVE;
else
{
quotient.sign = NEGATIVE;
if (remainder)
{
--quotient;
remainder = divisor - remainder;
}
}
}
Integer Integer::DividedBy(word b) const
{
word remainder;
Integer quotient;
Integer::Divide(remainder, quotient, *this, b);
return quotient;
}
word Integer::Modulo(word divisor) const
{
if (!divisor)
throw Integer::DivideByZero();
assert(divisor);
word remainder;
if ((divisor & (divisor-1)) == 0) // divisor is a power of 2
remainder = reg[0] & (divisor-1);
else
{
unsigned int i = WordCount();
if (divisor <= 5)
{
dword sum=0;
while (i--)
sum += reg[i];
remainder = word(sum%divisor);
}
else
{
remainder = 0;
while (i--)
remainder = word(MAKE_DWORD(reg[i], remainder) % divisor);
}
}
if (IsNegative() && remainder)
remainder = divisor - remainder;
return remainder;
}
void Integer::Negate()
{
if (!!(*this)) // don't flip sign if *this==0
sign = Sign(1-sign);
}
int Integer::PositiveCompare(const Integer& t) const
{
unsigned size = WordCount(), tSize = t.WordCount();
if (size == tSize)
return CryptoPP::Compare(reg, t.reg, size);
else
return size > tSize ? 1 : -1;
}
int Integer::Compare(const Integer& t) const
{
if (NotNegative())
{
if (t.NotNegative())
return PositiveCompare(t);
else
return 1;
}
else
{
if (t.NotNegative())
return -1;
else
return -PositiveCompare(t);
}
}
Integer Integer::SquareRoot() const
{
if (!IsPositive())
return Zero();
// overestimate square root
Integer x, y = Power2((BitCount()+1)/2);
assert(y*y >= *this);
do
{
x = y;
y = (x + *this/x) >> 1;
} while (y<x);
return x;
}
bool Integer::IsSquare() const
{
Integer r = SquareRoot();
return *this == r.Squared();
}
bool Integer::IsUnit() const
{
return (WordCount() == 1) && (reg[0] == 1);
}
Integer Integer::MultiplicativeInverse() const
{
return IsUnit() ? *this : Zero();
}
Integer a_times_b_mod_c(const Integer &x, const Integer& y, const Integer& m)
{
return x*y%m;
}
Integer a_exp_b_mod_c(const Integer &x, const Integer& e, const Integer& m)
{
ModularArithmetic mr(m);
return mr.Exponentiate(x, e);
}
Integer Integer::Gcd(const Integer &a, const Integer &b)
{
return EuclideanDomainOf<Integer>().Gcd(a, b);
}
Integer Integer::InverseMod(const Integer &m) const
{
assert(m.NotNegative());
if (IsNegative() || *this>=m)
return (*this%m).InverseMod(m);
if (m.IsEven())
{
if (!m || IsEven())
return Zero(); // no inverse
if (*this == One())
return One();
Integer u = m.InverseMod(*this);
return !u ? Zero() : (m*(*this-u)+1)/(*this);
}
SecBlock<word> T(m.reg.size * 4);
Integer r((word)0, m.reg.size);
unsigned k = AlmostInverse(r.reg, T, reg, reg.size, m.reg, m.reg.size);
DivideByPower2Mod(r.reg, r.reg, k, m.reg, m.reg.size);
return r;
}
word Integer::InverseMod(const word mod) const
{
word g0 = mod, g1 = *this % mod;
word v0 = 0, v1 = 1;
word y;
while (g1)
{
if (g1 == 1)
return v1;
y = g0 / g1;
g0 = g0 % g1;
v0 += y * v1;
if (!g0)
break;
if (g0 == 1)
return mod-v0;
y = g1 / g0;
g1 = g1 % g0;
v1 += y * v0;
}
return 0;
}
// ********************************************************
ModularArithmetic::ModularArithmetic(BufferedTransformation &bt)
{
BERSequenceDecoder seq(bt);
OID oid(seq);
if (oid != ASN1::prime_field())
BERDecodeError();
modulus.BERDecode(seq);
seq.MessageEnd();
result.reg.Resize(modulus.reg.size);
}
void ModularArithmetic::DEREncode(BufferedTransformation &bt) const
{
DERSequenceEncoder seq(bt);
ASN1::prime_field().DEREncode(seq);
modulus.DEREncode(seq);
seq.MessageEnd();
}
void ModularArithmetic::DEREncodeElement(BufferedTransformation &out, const Element &a) const
{
a.DEREncodeAsOctetString(out, MaxElementByteLength());
}
void ModularArithmetic::BERDecodeElement(BufferedTransformation &in, Element &a) const
{
a.BERDecodeAsOctetString(in, MaxElementByteLength());
}
const Integer& ModularArithmetic::Half(const Integer &a) const
{
if (a.reg.size==modulus.reg.size)
{
CryptoPP::DivideByPower2Mod(result.reg.ptr, a.reg, 1, modulus.reg, a.reg.size);
return result;
}
else
return result1 = (a.IsEven() ? (a >> 1) : ((a+modulus) >> 1));
}
const Integer& ModularArithmetic::Add(const Integer &a, const Integer &b) const
{
if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
{
if (CryptoPP::Add(result.reg.ptr, a.reg, b.reg, a.reg.size)
|| Compare(result.reg, modulus.reg, a.reg.size) >= 0)
{
CryptoPP::Subtract(result.reg.ptr, result.reg, modulus.reg, a.reg.size);
}
return result;
}
else
{
result1 = a+b;
if (result1 >= modulus)
result1 -= modulus;
return result1;
}
}
Integer& ModularArithmetic::Accumulate(Integer &a, const Integer &b) const
{
if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
{
if (CryptoPP::Add(a.reg, a.reg, b.reg, a.reg.size)
|| Compare(a.reg, modulus.reg, a.reg.size) >= 0)
{
CryptoPP::Subtract(a.reg, a.reg, modulus.reg, a.reg.size);
}
}
else
{
a+=b;
if (a>=modulus)
a-=modulus;
}
return a;
}
const Integer& ModularArithmetic::Subtract(const Integer &a, const Integer &b) const
{
if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
{
if (CryptoPP::Subtract(result.reg.ptr, a.reg, b.reg, a.reg.size))
CryptoPP::Add(result.reg.ptr, result.reg, modulus.reg, a.reg.size);
return result;
}
else
{
result1 = a-b;
if (result1.IsNegative())
result1 += modulus;
return result1;
}
}
Integer& ModularArithmetic::Reduce(Integer &a, const Integer &b) const
{
if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
{
if (CryptoPP::Subtract(a.reg, a.reg, b.reg, a.reg.size))
CryptoPP::Add(a.reg, a.reg, modulus.reg, a.reg.size);
}
else
{
a-=b;
if (a.IsNegative())
a+=modulus;
}
return a;
}
const Integer& ModularArithmetic::Inverse(const Integer &a) const
{
if (!a)
return a;
CopyWords(result.reg.ptr, modulus.reg, modulus.reg.size);
if (CryptoPP::Subtract(result.reg.ptr, result.reg, a.reg, a.reg.size))
Decrement(result.reg.ptr+a.reg.size, 1, modulus.reg.size-a.reg.size);
return result;
}
Integer ModularArithmetic::CascadeExponentiate(const Integer &x, const Integer &e1, const Integer &y, const Integer &e2) const
{
if (modulus.IsOdd())
{
MontgomeryRepresentation dr(modulus);
return dr.ConvertOut(dr.CascadeExponentiate(dr.ConvertIn(x), e1, dr.ConvertIn(y), e2));
}
else
return AbstractRing<Integer>::CascadeExponentiate(x, e1, y, e2);
}
void ModularArithmetic::SimultaneousExponentiate(Integer *results, const Integer &base, const Integer *exponents, unsigned int exponentsCount) const
{
if (modulus.IsOdd())
{
MontgomeryRepresentation dr(modulus);
dr.SimultaneousExponentiate(results, dr.ConvertIn(base), exponents, exponentsCount);
for (unsigned int i=0; i<exponentsCount; i++)
results[i] = dr.ConvertOut(results[i]);
}
else
AbstractRing<Integer>::SimultaneousExponentiate(results, base, exponents, exponentsCount);
}
MontgomeryRepresentation::MontgomeryRepresentation(const Integer &m) // modulus must be odd
: ModularArithmetic(m),
u((word)0, modulus.reg.size),
workspace(5*modulus.reg.size)
{
assert(modulus.IsOdd());
RecursiveInverseModPower2(u.reg, workspace, modulus.reg, modulus.reg.size);
}
const Integer& MontgomeryRepresentation::Multiply(const Integer &a, const Integer &b) const
{
word *const T = workspace.ptr;
word *const R = result.reg.ptr;
const unsigned int N = modulus.reg.size;
assert(a.reg.size<=N && b.reg.size<=N);
AsymmetricMultiply(T, T+2*N, a.reg, a.reg.size, b.reg, b.reg.size);
SetWords(T+a.reg.size+b.reg.size, 0, 2*N-a.reg.size-b.reg.size);
MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
return result;
}
const Integer& MontgomeryRepresentation::Square(const Integer &a) const
{
word *const T = workspace.ptr;
word *const R = result.reg.ptr;
const unsigned int N = modulus.reg.size;
assert(a.reg.size<=N);
RecursiveSquare(T, T+2*N, a.reg, a.reg.size);
SetWords(T+2*a.reg.size, 0, 2*N-2*a.reg.size);
MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
return result;
}
Integer MontgomeryRepresentation::ConvertOut(const Integer &a) const
{
word *const T = workspace.ptr;
word *const R = result.reg.ptr;
const unsigned int N = modulus.reg.size;
assert(a.reg.size<=N);
CopyWords(T, a.reg, a.reg.size);
SetWords(T+a.reg.size, 0, 2*N-a.reg.size);
MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
return result;
}
const Integer& MontgomeryRepresentation::MultiplicativeInverse(const Integer &a) const
{
// return (EuclideanMultiplicativeInverse(a, modulus)<<(2*WORD_BITS*modulus.reg.size))%modulus;
word *const T = workspace.ptr;
word *const R = result.reg.ptr;
const unsigned int N = modulus.reg.size;
assert(a.reg.size<=N);
CopyWords(T, a.reg, a.reg.size);
SetWords(T+a.reg.size, 0, 2*N-a.reg.size);
MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
unsigned k = AlmostInverse(R, T, R, N, modulus.reg, N);
// cout << "k=" << k << " N*32=" << 32*N << endl;
if (k>N*WORD_BITS)
DivideByPower2Mod(R, R, k-N*WORD_BITS, modulus.reg, N);
else
MultiplyByPower2Mod(R, R, N*WORD_BITS-k, modulus.reg, N);
return result;
}
NAMESPACE_END
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -