📄 zzx.cpp
字号:
if (&x == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-db);
jmax = min(da, i);
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], bp[i-j]);
add(accum, accum, t);
}
xp[i] = accum;
}
x.normalize();
}
void PlainSqr(ZZX& x, const ZZX& a)
{
long da = deg(a);
if (da < 0) {
clear(x);
return;
}
long d = 2*da;
const ZZ *ap;
ZZ *xp;
ZZX la;
if (&x == &a) {
la = a;
ap = la.rep.elts();
}
else
ap = a.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
long m, m2;
ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], ap[i-j]);
add(accum, accum, t);
}
add(accum, accum, accum);
if (m & 1) {
sqr(t, ap[jmax + 1]);
add(accum, accum, t);
}
xp[i] = accum;
}
x.normalize();
}
static
void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
{
if (sa == 0 || sb == 0) return;
long sx = sa+sb-1;
long i, j, jmin, jmax;
static ZZ t, accum;
for (i = 0; i < sx; i++) {
jmin = max(0, i-sb+1);
jmax = min(sa-1, i);
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], bp[i-j]);
add(accum, accum, t);
}
xp[i] = accum;
}
}
static
void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
{
long m = sb - hsa;
long i;
for (i = 0; i < m; i++)
add(T[i], b[i], b[hsa+i]);
for (i = m; i < hsa; i++)
T[i] = b[i];
}
static
void KarSub(ZZ *T, const ZZ *b, long sb)
{
long i;
for (i = 0; i < sb; i++)
sub(T[i], T[i], b[i]);
}
static
void KarAdd(ZZ *T, const ZZ *b, long sb)
{
long i;
for (i = 0; i < sb; i++)
add(T[i], T[i], b[i]);
}
static
void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
{
long i;
for (i = 0; i < hsa; i++)
c[i] = b[i];
for (i = hsa; i < sb; i++)
add(c[i], c[i], b[i]);
}
static void PlainMul1(ZZ *xp, const ZZ *ap, long sa, const ZZ& b)
{
long i;
for (i = 0; i < sa; i++)
mul(xp[i], ap[i], b);
}
static
void KarMul(ZZ *c, const ZZ *a,
long sa, const ZZ *b, long sb, ZZ *stk)
{
if (sa < sb) {
{ long t = sa; sa = sb; sb = t; }
{ const ZZ *t = a; a = b; b = t; }
}
if (sb == 1) {
if (sa == 1)
mul(*c, *a, *b);
else
PlainMul1(c, a, sa, *b);
return;
}
if (sb == 2 && sa == 2) {
mul(c[0], a[0], b[0]);
mul(c[2], a[1], b[1]);
add(stk[0], a[0], a[1]);
add(stk[1], b[0], b[1]);
mul(c[1], stk[0], stk[1]);
sub(c[1], c[1], c[0]);
sub(c[1], c[1], c[2]);
return;
}
long hsa = (sa + 1) >> 1;
if (hsa < sb) {
/* normal case */
long hsa2 = hsa << 1;
ZZ *T1, *T2, *T3;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa;
T3 = stk; stk += hsa2 - 1;
/* compute T1 = a_lo + a_hi */
KarFold(T1, a, sa, hsa);
/* compute T2 = b_lo + b_hi */
KarFold(T2, b, sb, hsa);
/* recursively compute T3 = T1 * T2 */
KarMul(T3, T1, hsa, T2, hsa, stk);
/* recursively compute a_hi * b_hi into high part of c */
/* and subtract from T3 */
KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);
/* recursively compute a_lo*b_lo into low part of c */
/* and subtract from T3 */
KarMul(c, a, hsa, b, hsa, stk);
KarSub(T3, c, hsa2 - 1);
clear(c[hsa2 - 1]);
/* finally, add T3 * X^{hsa} to c */
KarAdd(c+hsa, T3, hsa2-1);
}
else {
/* degenerate case */
ZZ *T;
T = stk; stk += hsa + sb - 1;
/* recursively compute b*a_hi into high part of c */
KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);
/* recursively compute b*a_lo into T */
KarMul(T, a, hsa, b, sb, stk);
KarFix(c, T, hsa + sb - 1, hsa);
}
}
void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
{
if (IsZero(a) || IsZero(b)) {
clear(c);
return;
}
if (&a == &b) {
KarSqr(c, a);
return;
}
vec_ZZ mem;
const ZZ *ap, *bp;
ZZ *cp;
long sa = a.rep.length();
long sb = b.rep.length();
if (&a == &c) {
mem = a.rep;
ap = mem.elts();
}
else
ap = a.rep.elts();
if (&b == &c) {
mem = b.rep;
bp = mem.elts();
}
else
bp = b.rep.elts();
c.rep.SetLength(sa+sb-1);
cp = c.rep.elts();
long maxa, maxb, xover;
maxa = MaxBits(a);
maxb = MaxBits(b);
xover = 2;
if (sa < xover || sb < xover)
PlainMul(cp, ap, sa, bp, sb);
else {
/* karatsuba */
long n, hn, sp, depth;
n = max(sa, sb);
sp = 0;
depth = 0;
do {
hn = (n+1) >> 1;
sp += (hn << 2) - 1;
n = hn;
depth++;
} while (n >= xover);
ZZVec stk;
stk.SetSize(sp,
((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10)
+ NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
KarMul(cp, ap, sa, bp, sb, stk.elts());
}
c.normalize();
}
void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
{
if (sa == 0) return;
long da = sa-1;
long d = 2*da;
long i, j, jmin, jmax;
long m, m2;
static ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], ap[i-j]);
add(accum, accum, t);
}
add(accum, accum, accum);
if (m & 1) {
sqr(t, ap[jmax + 1]);
add(accum, accum, t);
}
xp[i] = accum;
}
}
void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk)
{
if (sa == 1) {
sqr(*c, *a);
return;
}
if (sa == 2) {
sqr(c[0], a[0]);
sqr(c[2], a[1]);
mul(c[1], a[0], a[1]);
add(c[1], c[1], c[1]);
return;
}
if (sa == 3) {
sqr(c[0], a[0]);
mul(c[1], a[0], a[1]);
add(c[1], c[1], c[1]);
sqr(stk[0], a[1]);
mul(c[2], a[0], a[2]);
add(c[2], c[2], c[2]);
add(c[2], c[2], stk[0]);
mul(c[3], a[1], a[2]);
add(c[3], c[3], c[3]);
sqr(c[4], a[2]);
return;
}
long hsa = (sa + 1) >> 1;
long hsa2 = hsa << 1;
ZZ *T1, *T2;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa2-1;
KarFold(T1, a, sa, hsa);
KarSqr(T2, T1, hsa, stk);
KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
KarSqr(c, a, hsa, stk);
KarSub(T2, c, hsa2 - 1);
clear(c[hsa2 - 1]);
KarAdd(c+hsa, T2, hsa2-1);
}
void KarSqr(ZZX& c, const ZZX& a)
{
if (IsZero(a)) {
clear(c);
return;
}
vec_ZZ mem;
const ZZ *ap;
ZZ *cp;
long sa = a.rep.length();
if (&a == &c) {
mem = a.rep;
ap = mem.elts();
}
else
ap = a.rep.elts();
c.rep.SetLength(sa+sa-1);
cp = c.rep.elts();
long maxa, xover;
maxa = MaxBits(a);
xover = 2;
if (sa < xover)
PlainSqr(cp, ap, sa);
else {
/* karatsuba */
long n, hn, sp, depth;
n = sa;
sp = 0;
depth = 0;
do {
hn = (n+1) >> 1;
sp += hn+hn+hn - 1;
n = hn;
depth++;
} while (n >= xover);
ZZVec stk;
stk.SetSize(sp,
((2*maxa + NumBits(sa) + 2*depth + 10)
+ NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
KarSqr(cp, ap, sa, stk.elts());
}
c.normalize();
}
NTL_END_IMPL
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -