📄 mpi.c
字号:
register mp_word *_W; register mp_digit *tmpx;/* Alias for the W[] array */ _W = W;/* Alias for the digits of x */ tmpx = x->dp;/* Copy the digits of a into W[0..a->used-1] */ for (ix = 0; ix < x->used; ix++) { *_W++ = *tmpx++; }/* Zero the high words of W[a->used..m->used*2] */ for (; ix < n->used * 2 + 1; ix++) { *_W++ = 0; } }/* Now we proceed to zero successive digits from the least significant upwards. */ for (ix = 0; ix < n->used; ix++) {/* mu = ai * m' mod b We avoid a double precision multiplication (which isn't required) by casting the value down to a mp_digit. Note this requires that W[ix-1] have the carry cleared (see after the inner loop) */ register mp_digit mu; mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);/* a = a + mu * m * b**i This is computed in place and on the fly. The multiplication by b**i is handled by offseting which columns the results are added to. Note the comba method normally doesn't handle carries in the inner loop In this case we fix the carry from the previous column since the Montgomery reduction requires digits of the result (so far) [see above] to work. This is handled by fixing up one carry after the inner loop. The carry fixups are done in order so after these loops the first m->used words of W[] have the carries fixed */ { register int32 iy; register mp_digit *tmpn; register mp_word *_W;/* Alias for the digits of the modulus */ tmpn = n->dp;/* Alias for the columns set by an offset of ix */ _W = W + ix;/* inner loop */ for (iy = 0; iy < n->used; iy++) { *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++); } }/* Now fix carry for next digit, W[ix+1] */ W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT); }/* Now we have to propagate the carries and shift the words downward [all those least significant digits we zeroed]. */ { register mp_digit *tmpx; register mp_word *_W, *_W1;/* Now fix rest of carries *//* alias for current word */ _W1 = W + ix;/* alias for next word, where the carry goes */ _W = W + ++ix; for (; ix <= n->used * 2 + 1; ix++) { *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT); }/* copy out, A = A/b**n The result is A/b**n but instead of converting from an array of mp_word to mp_digit than calling mp_rshd we just copy them in the right order *//* alias for destination word */ tmpx = x->dp;/* alias for shifted double precision result */ _W = W + n->used; for (ix = 0; ix < n->used + 1; ix++) { *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK)); }/* zero oldused digits, if the input a was larger than m->used+1 we'll have to clear the digits */ for (; ix < olduse; ix++) { *tmpx++ = 0; } }/* Set the max used and clamp */ x->used = n->used + 1; mp_clamp(x);/* if A >= m then A = A - m */ if (mp_cmp_mag(x, n) != MP_LT) { return s_mp_sub(x, n, x); } return MP_OKAY;}/******************************************************************************//* High level addition (handles signs) */int32 mp_add (mp_int * a, mp_int * b, mp_int * c){ int32 sa, sb, res;/* Get sign of both inputs */ sa = a->sign; sb = b->sign;/* Handle two cases, not four. */ if (sa == sb) {/* Both positive or both negative. Add their magnitudes, copy the sign. */ c->sign = sa; res = s_mp_add (a, b, c); } else {/* One positive, the other negative. Subtract the one with the greater magnitude from the one of the lesser magnitude. The result gets the sign of the one with the greater magnitude. */ if (mp_cmp_mag (a, b) == MP_LT) { c->sign = sb; res = s_mp_sub (b, a, c); } else { c->sign = sa; res = s_mp_sub (a, b, c); } } return res;}/******************************************************************************//* Compare a digit. */int32 mp_cmp_d (mp_int * a, mp_digit b){/* Compare based on sign */ if (a->sign == MP_NEG) { return MP_LT; }/* Compare based on magnitude */ if (a->used > 1) { return MP_GT; }/* Compare the only digit of a to b */ if (a->dp[0] > b) { return MP_GT; } else if (a->dp[0] < b) { return MP_LT; } else { return MP_EQ; }}/******************************************************************************//* b = a/2 */int32 mp_div_2 (mp_int * a, mp_int * b){ int32 x, res, oldused;/* Copy */ if (b->alloc < a->used) { if ((res = mp_grow (b, a->used)) != MP_OKAY) { return res; } } oldused = b->used; b->used = a->used; { register mp_digit r, rr, *tmpa, *tmpb;/* Source alias */ tmpa = a->dp + b->used - 1;/* dest alias */ tmpb = b->dp + b->used - 1;/* carry */ r = 0; for (x = b->used - 1; x >= 0; x--) {/* Get the carry for the next iteration */ rr = *tmpa & 1;/* Shift the current digit, add in carry and store */ *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));/* Forward carry to next iteration */ r = rr; }/* Zero excess digits */ tmpb = b->dp + b->used; for (x = b->used; x < oldused; x++) { *tmpb++ = 0; } } b->sign = a->sign; mp_clamp (b); return MP_OKAY;}/******************************************************************************//* Computes xR**-1 == x (mod N) via Montgomery Reduction */#ifdef USE_SMALL_WORDint32 mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho){ int32 ix, res, digs; mp_digit mu;/* Can the fast reduction [comba] method be used? Note that unlike in mul you're safely allowed *less* than the available columns [255 per default] since carries are fixed up in the inner loop. */ digs = n->used * 2 + 1; if ((digs < MP_WARRAY) && n->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) { return fast_mp_montgomery_reduce (x, n, rho); }/* Grow the input as required. */ if (x->alloc < digs) { if ((res = mp_grow (x, digs)) != MP_OKAY) { return res; } } x->used = digs; for (ix = 0; ix < n->used; ix++) {/* mu = ai * rho mod b The value of rho must be precalculated via mp_montgomery_setup() such that it equals -1/n0 mod b this allows the following inner loop to reduce the input one digit at a time */ mu = (mp_digit)(((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK); /* a = a + mu * m * b**i */ { register int32 iy; register mp_digit *tmpn, *tmpx, u; register mp_word r;/* alias for digits of the modulus */ tmpn = n->dp;/* alias for the digits of x [the input] */ tmpx = x->dp + ix;/* set the carry to zero */ u = 0;/* Multiply and add in place */ for (iy = 0; iy < n->used; iy++) { /* compute product and sum */ r = ((mp_word)mu) * ((mp_word)*tmpn++) + ((mp_word) u) + ((mp_word) * tmpx); /* get carry */ u = (mp_digit)(r >> ((mp_word) DIGIT_BIT)); /* fix digit */ *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK)); } /* At this point the ix'th digit of x should be zero *//* propagate carries upwards as required */ while (u) { *tmpx += u; u = *tmpx >> DIGIT_BIT; *tmpx++ &= MP_MASK; } } }/* At this point the n.used'th least significant digits of x are all zero which means we can shift x to the right by n.used digits and the residue is unchanged.*/ /* x = x/b**n.used */ mp_clamp(x); mp_rshd (x, n->used); /* if x >= n then x = x - n */ if (mp_cmp_mag (x, n) != MP_LT) { return s_mp_sub (x, n, x); } return MP_OKAY;}#endif /* USE_SMALL_WORD *//******************************************************************************//* Setups the montgomery reduction stuff. */int32 mp_montgomery_setup (mp_int * n, mp_digit * rho){ mp_digit x, b;/* fast inversion mod 2**k Based on the fact that XA = 1 (mod 2**n) => (X(2-XA)) A = 1 (mod 2**2n) => 2*X*A - X*X*A*A = 1 => 2*(1) - (1) = 1*/ b = n->dp[0]; if ((b & 1) == 0) { return MP_VAL; } x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */ x = (x * (2 - b * x)) & MP_MASK; /* here x*a==1 mod 2**8 */#if !defined(MP_8BIT) x = (x * (2 - b * x)) & MP_MASK; /* here x*a==1 mod 2**8 */#endif /* MP_8BIT */#if defined(MP_64BIT) || !(defined(MP_8BIT) || defined(MP_16BIT)) x *= 2 - b * x; /* here x*a==1 mod 2**32 */#endif#ifdef MP_64BIT x *= 2 - b * x; /* here x*a==1 mod 2**64 */#endif /* MP_64BIT */ /* rho = -1/m mod b */ *rho = (((mp_word) 1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK; return MP_OKAY;}/******************************************************************************//* High level subtraction (handles signs) */int32 mp_sub (mp_int * a, mp_int * b, mp_int * c){ int32 sa, sb, res; sa = a->sign; sb = b->sign; if (sa != sb) {/* Subtract a negative from a positive, OR subtract a positive from a negative. In either case, ADD their magnitudes, and use the sign of the first number. */ c->sign = sa; res = s_mp_add (a, b, c); } else {/* Subtract a positive from a positive, OR subtract a negative from a negative. First, take the difference between their magnitudes, then... */ if (mp_cmp_mag (a, b) != MP_LT) {/* Copy the sign from the first */ c->sign = sa; /* The first has a larger or equal magnitude */ res = s_mp_sub (a, b, c); } else {/* The result has the *opposite* sign from the first number. */ c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;/* The second has a larger magnitude */ res = s_mp_sub (b, a, c); } } return res;}/******************************************************************************//* calc a value mod 2**b */int32 mp_mod_2d (mp_int * a, int32 b, mp_int * c){ int32 x, res;/* if b is <= 0 then zero the int32 */ if (b <= 0) { mp_zero (c); return MP_OKAY; }/* If the modulus is larger than the value than return */ if (b >=(int32) (a->used * DIGIT_BIT)) { res = mp_copy (a, c); return res; } /* copy */ if ((res = mp_copy (a, c)) != MP_OKAY) { return res; }/* Zero digits above the last digit of the modulus */ for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) { c->dp[x] = 0; }/* Clear the digit that is not completely outside/inside the modulus */ c->dp[b / DIGIT_BIT] &= (mp_digit) ((((mp_digit) 1) << (((mp_digit) b) % DIGIT_BIT)) - ((mp_digit) 1)); mp_clamp (c); return MP_OKAY;}/******************************************************************************//* Shift right a certain amount of digits. */void mp_rshd (mp_int * a, int32 b){ int32 x;/* If b <= 0 then ignore it */ if (b <= 0) { return; }/* If b > used then simply zero it and return.*/ if (a->used <= b) { mp_zero (a); return; } { register mp_digit *bottom, *top;/* Shift the digits down */ /* bottom */ bottom = a->dp; /* top [offset into digits] */ top = a->dp + b;/* This is implemented as a sliding window where the window is b-digits long and digits from the top of the window are copied to the bottom. e.g. b-2 | b-1 | b0 | b1 | b2 | ... | bb | ----> /\ | ----> \-------------------/ ----> */ for (x = 0; x < (a->used - b); x++) { *bottom++ = *top++; }/* Zero the top digits */ for (; x < a->used; x++) { *bottom++ = 0; } }/* Remove excess digits */ a->used -= b;}/******************************************************************************//* Low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */int32 s_mp_sub (mp_int * a, mp_int * b, mp_int * c){ int32 olduse, res, min, max;/* Find sizes */ min = b->used; max = a->used;/* init result */ if (c->alloc < max) { if ((res = mp_grow (c, max)) != MP_OKAY) { return res; } } olduse = c->used; c->used = max; { register mp_digit u, *tmpa, *tmpb, *tmpc; register int32 i;/* alias for digit pointers */ tmpa = a->dp; tmpb = b->dp; tmpc = c->dp;/* set carry to zero */ u = 0; for (i = 0; i < min; i++) { /* T[i] = A[i] - B[i] - U */ *tmpc = *tmpa++ - *tmpb++ - u;/* U = carry bit of T[i] Note this saves performing an AND operation since if a carry does occur it will propagate all the way to the MSB. As a result a single shift is enough to get the carry */ u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1)); /* Clear carry from T[i] */ *tmpc++ &= MP_MASK; }/* Now copy higher words if any, e.g. if A has more digits than B */ for (; i < max; i++) { /* T[i] = A[i] - U */ *tmpc = *tmpa++ - u; /* U = carry bit of T[i] */ u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1)); /* Clear carry from T[i] */ *tmpc++ &= MP_MASK; }/* Clear digits above used (since we may not have grown result above) */ for (i = c->used; i < olduse; i++) { *tmpc++ = 0; } } mp_clamp (c); return MP_OKAY;}/******************************************************************************//* integer signed division. c*b + d == a [e.g. a/b, c=quotient, d=remainder] HAC pp.598 Algorithm 14.20 Note that the description in HAC is horribly incomplete. For example, it doesn't consider the case where digits are removed from 'x' in the inner loop. It also doesn't consider the case that y has fewer than three
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -