📄 zzx1.cpp
字号:
ZZ prod;
set(prod);
zz_pBak bak;
bak.save();
long FirstTime = 1;
long i;
for (i = 0; ;i++) {
zz_p::FFTInit(i);
long p = zz_p::modulus();
if (divide(LeadCoeff(f1), p) || divide(LeadCoeff(f2), p)) continue;
zz_pX G, F1, F2;
zz_p LD;
conv(F1, f1);
conv(F2, f2);
conv(LD, ld);
GCD(G, F1, F2);
mul(G, G, LD);
if (deg(G) == 0) {
set(res);
break;
}
if (FirstTime || deg(G) < deg(g)) {
FirstTime = 0;
conv(prod, p);
BalCopy(g, G);
}
else if (deg(G) > deg(g))
continue;
else if (!CRT(g, prod, G)) {
PrimitivePart(res, g);
if (divide(f1, res) && divide(f2, res))
break;
}
}
bak.restore();
mul(d, res, c);
if (sign(LeadCoeff(d)) < 0) negate(d, d);
}
void trunc(ZZX& x, const ZZX& a, long m)
// x = a % X^m, output may alias input
{
if (m < 0) Error("trunc: bad args");
if (&x == &a) {
if (x.rep.length() > m) {
x.rep.SetLength(m);
x.normalize();
}
}
else {
long n;
long i;
ZZ* xp;
const ZZ* ap;
n = min(a.rep.length(), m);
x.rep.SetLength(n);
xp = x.rep.elts();
ap = a.rep.elts();
for (i = 0; i < n; i++) xp[i] = ap[i];
x.normalize();
}
}
void LeftShift(ZZX& x, const ZZX& a, long n)
{
if (n < 0) {
if (n < -NTL_MAX_LONG) Error("overflow in LeftShift");
RightShift(x, a, -n);
return;
}
if (n >= (1L << (NTL_BITS_PER_LONG-4)))
Error("overflow in LeftShift");
if (IsZero(a)) {
clear(x);
return;
}
long m = a.rep.length();
x.rep.SetLength(m+n);
long i;
for (i = m-1; i >= 0; i--)
x.rep[i+n] = a.rep[i];
for (i = 0; i < n; i++)
clear(x.rep[i]);
}
void RightShift(ZZX& x, const ZZX& a, long n)
{
if (n < 0) {
if (n < -NTL_MAX_LONG) Error("overflow in RightShift");
LeftShift(x, a, -n);
return;
}
long da = deg(a);
long i;
if (da < n) {
clear(x);
return;
}
if (&x != &a)
x.rep.SetLength(da-n+1);
for (i = 0; i <= da-n; i++)
x.rep[i] = a.rep[i+n];
if (&x == &a)
x.rep.SetLength(da-n+1);
x.normalize();
}
void TraceVec(vec_ZZ& S, const ZZX& ff)
{
if (!IsOne(LeadCoeff(ff)))
Error("TraceVec: bad args");
ZZX f;
f = ff;
long n = deg(f);
S.SetLength(n);
if (n == 0)
return;
long k, i;
ZZ acc, t;
S[0] = n;
for (k = 1; k < n; k++) {
mul(acc, f.rep[n-k], k);
for (i = 1; i < k; i++) {
mul(t, f.rep[n-i], S[k-i]);
add(acc, acc, t);
}
negate(S[k], acc);
}
}
static
void EuclLength(ZZ& l, const ZZX& a)
{
long n = a.rep.length();
long i;
ZZ sum, t;
clear(sum);
for (i = 0; i < n; i++) {
sqr(t, a.rep[i]);
add(sum, sum, t);
}
if (sum > 1) {
SqrRoot(l, sum);
add(l, l, 1);
}
else
l = sum;
}
static
long ResBound(const ZZX& a, const ZZX& b)
{
if (IsZero(a) || IsZero(b))
return 0;
ZZ t1, t2, t;
EuclLength(t1, a);
EuclLength(t2, b);
power(t1, t1, deg(b));
power(t2, t2, deg(a));
mul(t, t1, t2);
return NumBits(t);
}
void resultant(ZZ& rres, const ZZX& a, const ZZX& b, long deterministic)
{
if (IsZero(a) || IsZero(b)) {
clear(rres);
return;
}
zz_pBak zbak;
zbak.save();
ZZ_pBak Zbak;
Zbak.save();
long instable = 1;
long bound = 2+ResBound(a, b);
long gp_cnt = 0;
ZZ res, prod;
clear(res);
set(prod);
long i;
for (i = 0; ; i++) {
if (NumBits(prod) > bound)
break;
if (!deterministic &&
!instable && bound > 1000 && NumBits(prod) < 0.25*bound) {
ZZ P;
long plen = 90 + NumBits(max(bound, NumBits(res)));
do {
GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
}
while (divide(LeadCoeff(a), P) || divide(LeadCoeff(b), P));
ZZ_p::init(P);
ZZ_pX A, B;
conv(A, a);
conv(B, b);
ZZ_p t;
resultant(t, A, B);
if (CRT(res, prod, rep(t), P))
instable = 1;
else
break;
}
zz_p::FFTInit(i);
long p = zz_p::modulus();
if (divide(LeadCoeff(a), p) || divide(LeadCoeff(b), p))
continue;
zz_pX A, B;
conv(A, a);
conv(B, b);
zz_p t;
resultant(t, A, B);
instable = CRT(res, prod, rep(t), p);
}
rres = res;
zbak.restore();
Zbak.restore();
}
void MinPolyMod(ZZX& gg, const ZZX& a, const ZZX& f)
{
if (!IsOne(LeadCoeff(f)) || deg(f) < 1 || deg(a) >= deg(f))
Error("MinPolyMod: bad args");
if (IsZero(a)) {
SetX(gg);
return;
}
ZZ_pBak Zbak;
Zbak.save();
zz_pBak zbak;
zbak.save();
long n = deg(f);
long instable = 1;
long gp_cnt = 0;
ZZ prod;
ZZX g;
clear(g);
set(prod);
long bound = -1;
long i;
for (i = 0; ; i++) {
if (deg(g) == n) {
if (bound < 0)
bound = 2+CharPolyBound(a, f);
if (NumBits(prod) > bound)
break;
}
if (!instable &&
(deg(g) < n ||
(deg(g) == n && bound > 1000 && NumBits(prod) < 0.75*bound))) {
// guarantees 2^{-80} error probability
long plen = 90 + max( 2*NumBits(n) + NumBits(MaxBits(f)),
max( NumBits(n) + NumBits(MaxBits(a)),
NumBits(MaxBits(g)) ));
ZZ P;
GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
ZZ_p::init(P);
ZZ_pX A, F, G;
conv(A, a);
conv(F, f);
conv(G, g);
ZZ_pXModulus FF;
build(FF, F);
ZZ_pX H;
CompMod(H, G, A, FF);
if (IsZero(H))
break;
instable = 1;
}
zz_p::FFTInit(i);
zz_pX A, F;
conv(A, a);
conv(F, f);
zz_pXModulus FF;
build(FF, F);
zz_pX G;
MinPolyMod(G, A, FF);
if (deg(G) < deg(g))
continue;
if (deg(G) > deg(g)) {
clear(g);
set(prod);
}
instable = CRT(g, prod, G);
}
gg = g;
Zbak.restore();
zbak.restore();
}
void XGCD(ZZ& rr, ZZX& ss, ZZX& tt, const ZZX& a, const ZZX& b,
long deterministic)
{
ZZ r;
resultant(r, a, b, deterministic);
if (IsZero(r)) {
clear(rr);
return;
}
zz_pBak bak;
bak.save();
long i;
long instable = 1;
ZZ tmp;
ZZ prod;
ZZX s, t;
set(prod);
clear(s);
clear(t);
for (i = 0; ; i++) {
zz_p::FFTInit(i);
long p = zz_p::modulus();
if (divide(LeadCoeff(a), p) || divide(LeadCoeff(b), p) || divide(r, p))
continue;
zz_p R;
conv(R, r);
zz_pX D, S, T, A, B;
conv(A, a);
conv(B, b);
if (!instable) {
conv(S, s);
conv(T, t);
zz_pX t1, t2;
mul(t1, A, S);
mul(t2, B, T);
add(t1, t1, t2);
if (deg(t1) == 0 && ConstTerm(t1) == R)
mul(prod, prod, p);
else
instable = 1;
}
if (instable) {
XGCD(D, S, T, A, B);
mul(S, S, R);
mul(T, T, R);
tmp = prod;
long Sinstable = CRT(s, tmp, S);
long Tinstable = CRT(t, prod, T);
instable = Sinstable || Tinstable;
}
if (!instable) {
long bound1 = NumBits(min(deg(a), deg(s)) + 1)
+ MaxBits(a) + MaxBits(s);
long bound2 = NumBits(min(deg(b), deg(t)) + 1)
+ MaxBits(b) + MaxBits(t);
long bound = 4 + max(NumBits(r), max(bound1, bound2));
if (NumBits(prod) > bound)
break;
}
}
rr = r;
ss = s;
tt = t;
bak.restore();
}
void NormMod(ZZ& x, const ZZX& a, const ZZX& f, long deterministic)
{
if (!IsOne(LeadCoeff(f)) || deg(a) >= deg(f) || deg(f) <= 0)
Error("norm: bad args");
if (IsZero(a)) {
clear(x);
return;
}
resultant(x, f, a, deterministic);
}
void TraceMod(ZZ& res, const ZZX& a, const ZZX& f)
{
if (!IsOne(LeadCoeff(f)) || deg(a) >= deg(f) || deg(f) <= 0)
Error("trace: bad args");
vec_ZZ S;
TraceVec(S, f);
InnerProduct(res, S, a.rep);
}
void discriminant(ZZ& d, const ZZX& a, long deterministic)
{
long m = deg(a);
if (m < 0) {
clear(d);
return;
}
ZZX a1;
ZZ res;
diff(a1, a);
resultant(res, a, a1, deterministic);
if (!divide(res, res, LeadCoeff(a)))
Error("discriminant: inexact division");
m = m & 3;
if (m >= 2)
negate(res, res);
d = res;
}
void MulMod(ZZX& x, const ZZX& a, const ZZX& b, const ZZX& f)
{
if (deg(a) >= deg(f) || deg(b) >= deg(f) || deg(f) == 0 ||
!IsOne(LeadCoeff(f)))
Error("MulMod: bad args");
ZZX t;
mul(t, a, b);
rem(x, t, f);
}
void SqrMod(ZZX& x, const ZZX& a, const ZZX& f)
{
if (deg(a) >= deg(f) || deg(f) == 0 || !IsOne(LeadCoeff(f)))
Error("MulMod: bad args");
ZZX t;
sqr(t, a);
rem(x, t, f);
}
static
void MulByXModAux(ZZX& h, const ZZX& a, const ZZX& f)
{
long i, n, m;
ZZ* hh;
const ZZ *aa, *ff;
ZZ t, z;
n = deg(f);
m = deg(a);
if (m >= n || n == 0 || !IsOne(LeadCoeff(f)))
Error("MulByXMod: bad args");
if (m < 0) {
clear(h);
return;
}
if (m < n-1) {
h.rep.SetLength(m+2);
hh = h.rep.elts();
aa = a.rep.elts();
for (i = m+1; i >= 1; i--)
hh[i] = aa[i-1];
clear(hh[0]);
}
else {
h.rep.SetLength(n);
hh = h.rep.elts();
aa = a.rep.elts();
ff = f.rep.elts();
negate(z, aa[n-1]);
for (i = n-1; i >= 1; i--) {
mul(t, z, ff[i]);
add(hh[i], aa[i-1], t);
}
mul(hh[0], z, ff[0]);
h.normalize();
}
}
void MulByXMod(ZZX& h, const ZZX& a, const ZZX& f)
{
if (&h == &f) {
ZZX hh;
MulByXModAux(hh, a, f);
h = hh;
}
else
MulByXModAux(h, a, f);
}
static
void EuclLength1(ZZ& l, const ZZX& a)
{
long n = a.rep.length();
long i;
ZZ sum, t;
clear(sum);
for (i = 0; i < n; i++) {
sqr(t, a.rep[i]);
add(sum, sum, t);
}
abs(t, ConstTerm(a));
mul(t, t, 2);
add(t, t, 1);
add(sum, sum, t);
if (sum > 1) {
SqrRoot(l, sum);
add(l, l, 1);
}
else
l = sum;
}
long CharPolyBound(const ZZX& a, const ZZX& f)
// This computes a bound on the size of the
// coefficients of the characterstic polynomial.
// It use the relation characterization of the char poly as
// resultant_y(f(y), x-a(y)), and then interpolates this
// through complex primimitive (deg(f)+1)-roots of unity.
{
if (IsZero(a) || IsZero(f))
Error("CharPolyBound: bad args");
ZZ t1, t2, t;
EuclLength1(t1, a);
EuclLength(t2, f);
power(t1, t1, deg(f));
power(t2, t2, deg(a));
mul(t, t1, t2);
return NumBits(t);
}
void SetCoeff(ZZX& x, long i, long a)
{
if (a == 1)
SetCoeff(x, i);
else {
static ZZ aa;
conv(aa, a);
SetCoeff(x, i, aa);
}
}
// vectors
NTL_vector_impl(ZZX,vec_ZZX)
NTL_eq_vector_impl(ZZX,vec_ZZX)
NTL_io_vector_impl(ZZX,vec_ZZX)
void CopyReverse(ZZX& x, const ZZX& a, long hi)
// x[0..hi] = reverse(a[0..hi]), with zero fill
// input may not alias output
{
long i, j, n, m;
n = hi+1;
m = a.rep.length();
x.rep.SetLength(n);
const ZZ* ap = a.rep.elts();
ZZ* xp = x.rep.elts();
for (i = 0; i < n; i++) {
j = hi-i;
if (j < 0 || j >= m)
clear(xp[i]);
else
xp[i] = ap[j];
}
x.normalize();
}
void reverse(ZZX& x, const ZZX& a, long hi)
{
if (hi < -1) Error("reverse: bad args");
if (&x == &a) {
ZZX tmp;
CopyReverse(tmp, a, hi);
x = tmp;
}
else
CopyReverse(x, a, hi);
}
void MulTrunc(ZZX& x, const ZZX& a, const ZZX& b, long n)
{
ZZX t;
mul(t, a, b);
trunc(x, t, n);
}
void SqrTrunc(ZZX& x, const ZZX& a, long n)
{
ZZX t;
sqr(t, a);
trunc(x, t, n);
}
void NewtonInvTrunc(ZZX& c, const ZZX& a, long e)
{
ZZ x;
if (ConstTerm(a) == 1)
x = 1;
else if (ConstTerm(a) == -1)
x = -1;
else
Error("InvTrunc: non-invertible constant term");
if (e == 1) {
conv(c, x);
return;
}
static vec_long E;
E.SetLength(0);
append(E, e);
while (e > 1) {
e = (e+1)/2;
append(E, e);
}
long L = E.length();
ZZX g, g0, g1, g2;
g.rep.SetMaxLength(e);
g0.rep.SetMaxLength(e);
g1.rep.SetMaxLength((3*e+1)/2);
g2.rep.SetMaxLength(e);
conv(g, x);
long i;
for (i = L-1; i > 0; i--) {
// lift from E[i] to E[i-1]
long k = E[i];
long l = E[i-1]-E[i];
trunc(g0, a, k+l);
mul(g1, g0, g);
RightShift(g1, g1, k);
trunc(g1, g1, l);
mul(g2, g1, g);
trunc(g2, g2, l);
LeftShift(g2, g2, k);
sub(g, g, g2);
}
c = g;
}
void InvTrunc(ZZX& c, const ZZX& a, long e)
{
if (e < 0) Error("InvTrunc: bad args");
if (e == 0) {
clear(c);
return;
}
if (e >= (1L << (NTL_BITS_PER_LONG-4)))
Error("overflow in InvTrunc");
NewtonInvTrunc(c, a, e);
}
NTL_END_IMPL
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -