⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 makewfta.cpp

📁 任意精度计算的实现
💻 CPP
📖 第 1 页 / 共 2 页
字号:
    modint w, c, s, i;

    w = root (pr, 8);

    c = rcos (w);
    s = isin (w);

    i = w * w;

    w8[0] = iw8[0] = w8[1] = iw8[1] = w8[2] = iw8[2] = w8[3] = iw8[3] = 1;
    w8[4] = iw8[4] = c;
    w8[5] = -(iw8[5] = i);
    w8[6] = -(iw8[6] = i);
    w8[7] = -(iw8[7] = s);
}

void createw9table (modint pr, modint w9[], modint iw9[])
{
    modint w, c1, c2, c4, s1, s2, s3, s4;

    w = root (pr, 9);

    c1 = rcos (w);
    s1 = isin (w);

    s3 = isin (w * w * w);

    w *= w;

    c2 = rcos (w);
    s2 = isin (w);

    w *= w;

    c4 = rcos (w);
    s4 = isin (w);

    w9[0] = iw9[0] = 1;
    w9[1] = iw9[1] = w9[2] = iw9[2] = modint (1) / modint (2);
    w9[3] = iw9[3] = (c1 - modint (2) * c2 + c4) / modint (3);
    w9[4] = iw9[4] = (c1 + c2 - modint (2) * c4) / modint (3);
    w9[5] = iw9[5] = (modint (2) * c1 - c2 - c4) / modint (3);
    w9[6] = -(iw9[6] = s3);
    w9[7] = -(iw9[7] = s3);
    w9[8] = -(iw9[8] = (s1 + modint (2) * s2 + s4) / modint (3));
    w9[9] = -(iw9[9] = (s2 - s1 + modint (2) * s4) / modint (3));
    w9[10] = -(iw9[10] = (modint (2) * s1 + s2 - s4) / modint (3));
}

void createw16table (modint pr, modint w16[], modint iw16[])
{
    modint w, c1, c2, c3, s1, s2, s3, i;

    w = root (pr, 16);

    c1 = rcos (w);
    s1 = isin (w);

    c2 = rcos (w * w);
    s2 = isin (w * w);

    c3 = rcos (w * w * w);
    s3 = isin (w * w * w);

    w *= w;
    i = w * w;

    w16[0] = iw16[0] = w16[1] = iw16[1] = w16[2] = iw16[2] =
    w16[3] = iw16[3] = w16[4] = iw16[4] = 1;
    w16[5] = iw16[5] = w16[6] = iw16[6] = c2;
    w16[7] = iw16[7] = c3;
    w16[8] = iw16[8] = c3 + c1;
    w16[9] = iw16[9] = c3 - c1;
    w16[10] = -(iw16[10] = i);
    w16[11] = -(iw16[11] = i);
    w16[12] = -(iw16[12] = i);
    w16[13] = -(iw16[13] = s2);
    w16[14] = -(iw16[14] = s2);
    w16[15] = -(iw16[15] = s3);
    w16[16] = -(iw16[16] = s1 - s3);
    w16[17] = -(iw16[17] = s1 + s3);
}

void createswtables (modint pr, modint *wt[], modint *iwt[])
{
    createw2table (pr, wt[2], iwt[2]);
    createw3table (pr, wt[3], iwt[3]);
    createw4table (pr, wt[4], iwt[4]);
    createw5table (pr, wt[5], iwt[5]);
    createw7table (pr, wt[7], iwt[7]);
    createw8table (pr, wt[8], iwt[8]);
    createw9table (pr, wt[9], iwt[9]);
    createw16table (pr, wt[16], iwt[16]);
}

void createwtable (modint d[], modint id[], int i)
{
    int t, k, b, p, q, r;
    modint *c, *ic;

    for (t = n[i], k = 0; t > 1; t /= s[i][k], k++);

    for (t = 0; t < m[i]; t++)
        d[t] = id[t] = 1;

    b = m[i];

    for (t = 0; t < k; t++)
    {
        c = wt[s[i][t]];
        ic = iwt[s[i][t]];

        b /= ms[i][t];

        for (p = 0; p < m[i];)
            for (q = 0; q < ms[i][t]; q++)
                for (r = 0; r < b; r++, p++)
                {
                    d[p] *= c[q];
                    id[p] *= ic[q];
                }
    }
}

void createwtables (modint pr, modint *wt[], modint *iwt[], modint *w[], modint *iw[])
{
    int t, u;

    createswtables (pr, wt, iwt);

    for (t = 0; t < NALGS; t++)
        if (p[t]) createwtable (w[t], iw[t], t);

    // Scale already the values of the inverse transform
    for (t = 0; t < NALGS; t++)
        for (u = 0; u < m[t]; u++) iw[t][u] /= n[t];
}

int main (void)
{
    int l, t, u;

    ofstream o ("wftatab1"CEXT);

    if (o.fail ())
    {
        cerr << "Unable to open file wftatab1"CEXT << endl;
        return 1;
    }

    createpermutetables ();

    o << "// Permutation tables" << endl << "int ";

    for (t = 0; t < NALGS; t++)
        if (p[t])
        {
            o << "p" << n[t] << "[] = {";
            for (u = 0; u < n[t]; u++)
            {
                o << p[t][u] << (u == n[t] - 1 ? "}" : ", ");
                if ((u & 15) == 15 && u != n[t] - 1) o << endl << "    ";
            }
            o << (t == NALGS - 1 ? ";" : ", ") << endl << "    ";
        }

    o << endl << "int ";

    for (t = 0; t < NALGS; t++)
        if (ip[t])
        {
            o << "ip" << n[t] << "[] = {";
            for (u = 0; u < n[t]; u++)
            {
                o << ip[t][u] << (u == n[t] - 1 ? "}" : ", ");
                if ((u & 15) == 15 && u != n[t] - 1) o << endl << "    ";
            }
            o << (t == NALGS - 1 ? ";" : ", ") << endl << "    ";
        }

    o << endl << "int *p" << "[] = {";

    for (t = 0; t < NALGS; t++)
        if (p[t])
            o << "p" << n[t] << (t == NALGS - 1 ? "}" : ", ");
        else
            o << 0 << (t == NALGS - 1 ? "}" : ", ");

    o << ";" << endl;

    o << endl << "int *ip" << "[] = {";

    for (t = 0; t < NALGS; t++)
        if (ip[t])
            o << "ip" << n[t] << (t == NALGS - 1 ? "}" : ", ");
        else
            o << 0 << (t == NALGS - 1 ? "}" : ", ");

    o << ";" << endl;

    for (l = 0; l < 3; l++)
    {
        o.close ();

        switch (l)
        {
            case 0: o.open ("wftatab2"CEXT); break;
            case 1: o.open ("wftatab4"CEXT); break;
            case 2: o.open ("wftatab6"CEXT); break;
        }

        if (o.fail ())
        {
            cerr << "Unable to open file wftatab" << 2 * l + 2 << CEXT << endl;
            return 1;
        }

        o << "#include \"ap.h\"" << endl << endl;

        setmodulus (wftamoduli[l]);
        createwtables (wftaprimitiveroots[l], wt, iwt, w, iw);

        o << endl << "// W factor tables for WFTA for modulus " << l << endl << "rawtype ";

        for (t = 0; t < NALGS; t++)
        {
            o << "w" << n[t] << "_" << l << "[] = {";
            for (u = 0; u < m[t]; u++)
            {
                o << w[t][u] << (u == m[t] - 1 ? "}" : ", ");
                if ((u & 15) == 15 && u != m[t] - 1) o << endl << "         ";
            }
            o << (t == NALGS - 1 ? ";" : ", ") << endl << "         ";
        }

        o << endl << "modint *w_" << l << "[] = {";

        for (t = 0; t < NALGS; t++)
            o << "(modint *) w" << n[t] << "_" << l << (t == NALGS - 1 ? "}" : ", ");

        o << ";" << endl;

        o.close ();

        switch (l)
        {
            case 0: o.open ("wftatab3"CEXT); break;
            case 1: o.open ("wftatab5"CEXT); break;
            case 2: o.open ("wftatab7"CEXT); break;
        }

        if (o.fail ())
        {
            cerr << "Unable to open file wftatab" << 2 * l + 3 << CEXT << endl;
            return 1;
        }

        o << "#include \"ap.h\"" << endl << endl;

        o << endl << "// Inverse W factor tables for WFTA for modulus " << l << endl << "rawtype ";

        for (t = 0; t < NALGS; t++)
        {
            o << "iw" << n[t] << "_" << l << "[] = {";
            for (u = 0; u < m[t]; u++)
            {
                o << iw[t][u] << (u == m[t] - 1 ? "}" : ", ");
                if ((u & 15) == 15 && u != m[t] - 1) o << endl << "         ";
            }
            o << (t == NALGS - 1 ? ";" : ", ") << endl << "         ";
        }

        o << endl << "modint *iw_" << l << "[] = {";

        for (t = 0; t < NALGS; t++)
            o << "(modint *) iw" << n[t] << "_" << l << (t == NALGS - 1 ? "}" : ", ");

        o << ";" << endl;
    }

    clearmodulus ();

    return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -