📄 mpi.c
字号:
int32 mp_mul_2 (mp_int * a, mp_int * b)
{
int32 x, res, oldused;
/*
grow to accomodate result
*/
if (b->alloc < a->used + 1) {
if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
return res;
}
}
oldused = b->used;
b->used = a->used;
{
register mp_digit r, rr, *tmpa, *tmpb;
/* alias for source */
tmpa = a->dp;
/* alias for dest */
tmpb = b->dp;
/* carry */
r = 0;
for (x = 0; x < a->used; x++) {
/*
get what will be the *next* carry bit from the MSB of the
current digit
*/
rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
/*
now shift up this digit, add in the carry [from the previous]
*/
*tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
/* copy the carry that would be from the source digit into the next
iteration
*/
r = rr;
}
/*
new leading digit?
*/
if (r != 0) {
/*
add a MSB which is always 1 at this point
*/
*tmpb = 1;
++(b->used);
}
/*
now zero any excess digits on the destination that we didn't write to
*/
tmpb = b->dp + b->used;
for (x = b->used; x < oldused; x++) {
*tmpb++ = 0;
}
}
b->sign = a->sign;
return MP_OKAY;
}
/******************************************************************************/
/*
multiply by a digit
*/
int32 mp_mul_d(mp_int * a, mp_digit b, mp_int * c)
{
mp_digit u, *tmpa, *tmpc;
mp_word r;
int32 ix, res, olduse;
/*
make sure c is big enough to hold a*b
*/
if (c->alloc < a->used + 1) {
if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
return res;
}
}
/*
get the original destinations used count
*/
olduse = c->used;
/*
set the sign
*/
c->sign = a->sign;
/*
alias for a->dp [source]
*/
tmpa = a->dp;
/*
alias for c->dp [dest]
*/
tmpc = c->dp;
/* zero carry */
u = 0;
/* compute columns */
for (ix = 0; ix < a->used; ix++) {
/*
compute product and carry sum for this term
*/
r = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
/*
mask off higher bits to get a single digit
*/
*tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
/*
send carry into next iteration
*/
u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
}
/*
store final carry [if any] and increment ix offset
*/
*tmpc++ = u;
++ix;
/*
now zero digits above the top
*/
while (ix++ < olduse) {
*tmpc++ = 0;
}
/* set used count */
c->used = a->used + 1;
mp_clamp(c);
return MP_OKAY;
}
/******************************************************************************/
/*
low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
*/
#ifdef USE_SMALL_WORD
int32 s_mp_sqr (psPool_t *pool, mp_int * a, mp_int * b)
{
mp_int t;
int32 res, ix, iy, pa;
mp_word r;
mp_digit u, tmpx, *tmpt;
pa = a->used;
if ((res = mp_init_size(pool, &t, 2*pa + 1)) != MP_OKAY) {
return res;
}
/*
default used is maximum possible size
*/
t.used = 2*pa + 1;
for (ix = 0; ix < pa; ix++) {
/*
first calculate the digit at 2*ix
calculate double precision result
*/
r = ((mp_word) t.dp[2*ix]) +
((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
/*
store lower part in result
*/
t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
/*
get the carry
*/
u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
/*
left hand side of A[ix] * A[iy]
*/
tmpx = a->dp[ix];
/*
alias for where to store the results
*/
tmpt = t.dp + (2*ix + 1);
for (iy = ix + 1; iy < pa; iy++) {
/*
first calculate the product
*/
r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
/*
now calculate the double precision result, note we use addition
instead of *2 since it's easier to optimize
*/
r = ((mp_word) *tmpt) + r + r + ((mp_word) u);
/*
store lower part
*/
*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
/* get carry */
u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
}
/* propagate upwards */
while (u != ((mp_digit) 0)) {
r = ((mp_word) *tmpt) + ((mp_word) u);
*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
}
}
mp_clamp (&t);
mp_exch (&t, b);
mp_clear (&t);
return MP_OKAY;
}
#endif /* USE_SMALL_WORD */
/******************************************************************************/
/*
fast squaring
This is the comba method where the columns of the product are computed
first then the carries are computed. This has the effect of making a very
simple inner loop that is executed the most
W2 represents the outer products and W the inner.
A further optimizations is made because the inner products are of the
form "A * B * 2". The *2 part does not need to be computed until the end
which is good because 64-bit shifts are slow!
Based on Algorithm 14.16 on pp.597 of HAC.
This is the 1.0 version, but no SSE stuff
*/
int32 fast_s_mp_sqr(psPool_t *pool, mp_int * a, mp_int * b)
{
int32 olduse, res, pa, ix, iz;
mp_digit W[MP_WARRAY], *tmpx;
mp_word W1;
/*
grow the destination as required
*/
pa = a->used + a->used;
if (b->alloc < pa) {
if ((res = mp_grow(b, pa)) != MP_OKAY) {
return res;
}
}
/*
number of output digits to produce
*/
W1 = 0;
for (ix = 0; ix < pa; ix++) {
int32 tx, ty, iy;
mp_word _W;
mp_digit *tmpy;
/*
clear counter
*/
_W = 0;
/*
get offsets into the two bignums
*/
ty = MIN(a->used-1, ix);
tx = ix - ty;
/*
setup temp aliases
*/
tmpx = a->dp + tx;
tmpy = a->dp + ty;
/*
this is the number of times the loop will iterrate, essentially
while (tx++ < a->used && ty-- >= 0) { ... }
*/
iy = MIN(a->used-tx, ty+1);
/*
now for squaring tx can never equal ty
we halve the distance since they approach at a rate of 2x
and we have to round because odd cases need to be executed
*/
iy = MIN(iy, (ty-tx+1)>>1);
/*
execute loop
*/
for (iz = 0; iz < iy; iz++) {
_W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
}
/*
double the inner product and add carry
*/
_W = _W + _W + W1;
/*
even columns have the square term in them
*/
if ((ix&1) == 0) {
_W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
}
/*
store it
*/
W[ix] = (mp_digit)(_W & MP_MASK);
/*
make next carry
*/
W1 = _W >> ((mp_word)DIGIT_BIT);
}
/*
setup dest
*/
olduse = b->used;
b->used = a->used+a->used;
{
mp_digit *tmpb;
tmpb = b->dp;
for (ix = 0; ix < pa; ix++) {
*tmpb++ = W[ix] & MP_MASK;
}
/*
clear unused digits [that existed in the old copy of c]
*/
for (; ix < olduse; ix++) {
*tmpb++ = 0;
}
}
mp_clamp(b);
return MP_OKAY;
}
/******************************************************************************/
/*
computes a = 2**b
Simple algorithm which zeroes the int32, grows it then just sets one bit
as required.
*/
int32 mp_2expt (mp_int * a, int32 b)
{
int32 res;
/*
zero a as per default
*/
mp_zero (a);
/*
grow a to accomodate the single bit
*/
if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
return res;
}
/*
set the used count of where the bit will go
*/
a->used = b / DIGIT_BIT + 1;
/*
put the single bit in its place
*/
a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
return MP_OKAY;
}
/******************************************************************************/
/*
init an mp_init for a given size
*/
int32 mp_init_size(psPool_t *pool, mp_int * a, int32 size)
{
int x;
/*
pad size so there are always extra digits
*/
size += (MP_PREC * 2) - (size % MP_PREC);
/*
alloc mem
*/
a->dp = OPT_CAST(mp_digit) psMalloc(pool, sizeof (mp_digit) * size);
if (a->dp == NULL) {
return MP_MEM;
}
a->used = 0;
a->alloc = size;
a->sign = MP_ZPOS;
/*
zero the digits
*/
for (x = 0; x < size; x++) {
a->dp[x] = 0;
}
return MP_OKAY;
}
/******************************************************************************/
/*
low level addition, based on HAC pp.594, Algorithm 14.7
*/
int32 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
{
mp_int *x;
int32 olduse, res, min, max;
/*
find sizes, we let |a| <= |b| which means we have to sort them. "x" will
point to the input with the most digits
*/
if (a->used > b->used) {
min = b->used;
max = a->used;
x = a;
} else {
min = a->used;
max = b->used;
x = b;
}
/* init result */
if (c->alloc < max + 1) {
if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
return res;
}
}
/*
get old used digit count and set new one
*/
olduse = c->used;
c->used = max + 1;
{
register mp_digit u, *tmpa, *tmpb, *tmpc;
register int32 i;
/* alias for digit pointers */
/* first input */
tmpa = a->dp;
/* second input */
tmpb = b->dp;
/* destination */
tmpc = c->dp;
/* zero the carry */
u = 0;
for (i = 0; i < min; i++) {
/*
Compute the sum at one digit, T[i] = A[i] + B[i] + U
*/
*tmpc = *tmpa++ + *tmpb++ + u;
/*
U = carry bit of T[i]
*/
u = *tmpc >> ((mp_digit)DIGIT_BIT);
/*
take away carry bit from T[i]
*/
*tmpc++ &= MP_MASK;
}
/*
now copy higher words if any, that is in A+B if A or B has more digits add
those in
*/
if (min != max) {
for (; i < max; i++) {
/* T[i] = X[i] + U */
*tmpc = x->dp[i] + u;
/* U = carry bit of T[i] */
u = *tmpc >> ((mp_digit)DIGIT_BIT);
/* take away carry bit from T[i] */
*tmpc++ &= MP_MASK;
}
}
/* add carry */
*tmpc++ = u;
/*
clear digits above oldused
*/
for (i = c->used; i < olduse; i++) {
*tmpc++ = 0;
}
}
mp_clamp (c);
return MP_OKAY;
}
/******************************************************************************/
#ifdef USE_SMALL_WORD
/*
FUTURE - this is never needed, SLOW or not, because RSA exponents are
always odd.
*/
int32 mp_invmod(psPool_t *pool, mp_int * a, mp_int * b, mp_int * c)
{
mp_int x, y, u, v, A, B, C, D;
int32 res;
/*
b cannot be negative
*/
if (b->sign == MP_NEG || mp_iszero(b) == 1) {
return MP_VAL;
}
/*
if the modulus is odd we can use a faster routine instead
*/
if (mp_isodd (b) == 1) {
return fast_mp_invmod(pool, a, b, c);
}
/*
init temps
*/
if ((res = _mp_init_multi(pool, &x, &y, &u, &v,
&A, &B, &C, &D)) != MP_OKAY) {
return res;
}
/* x = a, y = b */
if ((res = mp_copy(a, &x)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_copy(b, &y)) != MP_OKAY) {
goto LBL_ERR;
}
/*
2. [modified] if x,y are both even then return an error!
*/
if (mp_iseven(&x) == 1 && mp_iseven (&y) == 1) {
res = MP_VAL;
goto LBL_ERR;
}
/*
3. u=x, v=y, A=1, B=0, C=0,D=1
*/
if ((res = mp_copy(&x, &u)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_copy(&y, &v)) != MP_OKAY) {
goto LBL_ERR;
}
mp_set (&A, 1);
mp_set (&D, 1);
top:
/*
4. while u is even do
*/
while (mp_iseven(&u) == 1) {
/* 4.1 u = u/2 */
if ((res = mp_div_2(&u, &u)) != MP_OKAY) {
goto LBL_ERR;
}
/* 4.2 if A or B is odd then */
if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
/* A = (A+y)/2, B = (B-x)/2 */
if ((res = mp_add(&A, &y, &A)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) {
goto LBL_ERR;
}
}
/* A = A/2, B = B/2 */
if ((res = mp_div_2(&A, &A)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_div_2(&B, &B)) != MP_OKAY) {
goto LBL_ERR;
}
}
/*
5. while v is even do
*/
while (mp_iseven(&v) == 1) {
/* 5.1 v = v/2 */
if ((res = mp_div_2(&v, &v)) != MP_OKAY) {
goto LBL_ERR;
}
/* 5.2 if C or D is odd then */
if (mp_isodd(&C) == 1 || mp_isodd (&D) == 1) {
/* C = (C+y)/2, D = (D-x)/2 */
if ((res = mp_add(&C, &y, &C)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) {
goto LBL_ERR;
}
}
/* C = C/2, D = D/2 */
if ((res = mp_div_2(&C, &C)) != MP_OKAY) {
goto LBL_ERR;
}
if ((res = mp_div_2(&D, &D)) != MP_OKAY) {
goto LBL_ERR;
}
}
/*
6. if u >= v t
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -