📄 lzz_px.cpp
字号:
}
else {
/* karatsuba */
long n, hn, sp;
n = max(sa, sb);
sp = 0;
do {
hn = (n+1) >> 1;
sp += (hn << 2) - 1;
n = hn;
} while (n >= KARX);
vec_zz_p stk;
stk.SetLength(sp);
if (use_FP) {
a_buf.SetLength(max(sa, sb));
b_buf.SetLength(max(sa, sb));
KarMul_FP(cp, ap, sa, bp, sb, stk.elts());
}
else
KarMul(cp, ap, sa, bp, sb, stk.elts());
}
c.normalize();
}
void PlainSqr_FP(zz_p *xp, const zz_p *aap, long sa)
{
if (sa == 0) return;
long da = sa-1;
long d = 2*da;
long i, j, jmin, jmax, m, m2;
double *ap = a_buf.elts();
for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i]));
double accum;
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
accum = 0;
for (j = jmin; j <= jmax; j++) {
accum += ap[j]*ap[i-j];
}
accum += accum;
if (m & 1) {
accum += ap[jmax + 1]*ap[jmax + 1];
}
reduce(xp[i], accum, p, pinv);
}
}
void PlainSqr(zz_p *xp, const zz_p *ap, long sa)
{
if (sa == 0) return;
long i, j, k, cnt;
cnt = 2*sa-1;
for (i = 0; i < cnt; i++)
clear(xp[i]);
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
long t1, t2;
i = -1;
for (j = 0; j <= sa-2; j++) {
i += 2;
t1 = MulMod(rep(ap[j]), rep(ap[j]), p, pinv);
t2 = rep(xp[i-1]);
t2 = AddMod(t2, t2, p);
t2 = AddMod(t2, t1, p);
xp[i-1].LoopHole() = t2;
cnt = sa - 1 - j;
const zz_p *ap1 = ap+(j+1);
zz_p *xp1 = xp+i;
t1 = rep(ap[j]);
double tpinv = ((double) t1)*pinv;
for (k = 0; k < cnt; k++) {
t2 = MulMod2(rep(ap1[k]), t1, p, tpinv);
t2 = AddMod(t2, rep(xp1[k]), p);
xp1[k].LoopHole() = t2;
}
t2 = rep(*xp1);
t2 = AddMod(t2, t2, p);
(*xp1).LoopHole() = t2;
}
t1 = rep(ap[sa-1]);
t1 = MulMod(t1, t1, p, pinv);
xp[2*sa-2].LoopHole() = t1;
}
#define KARSX (30)
void KarSqr(zz_p *c, const zz_p *a, long sa, zz_p *stk)
{
if (sa < KARSX) {
PlainSqr(c, a, sa);
return;
}
long hsa = (sa + 1) >> 1;
long hsa2 = hsa << 1;
zz_p *T1, *T2;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa2-1;
KarFold(T1, a, sa, hsa);
KarSqr(T2, T1, hsa, stk);
KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
KarSqr(c, a, hsa, stk);
KarSub(T2, c, hsa2 - 1);
clear(c[hsa2 - 1]);
KarAdd(c+hsa, T2, hsa2-1);
}
void KarSqr_FP(zz_p *c, const zz_p *a, long sa, zz_p *stk)
{
if (sa < KARSX) {
PlainSqr_FP(c, a, sa);
return;
}
long hsa = (sa + 1) >> 1;
long hsa2 = hsa << 1;
zz_p *T1, *T2;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa2-1;
KarFold(T1, a, sa, hsa);
KarSqr_FP(T2, T1, hsa, stk);
KarSqr_FP(c + hsa2, a+hsa, sa-hsa, stk);
KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
KarSqr_FP(c, a, hsa, stk);
KarSub(T2, c, hsa2 - 1);
clear(c[hsa2 - 1]);
KarAdd(c+hsa, T2, hsa2-1);
}
void PlainSqr(zz_pX& c, const zz_pX& a)
{
if (IsZero(a)) {
clear(c);
return;
}
vec_zz_p mem;
const zz_p *ap;
zz_p *cp;
long sa = a.rep.length();
if (&a == &c) {
mem = a.rep;
ap = mem.elts();
}
else
ap = a.rep.elts();
c.rep.SetLength(2*sa-1);
cp = c.rep.elts();
long p = zz_p::modulus();
long use_FP = ((p < NTL_SP_BOUND/KARSX) &&
(double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARSX));
if (sa < KARSX) {
if (use_FP) {
a_buf.SetLength(sa);
PlainSqr_FP(cp, ap, sa);
}
else
PlainSqr(cp, ap, sa);
}
else {
/* karatsuba */
long n, hn, sp;
n = sa;
sp = 0;
do {
hn = (n+1) >> 1;
sp += hn+hn+hn - 1;
n = hn;
} while (n >= KARSX);
vec_zz_p stk;
stk.SetLength(sp);
if (use_FP) {
a_buf.SetLength(sa);
KarSqr_FP(cp, ap, sa, stk.elts());
}
else
KarSqr(cp, ap, sa, stk.elts());
}
c.normalize();
}
void PlainDivRem(zz_pX& q, zz_pX& r, const zz_pX& a, const zz_pX& b)
{
long da, db, dq, i, j, LCIsOne;
const zz_p *bp;
zz_p *qp;
zz_p *xp;
zz_p LCInv, t;
zz_p s;
da = deg(a);
db = deg(b);
if (db < 0) Error("zz_pX: division by zero");
if (da < db) {
r = a;
clear(q);
return;
}
zz_pX lb;
if (&q == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
if (IsOne(bp[db]))
LCIsOne = 1;
else {
LCIsOne = 0;
inv(LCInv, bp[db]);
}
vec_zz_p x;
if (&r == &a)
xp = r.rep.elts();
else {
x = a.rep;
xp = x.elts();
}
dq = da - db;
q.rep.SetLength(dq+1);
qp = q.rep.elts();
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
for (i = dq; i >= 0; i--) {
t = xp[i+db];
if (!LCIsOne)
mul(t, t, LCInv);
qp[i] = t;
negate(t, t);
long T = rep(t);
double Tpinv = ((double) T)*pinv;
for (j = db-1; j >= 0; j--) {
long S = MulMod2(rep(bp[j]), T, p, Tpinv);
S = AddMod(S, rep(xp[i+j]), p);
xp[i+j].LoopHole() = S;
}
}
r.rep.SetLength(db);
if (&r != &a) {
for (i = 0; i < db; i++)
r.rep[i] = xp[i];
}
r.normalize();
}
void PlainDiv(zz_pX& q, const zz_pX& a, const zz_pX& b)
{
long da, db, dq, i, j, LCIsOne;
const zz_p *bp;
zz_p *qp;
zz_p *xp;
zz_p LCInv, t;
zz_p s;
da = deg(a);
db = deg(b);
if (db < 0) Error("zz_pX: division by zero");
if (da < db) {
clear(q);
return;
}
zz_pX lb;
if (&q == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
if (IsOne(bp[db]))
LCIsOne = 1;
else {
LCIsOne = 0;
inv(LCInv, bp[db]);
}
vec_zz_p x;
x.SetLength(da+1-db);
for (i = db; i <= da; i++)
x[i-db] = a.rep[i];
xp = x.elts();
dq = da - db;
q.rep.SetLength(dq+1);
qp = q.rep.elts();
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
for (i = dq; i >= 0; i--) {
t = xp[i];
if (!LCIsOne)
mul(t, t, LCInv);
qp[i] = t;
negate(t, t);
long T = rep(t);
double Tpinv = ((double) T)*pinv;
long lastj = max(0, db-i);
for (j = db-1; j >= lastj; j--) {
long S = MulMod2(rep(bp[j]), T, p, Tpinv);
S = AddMod(S, rep(xp[i+j-db]), p);
xp[i+j-db].LoopHole() = S;
}
}
}
void PlainRem(zz_pX& r, const zz_pX& a, const zz_pX& b)
{
long da, db, dq, i, j, LCIsOne;
const zz_p *bp;
zz_p *xp;
zz_p LCInv, t;
zz_p s;
da = deg(a);
db = deg(b);
if (db < 0) Error("zz_pX: division by zero");
if (da < db) {
r = a;
return;
}
bp = b.rep.elts();
if (IsOne(bp[db]))
LCIsOne = 1;
else {
LCIsOne = 0;
inv(LCInv, bp[db]);
}
vec_zz_p x;
if (&r == &a)
xp = r.rep.elts();
else {
x = a.rep;
xp = x.elts();
}
dq = da - db;
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
for (i = dq; i >= 0; i--) {
t = xp[i+db];
if (!LCIsOne)
mul(t, t, LCInv);
negate(t, t);
long T = rep(t);
double Tpinv = ((double) T)*pinv;
for (j = db-1; j >= 0; j--) {
long S = MulMod2(rep(bp[j]), T, p, Tpinv);
S = AddMod(S, rep(xp[i+j]), p);
xp[i+j].LoopHole() = S;
}
}
r.rep.SetLength(db);
if (&r != &a) {
for (i = 0; i < db; i++)
r.rep[i] = xp[i];
}
r.normalize();
}
void mul(zz_pX& x, const zz_pX& a, zz_p b)
{
if (IsZero(b)) {
clear(x);
return;
}
if (IsOne(b)) {
x = a;
return;
}
long i, da;
const zz_p *ap;
zz_p* xp;
long t;
t = rep(b);
long p = zz_p::modulus();
double pinv = zz_p::ModulusInverse();
double bpinv = t*pinv;
da = deg(a);
x.rep.SetLength(da+1);
ap = a.rep.elts();
xp = x.rep.elts();
for (i = 0; i <= da; i++)
xp[i].LoopHole() = MulMod2(rep(ap[i]), t, p, bpinv);
x.normalize();
}
void PlainGCD(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
zz_p t;
if (IsZero(b))
x = a;
else if (IsZero(a))
x = b;
else {
long n = max(deg(a),deg(b)) + 1;
zz_pX u(INIT_SIZE, n), v(INIT_SIZE, n);
u = a;
v = b;
do {
PlainRem(u, u, v);
swap(u, v);
} while (!IsZero(v));
x = u;
}
if (IsZero(x)) return;
if (IsOne(LeadCoeff(x))) return;
/* make gcd monic */
inv(t, LeadCoeff(x));
mul(x, x, t);
}
void PlainXGCD(zz_pX& d, zz_pX& s, zz_pX& t, const zz_pX& a, const zz_pX& b)
{
zz_p z;
if (IsZero(b)) {
set(s);
clear(t);
d = a;
}
else if (IsZero(a)) {
clear(s);
set(t);
d = b;
}
else {
long e = max(deg(a), deg(b)) + 1;
zz_pX temp(INIT_SIZE, e), u(INIT_SIZE, e), v(INIT_SIZE, e), u0(INIT_SIZE, e), v0(INIT_SIZE, e),
u1(INIT_SIZE, e), v1(INIT_SIZE, e), u2(INIT_SIZE, e), v2(INIT_SIZE, e), q(INIT_SIZE, e);
set(u1); clear(v1);
clear(u2); set(v2);
u = a; v = b;
do {
DivRem(q, u, u, v);
swap(u, v);
u0 = u2;
v0 = v2;
mul(temp, q, u2);
sub(u2, u1, temp);
mul(temp, q, v2);
sub(v2, v1, temp);
u1 = u0;
v1 = v0;
} while (!IsZero(v));
d = u;
s = u1;
t = v1;
}
if (IsZero(d)) return;
if (IsOne(LeadCoeff(d))) return;
/* make gcd monic */
inv(z, LeadCoeff(d));
mul(d, d, z);
mul(s, s, z);
mul(t, t, z);
}
void MulMod(zz_pX& x, const zz_pX& a, const zz_pX& b, const zz_pX& f)
{
if (deg(a) >= deg(f) || deg(b) >= deg(f) || deg(f) == 0)
Error("MulMod: bad args");
zz_pX t;
mul(t, a, b);
rem(x, t, f);
}
void SqrMod(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
if (deg(a) >= deg(f) || deg(f) == 0) Error("SqrMod: bad args");
zz_pX t;
sqr(t, a);
rem(x, t, f);
}
void InvMod(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
if (deg(a) >= deg(f) || deg(f) == 0) Error("InvMod: bad args");
zz_pX d, t;
XGCD(d, x, t, a, f);
if (!IsOne(d))
Error("zz_pX InvMod: can't compute multiplicative inverse");
}
long InvModStatus(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
if (deg(a) >= deg(f) || deg(f) == 0) Error("InvModStatus: bad args");
zz_pX d, t;
XGCD(d, x, t, a, f);
if (!IsOne(d)) {
x = d;
return 1;
}
else
return 0;
}
static
void MulByXModAux(zz_pX& h, const zz_pX& a, const zz_pX& f)
{
long i, n, m;
zz_p* hh;
const zz_p *aa, *ff;
zz_p t, z;
n = deg(f);
m = deg(a);
if (m >= n || n == 0) Error("MulByXMod: bad args");
if (m < 0) {
clear(h);
return;
}
if (m < n-1) {
h.rep.SetLength(m+2);
hh = h.rep.elts();
aa = a.rep.elts();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -