📄 fft.txt
字号:
#include<iostream>
#include<vector>
#include<iomanip>
#include<sstream>
#include<string>
#include<map>
#include<stack>
#include<fstream>
#include<cmath>
#include<algorithm>
#include<list>
using namespace std;
int ci;
const double pi=3.141592653589793238462643383279502884;
int n;
struct Pair
{
double x,y;
Pair(){x=0;y=0;}
Pair(double a,double b){x=a;y=b;}
Pair operator*(const Pair &b){return Pair(x*b.x-y*b.y,x*b.y+y*b.x);}
Pair operator+(const Pair &b){return Pair(x+b.x,y+b.y);}
Pair operator-(const Pair &b){return Pair(x-b.x,y-b.y);}
};
void fft(vector<Pair>&a,vector<Pair>&b)
{
for(int i=0;i<n;i++)
{
int temp=i;
int ti=0;
for(int j=0;j<ci;j++)
{
ti<<=1;
ti^=temp&1;
temp>>=1;
}
a[ti]=b[i];
}
int m=1;
for(int i=0;i<ci;i++)
{
m<<=1;
Pair wm(cos(2*pi/m),sin(2*pi/m));
for(int k=0;k<n;k+=m)
{
Pair w(1,0);
int p=m>>1;
for(int j=0;j<p;j++)
{
Pair t=w*a[k+j+p];
Pair u=a[k+j];
a[k+j]=u+t;
a[k+j+p]=u-t;
w=w*wm;
}
}
}
}
void defft(vector<Pair>&a,vector<Pair>&b)
{
for(int i=0;i<n;i++)
{
int temp=i;
int ti=0;
for(int j=0;j<ci;j++)
{
ti<<=1;
ti^=temp&1;
temp>>=1;
}
a[ti]=b[i];
}
int m=1;
for(int i=0;i<ci;i++)
{
m<<=1;
Pair wm(cos(2*pi/m),-sin(2*pi/m));
for(int k=0;k<n;k+=m)
{
Pair w(1,0);
int p=m>>1;
for(int j=0;j<p;j++)
{
Pair t=w*a[k+j+p];
Pair u=a[k+j];
a[k+j]=u+t;
a[k+j+p]=u-t;
w=w*wm;
}
}
}
for(int i=0;i<n;i++)
a[i].x/=n,a[i].y/=n;
}
char arr[100000],brr[100000];
int main()
{
while(scanf("%s %s",arr,brr)!=EOF)
{
int la=strlen(arr);
int lb=strlen(brr);
int da=max(la,lb);
da++;
da>>=1;
ci=1;
while(da)
{
da>>=1;
ci++;
}
n=1<<ci;
vector<Pair>a(n),b(n),c(n),d(n),e(n);
int i;
int index=0;
for(i=la-2;i>0;i-=2)
{
c[index++].x=atoi(arr+i);
arr[i]=0;
}
c[index++].x=atoi(arr);
index=0;
for(i=lb-2;i>0;i-=2)
{
d[index++].x=atoi(brr+i);
brr[i]=0;
}
d[index++].x=atoi(brr);
fft(a,c);
fft(b,d);
for(i=0;i<n;i++)e[i]=a[i]*b[i];
defft(a,e);
vector<int>val(n);
for(i=0;i<n;i++)
val[i]=(int)((a[i].x*10+5)/10);
for(int i=0;i<n-1;i++)
{
val[i+1]+=val[i]/100;
val[i]%=100;
}
for(i=n-1;i>=0;i--)
if(val[i])
break;
if(i<0)
{printf("0\n");continue;}
printf("%d",val[i]);
for(i--;i>=0;i--)
printf("%02d",val[i]);
printf("\n");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -