📄 mpi.c
字号:
/* first calculate the digit at 2*ix calculate double precision result */ r = ((mp_word) t.dp[2*ix]) + ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);/* store lower part in result */ t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));/* get the carry */ u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));/* left hand side of A[ix] * A[iy] */ tmpx = a->dp[ix];/* alias for where to store the results */ tmpt = t.dp + (2*ix + 1); for (iy = ix + 1; iy < pa; iy++) {/* first calculate the product */ r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);/* now calculate the double precision result, note we use addition instead of *2 since it's easier to optimize */ r = ((mp_word) *tmpt) + r + r + ((mp_word) u);/* store lower part */ *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK)); /* get carry */ u = (mp_digit)(r >> ((mp_word) DIGIT_BIT)); } /* propagate upwards */ while (u != ((mp_digit) 0)) { r = ((mp_word) *tmpt) + ((mp_word) u); *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK)); u = (mp_digit)(r >> ((mp_word) DIGIT_BIT)); } } mp_clamp (&t); mp_exch (&t, b); mp_clear (&t); return MP_OKAY;}#endif /* USE_SMALL_WORD *//******************************************************************************//* fast squaring This is the comba method where the columns of the product are computed first then the carries are computed. This has the effect of making a very simple inner loop that is executed the most W2 represents the outer products and W the inner. A further optimizations is made because the inner products are of the form "A * B * 2". The *2 part does not need to be computed until the end which is good because 64-bit shifts are slow! Based on Algorithm 14.16 on pp.597 of HAC. This is the 1.0 version, but no SSE stuff*/int32 fast_s_mp_sqr(psPool_t *pool, mp_int * a, mp_int * b){ int32 olduse, res, pa, ix, iz; mp_digit W[MP_WARRAY], *tmpx; mp_word W1;/* grow the destination as required */ pa = a->used + a->used; if (b->alloc < pa) { if ((res = mp_grow(b, pa)) != MP_OKAY) { return res; } }/* number of output digits to produce */ W1 = 0; for (ix = 0; ix < pa; ix++) { int32 tx, ty, iy; mp_word _W; mp_digit *tmpy;/* clear counter */ _W = 0;/* get offsets into the two bignums */ ty = MIN(a->used-1, ix); tx = ix - ty;/* setup temp aliases */ tmpx = a->dp + tx; tmpy = a->dp + ty;/* this is the number of times the loop will iterrate, essentially while (tx++ < a->used && ty-- >= 0) { ... }*/ iy = MIN(a->used-tx, ty+1);/* now for squaring tx can never equal ty we halve the distance since they approach at a rate of 2x and we have to round because odd cases need to be executed*/ iy = MIN(iy, (ty-tx+1)>>1);/* execute loop */ for (iz = 0; iz < iy; iz++) { _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--); }/* double the inner product and add carry */ _W = _W + _W + W1;/* even columns have the square term in them */ if ((ix&1) == 0) { _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]); }/* store it */ W[ix] = (mp_digit)(_W & MP_MASK);/* make next carry */ W1 = _W >> ((mp_word)DIGIT_BIT); }/* setup dest */ olduse = b->used; b->used = a->used+a->used; { mp_digit *tmpb; tmpb = b->dp; for (ix = 0; ix < pa; ix++) { *tmpb++ = W[ix] & MP_MASK; }/* clear unused digits [that existed in the old copy of c] */ for (; ix < olduse; ix++) { *tmpb++ = 0; } } mp_clamp(b); return MP_OKAY;}/******************************************************************************//* computes a = 2**b Simple algorithm which zeroes the int32, grows it then just sets one bit as required. */int32 mp_2expt (mp_int * a, int32 b){ int32 res;/* zero a as per default */ mp_zero (a);/* grow a to accomodate the single bit */ if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) { return res; }/* set the used count of where the bit will go */ a->used = b / DIGIT_BIT + 1;/* put the single bit in its place */ a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT); return MP_OKAY;}/******************************************************************************//* init an mp_init for a given size */int32 mp_init_size(psPool_t *pool, mp_int * a, int32 size){ int x;/* pad size so there are always extra digits */ size += (MP_PREC * 2) - (size % MP_PREC); /* alloc mem */ a->dp = OPT_CAST(mp_digit) psMalloc(pool, sizeof (mp_digit) * size); if (a->dp == NULL) { return MP_MEM; } a->used = 0; a->alloc = size; a->sign = MP_ZPOS;/* zero the digits */ for (x = 0; x < size; x++) { a->dp[x] = 0; } return MP_OKAY;}/******************************************************************************//* low level addition, based on HAC pp.594, Algorithm 14.7 */int32 s_mp_add (mp_int * a, mp_int * b, mp_int * c){ mp_int *x; int32 olduse, res, min, max;/* find sizes, we let |a| <= |b| which means we have to sort them. "x" will point to the input with the most digits */ if (a->used > b->used) { min = b->used; max = a->used; x = a; } else { min = a->used; max = b->used; x = b; } /* init result */ if (c->alloc < max + 1) { if ((res = mp_grow (c, max + 1)) != MP_OKAY) { return res; } }/* get old used digit count and set new one */ olduse = c->used; c->used = max + 1; { register mp_digit u, *tmpa, *tmpb, *tmpc; register int32 i; /* alias for digit pointers */ /* first input */ tmpa = a->dp; /* second input */ tmpb = b->dp; /* destination */ tmpc = c->dp; /* zero the carry */ u = 0; for (i = 0; i < min; i++) {/* Compute the sum at one digit, T[i] = A[i] + B[i] + U */ *tmpc = *tmpa++ + *tmpb++ + u;/* U = carry bit of T[i] */ u = *tmpc >> ((mp_digit)DIGIT_BIT);/* take away carry bit from T[i] */ *tmpc++ &= MP_MASK; }/* now copy higher words if any, that is in A+B if A or B has more digits add those in */ if (min != max) { for (; i < max; i++) { /* T[i] = X[i] + U */ *tmpc = x->dp[i] + u; /* U = carry bit of T[i] */ u = *tmpc >> ((mp_digit)DIGIT_BIT); /* take away carry bit from T[i] */ *tmpc++ &= MP_MASK; } } /* add carry */ *tmpc++ = u;/* clear digits above oldused */ for (i = c->used; i < olduse; i++) { *tmpc++ = 0; } } mp_clamp (c); return MP_OKAY;}/******************************************************************************/#ifdef USE_SMALL_WORD/* FUTURE - this is never needed, SLOW or not, because RSA exponents are always odd. */int32 mp_invmod(psPool_t *pool, mp_int * a, mp_int * b, mp_int * c){ mp_int x, y, u, v, A, B, C, D; int32 res;/* b cannot be negative */ if (b->sign == MP_NEG || mp_iszero(b) == 1) { return MP_VAL; }/* if the modulus is odd we can use a faster routine instead */ if (mp_isodd (b) == 1) { return fast_mp_invmod(pool, a, b, c); }/* init temps */ if ((res = _mp_init_multi(pool, &x, &y, &u, &v, &A, &B, &C, &D)) != MP_OKAY) { return res; } /* x = a, y = b */ if ((res = mp_copy(a, &x)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_copy(b, &y)) != MP_OKAY) { goto LBL_ERR; }/* 2. [modified] if x,y are both even then return an error! */ if (mp_iseven(&x) == 1 && mp_iseven (&y) == 1) { res = MP_VAL; goto LBL_ERR; }/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */ if ((res = mp_copy(&x, &u)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_copy(&y, &v)) != MP_OKAY) { goto LBL_ERR; } mp_set (&A, 1); mp_set (&D, 1);top:/* 4. while u is even do */ while (mp_iseven(&u) == 1) { /* 4.1 u = u/2 */ if ((res = mp_div_2(&u, &u)) != MP_OKAY) { goto LBL_ERR; } /* 4.2 if A or B is odd then */ if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) { /* A = (A+y)/2, B = (B-x)/2 */ if ((res = mp_add(&A, &y, &A)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) { goto LBL_ERR; } } /* A = A/2, B = B/2 */ if ((res = mp_div_2(&A, &A)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_div_2(&B, &B)) != MP_OKAY) { goto LBL_ERR; } }/* 5. while v is even do */ while (mp_iseven(&v) == 1) { /* 5.1 v = v/2 */ if ((res = mp_div_2(&v, &v)) != MP_OKAY) { goto LBL_ERR; } /* 5.2 if C or D is odd then */ if (mp_isodd(&C) == 1 || mp_isodd (&D) == 1) { /* C = (C+y)/2, D = (D-x)/2 */ if ((res = mp_add(&C, &y, &C)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) { goto LBL_ERR; } } /* C = C/2, D = D/2 */ if ((res = mp_div_2(&C, &C)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_div_2(&D, &D)) != MP_OKAY) { goto LBL_ERR; } }/* 6. if u >= v then */ if (mp_cmp(&u, &v) != MP_LT) { /* u = u - v, A = A - C, B = B - D */ if ((res = mp_sub(&u, &v, &u)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&A, &C, &A)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&B, &D, &B)) != MP_OKAY) { goto LBL_ERR; } } else { /* v - v - u, C = C - A, D = D - B */ if ((res = mp_sub(&v, &u, &v)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&C, &A, &C)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&D, &B, &D)) != MP_OKAY) { goto LBL_ERR; } }/* if not zero goto step 4 */ if (mp_iszero(&u) == 0) goto top;/* now a = C, b = D, gcd == g*v *//* if v != 1 then there is no inverse */ if (mp_cmp_d(&v, 1) != MP_EQ) { res = MP_VAL; goto LBL_ERR; }/* if its too low */ while (mp_cmp_d(&C, 0) == MP_LT) { if ((res = mp_add(&C, b, &C)) != MP_OKAY) { goto LBL_ERR; } }/* too big */ while (mp_cmp_mag(&C, b) != MP_LT) { if ((res = mp_sub(&C, b, &C)) != MP_OKAY) { goto LBL_ERR; } }/* C is now the inverse */ mp_exch(&C, c); res = MP_OKAY;LBL_ERR:_mp_clear_multi(&x, &y, &u, &v, &A, &B, &C, &D); return res;}#endif /* USE_SMALL_WORD *//******************************************************************************//* * Computes the modular inverse via binary extended euclidean algorithm, * that is c = 1/a mod b * * Based on slow invmod except this is optimized for the case where b is * odd as per HAC Note 14.64 on pp. 610 */int32 fast_mp_invmod(psPool_t *pool, mp_int * a, mp_int * b, mp_int * c){ mp_int x, y, u, v, B, D; int32 res, neg;/* 2. [modified] b must be odd */ if (mp_iseven (b) == 1) { return MP_VAL; }/* init all our temps */ if ((res = _mp_init_multi(pool, &x, &y, &u, &v, &B, &D, NULL, NULL)) != MP_OKAY) { return res; }/* x == modulus, y == value to invert */ if ((res = mp_copy(b, &x)) != MP_OKAY) { goto LBL_ERR; }/* we need y = |a| */ if ((res = mp_mod(pool, a, b, &y)) != MP_OKAY) { goto LBL_ERR; }/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */ if ((res = mp_copy(&x, &u)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_copy(&y, &v)) != MP_OKAY) { goto LBL_ERR; } mp_set(&D, 1);top:/* 4. while u is even do*/ while (mp_iseven(&u) == 1) { /* 4.1 u = u/2 */ if ((res = mp_div_2(&u, &u)) != MP_OKAY) { goto LBL_ERR; } /* 4.2 if B is odd then */ if (mp_isodd(&B) == 1) { if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) { goto LBL_ERR; } } /* B = B/2 */ if ((res = mp_div_2(&B, &B)) != MP_OKAY) { goto LBL_ERR; } }/* 5. while v is even do */ while (mp_iseven(&v) == 1) { /* 5.1 v = v/2 */ if ((res = mp_div_2(&v, &v)) != MP_OKAY) { goto LBL_ERR; } /* 5.2 if D is odd then */ if (mp_isodd(&D) == 1) { /* D = (D-x)/2 */ if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) { goto LBL_ERR; } } /* D = D/2 */ if ((res = mp_div_2(&D, &D)) != MP_OKAY) { goto LBL_ERR; } }/* 6. if u >= v then */ if (mp_cmp(&u, &v) != MP_LT) { /* u = u - v, B = B - D */ if ((res = mp_sub(&u, &v, &u)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&B, &D, &B)) != MP_OKAY) { goto LBL_ERR; } } else { /* v - v - u, D = D - B */ if ((res = mp_sub(&v, &u, &v)) != MP_OKAY) { goto LBL_ERR; } if ((res = mp_sub(&D, &B, &D)) != MP_OKAY) { goto LBL_ERR; } }/* if not zero goto step 4 */ if (mp_iszero(&u) == 0) { goto top; }/* now a = C, b = D, gcd == g*v *//* if v != 1 then there is no inverse */ if (mp_cmp_d(&v, 1) !=
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -