⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sixwftat.cpp

📁 任意精度计算的实现
💻 CPP
字号:
#include "ap.h"


// The "six-step" ntt using wfta, but doesn't transpose or scramble (for convolution only)
void sixstepwftatrans (modint data[], modint *w[], modint pr, int isign, int n1, int n2)
{
    int j, k;
    modint wt, tmp, tmp2, *p1, *p2;

    if (isign > 0)
        wt = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / (n1 * n2));
    else
        wt = pow (pr, (modint::modulus - 1) / (n1 * n2));

    modint *b = new modint[n1];

    // treat the input data as a n1 x n2 matrix

    // first do n2 transforms of length n1 in columns

    for (k = 0, p1 = data; k < n2; k++, p1++)
    {
        for (j = 0, p2 = p1; j < n1; j++, p2 += n2)
            b[j] = *p2;

        wfta (b, w, n1);

        for (j = 0, p2 = p1; j < n1; j++, p2 += n2)
            *p2 = b[j];
    }

    // then multiply the matrix A_ij by exp(isign * -2 pi i j k / nn)

    tmp = wt;
    for (j = 1, p1 = data + n2; j < n1; j++, p1 += n2)
    {
        tmp2 = pow (tmp, j);
        p1[j] *= tmp2;
        for (k = j + 1, p2 = p1 + n2 + j; k < n1; k++, p2 += n2)
        {
            tmp2 *= tmp;
            p1[k] *= tmp2;
            *p2 *= tmp2;
        }
        for (; k < n2; k++)
        {
            tmp2 *= tmp;
            p1[k] *= tmp2;
        }
        tmp *= wt;
    }

    // last do n1 transforms of length n2 in rows

    for (k = 0, p1 = data; k < n1; k++, p1 += n2)
        wfta (p1, w, n2);

    delete[] b;
}

void isixstepwftatrans (modint data[], modint *w[], modint pr, int isign, int n1, int n2)
{
    int j, k;
    modint wt, tmp, tmp2, *p1, *p2;

    if (isign > 0)
        wt = pow (pr, modint::modulus - 1 - (modint::modulus - 1) / (n1 * n2));
    else
        wt = pow (pr, (modint::modulus - 1) / (n1 * n2));

    modint *b = new modint[n1];

    // treat the input data as a n1 x n2 matrix

    // first do n1 transforms of length n2 in rows

    for (k = 0, p1 = data; k < n1; k++, p1 += n2)
        wfta (p1, w, n2);

    // then multiply the matrix A_ij by exp(isign * -2 pi i j k / nn)

    tmp = wt;
    for (j = 1, p1 = data + n2; j < n1; j++, p1 += n2)
    {
        tmp2 = pow (tmp, j);
        p1[j] *= tmp2;
        for (k = j + 1, p2 = p1 + n2 + j; k < n1; k++, p2 += n2)
        {
            tmp2 *= tmp;
            p1[k] *= tmp2;
            *p2 *= tmp2;
        }
        for (; k < n2; k++)
        {
            tmp2 *= tmp;
            p1[k] *= tmp2;
        }
        tmp *= wt;
    }

    // last do n2 transforms of length n1 in columns

    for (k = 0, p1 = data; k < n2; k++, p1++)
    {
        for (j = 0, p2 = p1; j < n1; j++, p2 += n2)
            b[j] = *p2;

        wfta (b, w, n1);

        for (j = 0, p2 = p1; j < n1; j++, p2 += n2)
            *p2 = b[j];
    }

    delete[] b;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -