⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 poly.cpp

📁 miracl-大数运算库,大家使用有什么问题请多多提意见
💻 CPP
📖 第 1 页 / 共 2 页
字号:
/*
 * C++ class to implement a polynomial type and to allow 
 * arithmetic on polynomials whose elements are from
 * the finite field mod p
 *
 * WARNING: This class has been cobbled together for a specific use with
 * the MIRACL library. It is not complete, and may not work in other 
 * applications
 *
 * See Knuth The Art of Computer Programming Vol.2, Chapter 4.6 
 */

#include "poly.h"

#include <iostream>
using namespace std;

Poly::Poly(const ZZn& c,int p)
{
    start=NULL;
    addterm(c,p);
}

Poly::Poly(Variable &x)
{
    start=NULL;
    addterm((ZZn)1,1);
}

Poly operator-(const Poly& a)
{
    Poly p=a;
    p.multerm((ZZn)-1,0);
    return p;
}

Poly operator*(const ZZn& c,Variable x)
{
    Poly t(c,1);
    return t;
}

Poly pow(Variable x,int n)
{
    Poly r((ZZn)1,n);
    return r;
}

BOOL operator==(const Poly& a,const Poly& b)
{
    Poly diff=a-b;
    if (iszero(diff)) return TRUE;
    return FALSE;
}

BOOL operator!=(const Poly& a,const Poly& b)
{
    Poly diff=a-b;
    if (iszero(diff)) return FALSE;
    return TRUE;
}

void setpolymod(const Poly& p) 
{ 
    int n,m;
    Poly h;
    term *ptr;
    big *f,*rf;
    n=degree(p);
    if (n<FFT_BREAK_EVEN) return;
    h=reverse(p);
    h=invmodxn(h,n);
    h=reverse(h);   // h=RECIP(f)
    m=degree(h);
    if (m<n-1) h=mulxn(h,n-1-m);

    f=(big *)mr_alloc(n+1,sizeof(big));
    rf=(big *)mr_alloc(n+1,sizeof(big));
 
    ptr=p.start;
    while (ptr!=NULL)
    {
       f[ptr->n]=getbig(ptr->an);
       ptr=ptr->next;
    }   
    ptr=h.start;
    while (ptr!=NULL)
    {
       rf[ptr->n]=getbig(ptr->an);
       ptr=ptr->next;
    }   
 
    mr_polymod_set(n,rf,f);

    mr_free(rf);
    mr_free(f);
}

Poly::Poly(const Poly& p)
{
    term *ptr=p.start;
    term *pos=NULL;
    start=NULL;
    while (ptr!=NULL)
    {  
        pos=addterm(ptr->an,ptr->n,pos);
        ptr=ptr->next;
    }    
}


Poly::~Poly()
{
   term *nx;
   while (start!=NULL)
   {   
       nx=start->next;
       delete start;
       start=nx;
   }
}

ZZn Poly::coeff(int power)  const
{
    ZZn c=0;
    term *ptr=start;
    while (ptr!=NULL)
    {
        if (ptr->n==power)
        {
            c=ptr->an;
            return c;
        }
        ptr=ptr->next;
    }
    return c;
}

ZZn Poly::F(const ZZn& x) const
{
    ZZn f=0;
    int diff;
    term *ptr=start;

// Horner's rule

    if (ptr==NULL) return f;
    f=ptr->an;

    while (ptr->next!=NULL)
    {
        diff=ptr->n-ptr->next->n;
        if (diff==1) f=f*x+ptr->next->an;
        else         f=f*pow(x,diff)+ptr->next->an;    
        ptr=ptr->next;
    }
    f*=pow(x,ptr->n);

    return f;
}

ZZn Poly:: min() const
{
    term *ptr=start;
    if (start==NULL) return (ZZn)0;
    
    while (ptr->next!=NULL) ptr=ptr->next;
    return (ptr->an);
}

Poly compose(const Poly& g,const Poly& b,const Poly& m)
{ // compose polynomials
  // assume G(x) = G3x^3 + G2x^2 + G1x^1 +G0
  // Calculate G(B(x) = G3.(B(x))^3 + G2.(B(x))^2 ....   
    Poly c,t;  
    term *ptr;
    int i,d=degree(g);
    Poly *table=new Poly[d+1];
    table[0].addterm((ZZn)1,0);
    for (i=1;i<=d;i++) table[i]=(table[i-1]*b)%m;
    ptr=g.start;
    while (ptr!=NULL)
    {
        c+=ptr->an*table[ptr->n];
        c=c%m;
        ptr=ptr->next;
    }
    delete [] table;
    return c;
}

Poly compose(const Poly& g,const Poly& b)
{ // compose polynomials
  // assume G(x) = G3x^3 + G2x^2 + G1x^1 +G0
  // Calculate G(B(x) = G3.(B(x))^3 + G2.(B(x))^2 ....   
    Poly c,t;  
    term *ptr;
    int i,d=degree(g);
    Poly *table=new Poly[d+1];
    table[0].addterm((ZZn)1,0);
    for (i=1;i<=d;i++) table[i]=(table[i-1]*b);
    ptr=g.start;
    while (ptr!=NULL)
    {
        c+=ptr->an*table[ptr->n];
        ptr=ptr->next;
    }
    delete [] table;
    return c;
}

Poly reduce(const Poly &x,const Poly &m)
{
    Poly r;
    int i,d;
    ZZn t;
    big *G,*R;
    term *ptr,*pos=NULL;
    int degm=degree(m);
    int n=degree(x);

    if (degm < FFT_BREAK_EVEN || n-degm < FFT_BREAK_EVEN)
    {
        r=x%m;
        return r;
    }
    G=(big *)mr_alloc(n+1,sizeof(big));
    char *memg=(char *)memalloc(n+1);
    for (i=0;i<=n;i++) G[i]=mirvar_mem(memg,i);
    R=(big *)mr_alloc(degm,sizeof(big));
    char *memr=(char *)memalloc(degm);
    for (i=0;i<degm;i++) R[i]=mirvar_mem(memr,i);

    ptr=x.start;
    while (ptr!=NULL)
    {
        copy(getbig(ptr->an),G[ptr->n]);
        ptr=ptr->next; 
    }
    if (!mr_poly_rem(n,G,R))
    {  // reset the modulus - things have changed
        setpolymod(m);
        mr_poly_rem(n,G,R);
    }
 
    r.clear();

    for (d=degm-1;d>=0;d--)
    {
        t=R[d];
        if (t.iszero()) continue;
        pos=r.addterm(t,d,pos);
    }
    memkill(memr,degm);

    mr_free(R);
    memkill(memg,n+1);
    mr_free(G);

    return r;
}

Poly modmult(const Poly &x,const Poly &y,const Poly &m)
{ /* x*y mod m */
    Poly r=x*y;
    r=reduce(r,m);
    return r;
}

Poly operator*(const Poly& a,const Poly& b)
{
    int i,d,dega,degb,deg;
    BOOL squaring;
    ZZn t;
    Poly prod;
    term *iptr,*pos;
    term *ptr=b.start;

    squaring=FALSE;
    if (&a==&b) squaring=TRUE;

    dega=degree(a);
    deg=dega;
    if (!squaring)
    {
        degb=degree(b);
        if (degb<dega) deg=degb;
    }
    else degb=dega;
    if (deg>=FFT_BREAK_EVEN)      /* deg is minimum - both must be less than FFT_BREAK_EVEN */
    { // use fast methods 
        big *A,*B,*C;
        deg=dega+degb;     // degree of product
   
        A=(big *)mr_alloc(dega+1,sizeof(big));
        if (!squaring) B=(big *)mr_alloc(degb+1,sizeof(big));        
        C=(big *)mr_alloc(deg+1,sizeof(big));
        char *memc=(char *)memalloc(deg+1);
        for (i=0;i<=deg;i++) C[i]=mirvar_mem(memc,i);
        ptr=a.start;
        while (ptr!=NULL)
        {
            A[ptr->n]=getbig(ptr->an);
            ptr=ptr->next;
        }

        if (!squaring)
        {
            ptr=b.start;
            while (ptr!=NULL)
            {
                B[ptr->n]=getbig(ptr->an);
                ptr=ptr->next;
            }
            mr_poly_mul(dega,A,degb,B,C);
        }
        else mr_poly_sqr(dega,A,C);
        pos=NULL;
        for (d=deg;d>=0;d--)
        {
            t=C[d];
            if (t.iszero()) continue;
            pos=prod.addterm(t,d,pos);
        }
        memkill(memc,deg+1);
        mr_free(C);
        mr_free(A);
        if (!squaring) mr_free(B);

        return prod;
    }

    if (squaring)
    { // squaring
        pos=NULL;
        while (ptr!=NULL)
        { // diagonal terms
            pos=prod.addterm(ptr->an*ptr->an,ptr->n+ptr->n,pos);
            ptr=ptr->next;
        }
        ptr=b.start;
        while (ptr!=NULL)
        { // above the diagonal
            iptr=ptr->next;
            pos=NULL;
            while (iptr!=NULL)
            {
                t=ptr->an*iptr->an;
                pos=prod.addterm(t+t,ptr->n+iptr->n,pos);
                iptr=iptr->next;
            }
            ptr=ptr->next; 
        }
    }
    else while (ptr!=NULL)
    {
        pos=NULL;
        iptr=a.start;
        while (iptr!=NULL)
        {
            pos=prod.addterm(ptr->an*iptr->an,ptr->n+iptr->n,pos);
            iptr=iptr->next;
        }
        ptr=ptr->next;
    }

    return prod;
}

Poly& Poly::operator%=(const Poly&v)
{
    ZZn m,pq;
    int power;
    term *rptr=start;
    term *vptr=v.start;
    term *ptr,*pos;
    if (degree(*this)<degree(v)) return *this;
    m=-((ZZn)1/vptr->an);
    while (rptr!=NULL && rptr->n>=vptr->n)
    {
        pq=rptr->an*m;
        power=rptr->n-vptr->n;
        pos=NULL;
        ptr=v.start;
        while (ptr!=NULL)
        {
            pos=addterm(ptr->an*pq,ptr->n+power,pos);
            ptr=ptr->next;
        } 
        rptr=start;
    }
    return *this;
}

Poly operator%(const Poly& u,const Poly&v)
{
    Poly r=u;
    r%=v;
    return r;
}

Poly operator/(const Poly& u,const Poly&v)
{
    Poly q,r=u;
    term *rptr=r.start;
    term *vptr=v.start;
    term *ptr,*pos;
    while (rptr!=NULL && rptr->n>=vptr->n)
    {
        Poly t=v;
        ZZn pq=rptr->an/vptr->an;
        int power=rptr->n-vptr->n;
  // quotient
        q.addterm(pq,power);
        t.multerm(-pq,power);
        ptr=t.start;
        pos=NULL;
        while (ptr!=NULL)
        {
            pos=r.addterm(ptr->an,ptr->n,pos);
            ptr=ptr->next;
        } 
        rptr=r.start;
    }
    return q;
}

Poly diff(const Poly& f)
{
   Poly d;
   term *pos=NULL;
   term *ptr=f.start;
   while (ptr!=NULL)
   {
       pos=d.addterm(ptr->an*ptr->n,ptr->n-1,pos);
       ptr=ptr->next;
   }

   return d;
}

Poly gcd(const Poly& f,const Poly& g)
{
    Poly a,b;
    a=f; b=g;
    term *ptr;
    forever
    {
        if (b.start==NULL)
        {
            ptr=a.start;
            a.multerm((ZZn)1/ptr->an,0);
            return a;
        }
        a%=b;
        if (a.start==NULL)
        {
            ptr=b.start;
            b.multerm((ZZn)1/ptr->an,0);
            return b;
        }
        b%=a;
    }
}

Poly pow(const Poly& f,int k)
{
    Poly u;
    int w,e,b;

    if (k==0)
    {
        u.addterm((ZZn)1,0);
        return u;
    }
    u=f;
    if (k==1) return u;

    e=k;
    b=0; while (k>1) {k>>=1; b++; }
    w=(1<<b);
    e-=w; w/=2;
    while (w>0)
    {
        u=(u*u);
        if (e>=w)
        {
           e-=w;
           u=(u*f);
        }
        w/=2; 
    }
    return u;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -