⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mpi.c

📁 matrix ssl代码
💻 C
📖 第 1 页 / 共 5 页
字号:
int32 mp_mul_2 (mp_int * a, mp_int * b)
{
	int32		x, res, oldused;

/*
	grow to accomodate result
 */
	if (b->alloc < a->used + 1) {
		if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
			return res;
		}
	}

	oldused = b->used;
	b->used = a->used;

	{
		register mp_digit r, rr, *tmpa, *tmpb;

		/* alias for source */
		tmpa = a->dp;

		/* alias for dest */
		tmpb = b->dp;

		/* carry */
		r = 0;
		for (x = 0; x < a->used; x++) {

/*
			get what will be the *next* carry bit from the MSB of the 
			current digit 
 */
			rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));

/*
			now shift up this digit, add in the carry [from the previous]
 */
			*tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;

/*			copy the carry that would be from the source digit into the next
			iteration 
 */
			r = rr;
		}

/*
		new leading digit?
 */
		if (r != 0) {
/*
			add a MSB which is always 1 at this point
 */
			*tmpb = 1;
			++(b->used);
		}

/*
		now zero any excess digits on the destination that we didn't write to
 */
		tmpb = b->dp + b->used;
		for (x = b->used; x < oldused; x++) {
			*tmpb++ = 0;
		}
	}
	b->sign = a->sign;
	return MP_OKAY;
}

/******************************************************************************/
/*
	multiply by a digit
 */
int32 mp_mul_d(mp_int * a, mp_digit b, mp_int * c)
{
	mp_digit	u, *tmpa, *tmpc;
	mp_word		r;
	int32			ix, res, olduse;

/*
	make sure c is big enough to hold a*b
 */
	if (c->alloc < a->used + 1) {
		if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
			return res;
		}
	}

/*
	get the original destinations used count
 */
	olduse = c->used;

/*
	set the sign
 */
	c->sign = a->sign;

/*
	alias for a->dp [source]
 */
	tmpa = a->dp;

/*
	alias for c->dp [dest]
 */
	tmpc = c->dp;

	/* zero carry */
	u = 0;

	/* compute columns */
	for (ix = 0; ix < a->used; ix++) {
/*
		compute product and carry sum for this term
 */
		r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);

/*
		mask off higher bits to get a single digit
 */
		*tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));

/*
		send carry into next iteration
 */
		u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
	}

/*
	store final carry [if any] and increment ix offset
 */
	*tmpc++ = u;
	++ix;

/*
	now zero digits above the top
 */
	while (ix++ < olduse) {
		*tmpc++ = 0;
	}

	/* set used count */
	c->used = a->used + 1;
	mp_clamp(c);

	return MP_OKAY;
}

/******************************************************************************/
/*
	low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
 */
#ifdef USE_SMALL_WORD
int32 s_mp_sqr (psPool_t *pool, mp_int * a, mp_int * b)
{
	mp_int		t;
	int32			res, ix, iy, pa;
	mp_word		r;
	mp_digit	u, tmpx, *tmpt;

	pa = a->used;
	if ((res = mp_init_size(pool, &t, 2*pa + 1)) != MP_OKAY) {
		return res;
	}
	
/*
	default used is maximum possible size
 */
	t.used = 2*pa + 1;

	for (ix = 0; ix < pa; ix++) {
/*
		first calculate the digit at 2*ix
		calculate double precision result
 */
		r = ((mp_word) t.dp[2*ix]) +
			((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);

/*
		store lower part in result
 */
		t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));

/*
		get the carry
 */
		u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));

/*
		left hand side of A[ix] * A[iy]
 */
		tmpx = a->dp[ix];

/*
		alias for where to store the results
 */
		tmpt = t.dp + (2*ix + 1);

		for (iy = ix + 1; iy < pa; iy++) {
/*
			first calculate the product
 */
			r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);

/*
			now calculate the double precision result, note we use addition
			instead of *2 since it's easier to optimize
 */
			r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);

/*
			store lower part
 */
			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));

			/* get carry */
			u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
		}
		/* propagate upwards */
		while (u != ((mp_digit) 0)) {
			r       = ((mp_word) *tmpt) + ((mp_word) u);
			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
			u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
		}
	}

	mp_clamp (&t);
	mp_exch (&t, b);
	mp_clear (&t);
	return MP_OKAY;
}
#endif /* USE_SMALL_WORD */

/******************************************************************************/
/*
	fast squaring

	This is the comba method where the columns of the product are computed
	first then the carries are computed.  This has the effect of making a very
	simple inner loop that is executed the most

	W2 represents the outer products and W the inner.

	A further optimizations is made because the inner products are of the 
	form "A * B * 2".  The *2 part does not need to be computed until the end
	which is good because 64-bit shifts are slow!

	Based on Algorithm 14.16 on pp.597 of HAC.

	This is the 1.0 version, but no SSE stuff
*/
int32 fast_s_mp_sqr(psPool_t *pool, mp_int * a, mp_int * b)
{
	int32		olduse, res, pa, ix, iz;
	mp_digit	W[MP_WARRAY], *tmpx;
	mp_word		W1;

/* 
	grow the destination as required
 */
	pa = a->used + a->used;
	if (b->alloc < pa) {
		if ((res = mp_grow(b, pa)) != MP_OKAY) {
			return res;
		}
	}

/*
	number of output digits to produce
 */
	W1 = 0;
	for (ix = 0; ix < pa; ix++) { 
		int32		tx, ty, iy;
		mp_word		_W;
		mp_digit	*tmpy;

/*
		clear counter
 */
		_W = 0;

/*
		get offsets into the two bignums
 */
		ty = MIN(a->used-1, ix);
		tx = ix - ty;

/*
		setup temp aliases
 */
		tmpx = a->dp + tx;
		tmpy = a->dp + ty;

/*
		this is the number of times the loop will iterrate, essentially 
		while (tx++ < a->used && ty-- >= 0) { ... }
*/
		iy = MIN(a->used-tx, ty+1);

/*
		now for squaring tx can never equal ty 
		we halve the distance since they approach at a rate of 2x
		and we have to round because odd cases need to be executed
*/
		iy = MIN(iy, (ty-tx+1)>>1);

/*
		execute loop
 */
		for (iz = 0; iz < iy; iz++) {
			 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
		}

/*
		double the inner product and add carry
 */
		_W = _W + _W + W1;

/*
		even columns have the square term in them
 */
		if ((ix&1) == 0) {
			_W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
		}

/*
		store it
 */
		W[ix] = (mp_digit)(_W & MP_MASK);

/*
		make next carry
 */
		W1 = _W >> ((mp_word)DIGIT_BIT);
	}

/*
	setup dest
 */
	olduse  = b->used;
	b->used = a->used+a->used;

	{
		mp_digit *tmpb;
		tmpb = b->dp;
		for (ix = 0; ix < pa; ix++) {
			*tmpb++ = W[ix] & MP_MASK;
		}

/*
		clear unused digits [that existed in the old copy of c]
 */
		for (; ix < olduse; ix++) {
			*tmpb++ = 0;
		}
	}
	mp_clamp(b);
	return MP_OKAY;
}

/******************************************************************************/
/*
	computes a = 2**b 

	Simple algorithm which zeroes the int32, grows it then just sets one bit
	as required.
 */
int32 mp_2expt (mp_int * a, int32 b)
{
	int32		res;

/*
	zero a as per default
 */
	mp_zero (a);

/*
	grow a to accomodate the single bit
 */
	if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
		return res;
	}

/*
	set the used count of where the bit will go
 */
	a->used = b / DIGIT_BIT + 1;

/*
	put the single bit in its place
 */
	a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);

	return MP_OKAY;
}

/******************************************************************************/
/*
	init an mp_init for a given size
 */
int32 mp_init_size(psPool_t *pool, mp_int * a, int32 size)
{
	int		x;
/*
	pad size so there are always extra digits
 */
	size += (MP_PREC * 2) - (size % MP_PREC);	

/*
	alloc mem
 */
	a->dp = OPT_CAST(mp_digit) psMalloc(pool, sizeof (mp_digit) * size);
	if (a->dp == NULL) {
		return MP_MEM;
	}
	a->used  = 0;
	a->alloc = size;
	a->sign  = MP_ZPOS;

/*
	zero the digits
 */
	for (x = 0; x < size; x++) {
		a->dp[x] = 0;
	}
	return MP_OKAY;
}

/******************************************************************************/
/*
	low level addition, based on HAC pp.594, Algorithm 14.7
 */
int32 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
{
	mp_int		*x;
	int32			olduse, res, min, max;

/*
	find sizes, we let |a| <= |b| which means we have to sort them.  "x" will
	point to the input with the most digits
 */
	if (a->used > b->used) {
		min = b->used;
		max = a->used;
		x = a;
	} else {
		min = a->used;
		max = b->used;
		x = b;
	}

	/* init result */
	if (c->alloc < max + 1) {
		if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
			return res;
		}
	}

/*
	get old used digit count and set new one
 */
	olduse = c->used;
	c->used = max + 1;

	{
		register mp_digit u, *tmpa, *tmpb, *tmpc;
		register int32 i;

		/* alias for digit pointers */

		/* first input */
		tmpa = a->dp;

		/* second input */
		tmpb = b->dp;

		/* destination */
		tmpc = c->dp;

		/* zero the carry */
		u = 0;
		for (i = 0; i < min; i++) {
/*
			Compute the sum at one digit, T[i] = A[i] + B[i] + U
 */
			*tmpc = *tmpa++ + *tmpb++ + u;

/*
			U = carry bit of T[i]
 */
			u = *tmpc >> ((mp_digit)DIGIT_BIT);

/*
			take away carry bit from T[i]
 */
			*tmpc++ &= MP_MASK;
		}

/*
		now copy higher words if any, that is in A+B if A or B has more digits add
		those in 
 */
		if (min != max) {
			for (; i < max; i++) {
				/* T[i] = X[i] + U */
				*tmpc = x->dp[i] + u;

				/* U = carry bit of T[i] */
				u = *tmpc >> ((mp_digit)DIGIT_BIT);

				/* take away carry bit from T[i] */
				*tmpc++ &= MP_MASK;
			}
		}

		/* add carry */
		*tmpc++ = u;

/*
		clear digits above oldused
 */
		for (i = c->used; i < olduse; i++) {
			*tmpc++ = 0;
		}
	}

	mp_clamp (c);
	return MP_OKAY;
}

/******************************************************************************/
#ifdef USE_SMALL_WORD
/*
	FUTURE - this is never needed, SLOW or not, because RSA exponents are
	always odd.
*/
int32 mp_invmod(psPool_t *pool, mp_int * a, mp_int * b, mp_int * c)
{
	mp_int		x, y, u, v, A, B, C, D;
	int32			res;

/*
	b cannot be negative
 */
	if (b->sign == MP_NEG || mp_iszero(b) == 1) {
		return MP_VAL;
	}

/*
	if the modulus is odd we can use a faster routine instead
 */
	if (mp_isodd (b) == 1) {
		return fast_mp_invmod(pool, a, b, c);
	}

/*
	init temps
 */
	if ((res = _mp_init_multi(pool, &x, &y, &u, &v,
			&A, &B, &C, &D)) != MP_OKAY) {
		return res;
	}

	/* x = a, y = b */
	if ((res = mp_copy(a, &x)) != MP_OKAY) {
		goto LBL_ERR;
	}
	if ((res = mp_copy(b, &y)) != MP_OKAY) {
		goto LBL_ERR;
	}

/*
	2. [modified] if x,y are both even then return an error!
 */
	if (mp_iseven(&x) == 1 && mp_iseven (&y) == 1) {
		res = MP_VAL;
		goto LBL_ERR;
	}

/*
	3. u=x, v=y, A=1, B=0, C=0,D=1
 */
	if ((res = mp_copy(&x, &u)) != MP_OKAY) {
		goto LBL_ERR;
	}
	if ((res = mp_copy(&y, &v)) != MP_OKAY) {
		goto LBL_ERR;
	}
	mp_set (&A, 1);
	mp_set (&D, 1);

top:
/*
	4.  while u is even do
 */
	while (mp_iseven(&u) == 1) {
		/* 4.1 u = u/2 */
		if ((res = mp_div_2(&u, &u)) != MP_OKAY) {
		goto LBL_ERR;
		}
		/* 4.2 if A or B is odd then */
		if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
			/* A = (A+y)/2, B = (B-x)/2 */
			if ((res = mp_add(&A, &y, &A)) != MP_OKAY) {
				goto LBL_ERR;
			}
			if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) {
				goto LBL_ERR;
			}
		}
		/* A = A/2, B = B/2 */
		if ((res = mp_div_2(&A, &A)) != MP_OKAY) {
			goto LBL_ERR;
		}
		if ((res = mp_div_2(&B, &B)) != MP_OKAY) {
			goto LBL_ERR;
		}
	}

/*
	5.  while v is even do
 */
	while (mp_iseven(&v) == 1) {
		/* 5.1 v = v/2 */
		if ((res = mp_div_2(&v, &v)) != MP_OKAY) {
		goto LBL_ERR;
		}
		/* 5.2 if C or D is odd then */
		if (mp_isodd(&C) == 1 || mp_isodd (&D) == 1) {
			/* C = (C+y)/2, D = (D-x)/2 */
			if ((res = mp_add(&C, &y, &C)) != MP_OKAY) {
				goto LBL_ERR;
			}
			if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) {
				goto LBL_ERR;
			}
		}
		/* C = C/2, D = D/2 */
		if ((res = mp_div_2(&C, &C)) != MP_OKAY) {
			goto LBL_ERR;
		}
		if ((res = mp_div_2(&D, &D)) != MP_OKAY) {
			goto LBL_ERR;
		}
	}

/*
	6.  if u >= v t

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -