📄 integer.cpp
字号:
else
return bool((reg[n/WORD_BITS] >> (n % WORD_BITS)) & 1);
}
void Integer::SetBit(unsigned int n, bool value)
{
if (value)
{
reg.CleanGrow(RoundupSize(bitsToWords(n+1)));
reg[n/WORD_BITS] |= (word(1) << (n%WORD_BITS));
}
else
{
if (n/WORD_BITS < reg.size)
reg[n/WORD_BITS] &= ~(word(1) << (n%WORD_BITS));
}
}
byte Integer::GetByte(unsigned int n) const
{
if (n/WORD_SIZE >= reg.size)
return 0;
else
return byte(reg[n/WORD_SIZE] >> ((n%WORD_SIZE)*8));
}
void Integer::SetByte(unsigned int n, byte value)
{
reg.CleanGrow(RoundupSize(bytesToWords(n+1)));
reg[n/WORD_SIZE] &= ~(word(0xff) << 8*(n%WORD_SIZE));
reg[n/WORD_SIZE] |= (word(value) << 8*(n%WORD_SIZE));
}
Integer Integer::operator-() const
{
Integer result(*this);
result.Negate();
return result;
}
Integer Integer::AbsoluteValue() const
{
Integer result(*this);
result.sign = POSITIVE;
return result;
}
void Integer::swap(Integer &a)
{
reg.swap(a.reg);
std::swap(sign, a.sign);
}
Integer::Integer(word value, unsigned int length)
: reg(RoundupSize(length)), sign(POSITIVE)
{
reg[0] = value;
SetWords(reg+1, 0, reg.size-1);
}
Integer::Integer(const char *str)
: reg(2), sign(POSITIVE)
{
word radix;
unsigned length = strlen(str);
SetWords(reg, 0, 2);
if (length == 0)
return;
switch (str[length-1])
{
case 'h':
case 'H':
radix=16;
break;
case 'o':
case 'O':
radix=8;
break;
case 'b':
case 'B':
radix=2;
break;
default:
radix=10;
}
for (unsigned i=0; i<length; i++)
{
word digit;
if (str[i] >= '0' && str[i] <= '9')
digit = str[i] - '0';
else if (str[i] >= 'A' && str[i] <= 'F')
digit = str[i] - 'A' + 10;
else if (str[i] >= 'a' && str[i] <= 'f')
digit = str[i] - 'a' + 10;
else
digit = radix;
if (digit < radix)
{
*this *= radix;
*this += digit;
}
}
if (str[0] == '-')
Negate();
}
unsigned int Integer::WordCount() const
{
return CountWords(reg, reg.size);
}
unsigned int Integer::ByteCount() const
{
unsigned wordCount = WordCount();
if (wordCount)
return (wordCount-1)*WORD_SIZE + BytePrecision(reg[wordCount-1]);
else
return 0;
}
unsigned int Integer::BitCount() const
{
unsigned wordCount = WordCount();
if (wordCount)
return (wordCount-1)*WORD_BITS + BitPrecision(reg[wordCount-1]);
else
return 0;
}
void Integer::Decode(const byte *input, unsigned int inputLen, Signedness s)
{
sign = ((s==SIGNED) && (input[0] & 0x80)) ? NEGATIVE : POSITIVE;
while (inputLen>0 && input[0]==0)
{
input++;
inputLen--;
}
reg.CleanNew(RoundupSize(bytesToWords(inputLen)));
for (unsigned i=0; i<inputLen; i++)
reg[i/WORD_SIZE] |= input[inputLen-1-i] << (i%WORD_SIZE)*8;
if (sign == NEGATIVE)
{
for (unsigned i=inputLen; i<reg.size*WORD_SIZE; i++)
reg[i/WORD_SIZE] |= 0xff << (i%WORD_SIZE)*8;
TwosComplement(reg, reg.size);
}
}
unsigned int Integer::MinEncodedSize(Signedness signedness) const
{
unsigned int outputLen = STDMAX(1U, ByteCount());
if (signedness == UNSIGNED)
return outputLen;
if (NotNegative() && (GetByte(outputLen-1) & 0x80))
outputLen++;
if (IsNegative() && *this < -Power2(outputLen*8-1))
outputLen++;
return outputLen;
}
unsigned int Integer::Encode(byte *output, unsigned int outputLen, Signedness signedness) const
{
if (signedness == UNSIGNED || NotNegative())
{
for (unsigned i=0; i<outputLen; i++)
output[i]=GetByte(outputLen-i-1);
}
else
{
// take two's complement of *this
Integer temp = Integer::Power2(8*STDMAX(ByteCount(), outputLen)) + *this;
for (unsigned i=0; i<outputLen; i++)
output[i]=temp.GetByte(outputLen-i-1);
}
return outputLen;
}
unsigned int Integer::DEREncode(byte *output) const
{
unsigned int i=0;
output[i++] = INTEGER;
unsigned int bc = MinEncodedSize(SIGNED);
SecByteBlock buf(bc);
Encode(buf, bc, SIGNED);
i += DERLengthEncode(bc, output+i);
memcpy(output+i, buf, bc);
return i+bc;
}
unsigned int Integer::DEREncode(BufferedTransformation &bt) const
{
bt.Put(INTEGER);
unsigned int bc = MinEncodedSize(SIGNED);
SecByteBlock buf(bc);
Encode(buf, bc, SIGNED);
unsigned int lengthBytes = DERLengthEncode(bc, bt);
bt.Put(buf, bc);
return 1+lengthBytes+bc;
}
void Integer::BERDecode(const byte *input)
{
if (*input++ != INTEGER)
BERDecodeError();
int bc;
if (!(*input & 0x80))
bc = *input++;
else
{
int lengthBytes = *input++ & 0x7f;
if (lengthBytes > 2)
BERDecodeError();
bc = *input++;
if (lengthBytes > 1)
bc = (bc << 8) | *input++;
}
Decode(input, bc, SIGNED);
}
void Integer::BERDecode(BufferedTransformation &bt)
{
byte b;
if (!bt.Get(b) || b != INTEGER)
BERDecodeError();
unsigned int bc;
BERLengthDecode(bt, bc);
SecByteBlock buf(bc);
if (bc != bt.Get(buf, bc))
BERDecodeError();
Decode(buf, bc, SIGNED);
}
void Integer::Randomize(RandomNumberGenerator &rng, unsigned int nbits)
{
const unsigned int nbytes = nbits/8 + 1;
SecByteBlock buf(nbytes);
rng.GetBlock(buf, nbytes);
if (nbytes)
buf[0] = (byte)Crop(buf[0], nbits % 8);
Decode(buf, nbytes, UNSIGNED);
}
void Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max)
{
assert(max >= min);
Integer range = max - min;
const unsigned int nbits = range.BitCount();
do
{
Randomize(rng, nbits);
}
while (*this > range);
*this += min;
}
bool Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod)
{
assert(!equiv.IsNegative() && equiv < mod);
switch (rnType)
{
case ANY:
if (mod == One())
Randomize(rng, min, max);
else
{
Integer min1 = min + (equiv-min)%mod;
if (max < min1)
return false;
Randomize(rng, Zero(), (max - min1) / mod);
*this *= mod;
*this += min1;
}
return true;
case PRIME:
int i;
i = 0;
while (1)
{
if (++i==16)
{
// check if there are any suitable primes in [min, max]
Integer first = min;
if (FirstPrime(first, max, equiv, mod))
{
// if there is only one suitable prime, we're done
*this = first;
if (!FirstPrime(first, max, equiv, mod))
return true;
}
else
return false;
}
Randomize(rng, min, max);
if (FirstPrime(*this, STDMIN(*this+mod*PrimeSearchInterval(max), max), equiv, mod))
return true;
}
default:
assert(false);
return false;
}
}
std::istream& operator>>(std::istream& in, Integer &a)
{
char c;
unsigned int length = 0;
SecBlock<char> str(length + 16);
std::ws(in);
do
{
in.read(&c, 1);
str[length++] = c;
if (length >= str.size)
str.Grow(length + 16);
}
while (in && (c=='-' || (c>='0' && c<='9') || (c>='a' && c<='f') || (c>='A' && c<='F') || c=='h' || c=='H' || c=='o' || c=='O' || c==',' || c=='.'));
if (in.gcount())
in.putback(c);
str[length-1] = '\0';
a = Integer(str);
return in;
}
std::ostream& operator<<(std::ostream& out, const Integer &a)
{
// Get relevant conversion specifications from ostream.
long f = out.flags() & std::ios::basefield; // Get base digits.
int base, block;
char suffix;
switch(f)
{
case std::ios::oct :
base = 8;
block = 8;
suffix = 'o';
break;
case std::ios::hex :
base = 16;
block = 4;
suffix = 'h';
break;
default :
base = 10;
block = 3;
suffix = '.';
}
SecBlock<char> s(a.BitCount() / (BitPrecision(base)-1) + 1);
Integer temp1=a, temp2;
unsigned i=0;
const char vec[]="0123456789ABCDEF";
if (a.IsNegative())
{
out << '-';
temp1.Negate();
}
if (!a)
out << '0';
while (!!temp1)
{
s[i++]=vec[Integer::ShortDivide(temp2, temp1, base)];
temp1=temp2;
}
while (i--)
{
out << s[i];
if (i && !(i%block))
out << ",";
}
return out << suffix;
}
Integer& Integer::operator++()
{
if (NotNegative())
{
if (Increment(reg, reg.size))
{
reg.CleanGrow(2*reg.size);
reg[reg.size/2]=1;
}
}
else
{
word borrow = Decrement(reg, reg.size);
assert(!borrow);
if (WordCount()==0)
*this = Zero();
}
return *this;
}
Integer& Integer::operator--()
{
if (IsNegative())
{
if (Increment(reg, reg.size))
{
reg.CleanGrow(2*reg.size);
reg[reg.size/2]=1;
}
}
else
{
if (Decrement(reg, reg.size))
*this = -One();
}
return *this;
}
void PositiveAdd(Integer &sum, const Integer &a, const Integer& b)
{
word carry;
if (a.reg.size == b.reg.size)
carry = Add(sum.reg, a.reg, b.reg, a.reg.size);
else if (a.reg.size > b.reg.size)
{
carry = Add(sum.reg, a.reg, b.reg, b.reg.size);
CopyWords(sum.reg+b.reg.size, a.reg+b.reg.size, a.reg.size-b.reg.size);
carry = Increment(sum.reg+b.reg.size, a.reg.size-b.reg.size, carry);
}
else
{
carry = Add(sum.reg, a.reg, b.reg, a.reg.size);
CopyWords(sum.reg+a.reg.size, b.reg+a.reg.size, b.reg.size-a.reg.size);
carry = Increment(sum.reg+a.reg.size, b.reg.size-a.reg.size, carry);
}
if (carry)
{
sum.reg.CleanGrow(2*sum.reg.size);
sum.reg[sum.reg.size/2] = 1;
}
sum.sign = Integer::POSITIVE;
}
void PositiveSubtract(Integer &diff, const Integer &a, const Integer& b)
{
unsigned aSize = a.WordCount();
aSize += aSize%2;
unsigned bSize = b.WordCount();
bSize += bSize%2;
if (aSize == bSize)
{
if (Compare(a.reg, b.reg, aSize) >= 0)
{
Subtract(diff.reg, a.reg, b.reg, aSize);
diff.sign = Integer::POSITIVE;
}
else
{
Subtract(diff.reg, b.reg, a.reg, aSize);
diff.sign = Integer::NEGATIVE;
}
}
else if (aSize > bSize)
{
word borrow = Subtract(diff.reg, a.reg, b.reg, bSize);
CopyWords(diff.reg+bSize, a.reg+bSize, aSize-bSize);
borrow = Decrement(diff.reg+bSize, aSize-bSize, borrow);
assert(!borrow);
diff.sign = Integer::POSITIVE;
}
else
{
word borrow = Subtract(diff.reg, b.reg, a.reg, aSize);
CopyWords(diff.reg+aSize, b.reg+aSize, bSize-aSize);
borrow = Decrement(diff.reg+aSize, bSize-aSize, borrow);
assert(!borrow);
diff.sign = Integer::NEGATIVE;
}
}
Integer operator+(const Integer &a, const Integer& b)
{
Integer sum((word)0, STDMAX(a.reg.size, b.reg.size));
if (a.NotNegative())
{
if (b.NotNegative())
PositiveAdd(sum, a, b);
else
PositiveSubtract(sum, a, b);
}
else
{
if (b.NotNegative())
PositiveSubtract(sum, b, a);
else
{
PositiveAdd(sum, a, b);
sum.sign = Integer::NEGATIVE;
}
}
return sum;
}
Integer& Integer::operator+=(const Integer& t)
{
reg.CleanGrow(t.reg.size);
if (NotNegative())
{
if (t.NotNegative())
PositiveAdd(*this, *this, t);
else
PositiveSubtract(*this, *this, t);
}
else
{
if (t.NotNegative())
PositiveSubtract(*this, t, *this);
else
{
PositiveAdd(*this, *this, t);
sign = Integer::NEGATIVE;
}
}
return *this;
}
Integer operator-(const Integer &a, const Integer& b)
{
Integer diff((word)0, STDMAX(a.reg.size, b.reg.size));
if (a.NotNegative())
{
if (b.NotNegative())
PositiveSubtract(diff, a, b);
else
PositiveAdd(diff, a, b);
}
else
{
if (b.NotNegative())
{
PositiveAdd(diff, a, b);
diff.sign = Integer::NEGATIVE;
}
else
PositiveSubtract(diff, b, a);
}
return diff;
}
Integer& Integer::operator-=(const Integer& t)
{
reg.CleanGrow(t.reg.size);
if (NotNegative())
{
if (t.NotNegative())
PositiveSubtract(*this, *this, t);
else
PositiveAdd(*this, *this, t);
}
else
{
if (t.NotNegative())
{
PositiveAdd(*this, *this, t);
sign = Integer::NEGATIVE;
}
else
PositiveSubtract(*this, t, *this);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -