⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 conv&corr.cpp

📁 使用fft实现的快速卷积
💻 CPP
📖 第 1 页 / 共 2 页
字号:
    int w=0;  // 2的幂数,即FFT迭代次数,2的w次方=WLcount
	int h=0;  // 2的幂数,即FFT迭代次数,2的h次方=HLcount

	int temp;
	if (log(Wcount)/log(2)-int(log(Wcount)/log(2))==0)
      temp = log(Wcount)/log(2);
	else
	  temp = log(Wcount)/log(2)+1;
	w = temp;
	WLcount = 1<<w;	

	if (log(Hcount)/log(2)-int(log(Hcount)/log(2))==0)
      temp = log(Hcount)/log(2);
	else
	  temp = log(Hcount)/log(2)+1;
	h = temp;
	HLcount = 1<<h;	

	// 分配运算所需存储器
    complex<double> *X1, *X2, *FD1, *FD2, *FD12, *TD12;

	X1 = new complex<double>[WLcount * HLcount];  //补齐后的序列1
	X2 = new complex<double>[WLcount * HLcount];  //补齐后的序列2
	FD1 = new complex<double>[WLcount * HLcount];   //序列1的傅立叶变换结果
	FD2 = new complex<double>[WLcount * HLcount];   //序列2的傅立叶变换结果
	FD12 = new complex<double>[WLcount * HLcount];   //序列1,2的频域相乘结果
	TD12 = new complex<double>[WLcount * HLcount];   //序列1,2的傅立叶反变换结果
	
    //将序列补齐为WLcount * HLcount长度
	complex<double> *X, *Y;
	X = new complex<double>[M1 * N1];  //临时存储器
	Y = new complex<double>[M2 * N2];
	
	// 将时域点写入X,Y
	memcpy(X, TD1, sizeof(complex<double>) * M1 * N1);
	memcpy(Y, TD2, sizeof(complex<double>) * M2 * N2);

	// 循环变量
	int	i, j;

    for (i=0; i<N1; i++)    //拷贝序列1内容   
    {
        for (j=0; j<M1; j++)
        {
		   X1[i * WLcount + j] = complex<double>(X[i * M1 + j].real(), X[i * M1 + j].imag()); 
		}
    }

    for (i=0; i<N1; i++)    //序列1补0
    {
        for (j=M1; j<WLcount; j++)
		{
		   X1[i * WLcount + j] = complex<double>(0, 0);
		}
    }
	for (i=N1; i<HLcount; i++)    //序列1补0
    {
        for (j=0; j<WLcount; j++)
		{
		   X1[i * WLcount + j] = complex<double>(0, 0);
		}
    }

	for (i=0; i<N2; i++)    //拷贝序列2内容   
    {
        for (j=0; j<M2; j++)
        {
		   X2[i * WLcount + j] = complex<double>(Y[i * M2 + j].real(), Y[i * M2 + j].imag()); 
		}
    }

    for (i=0; i<N2; i++)    //序列2补0
    {
        for (j=M2; j<WLcount; j++)
		{
		   X2[i * WLcount + j] = complex<double>(0, 0);
		}
    }
	for (i=N2; i<HLcount; i++)    //序列2补0
    {
        for (j=0; j<WLcount; j++)
		{
		   X2[i * WLcount + j] = complex<double>(0, 0);
		}
    }

    // 释放内存
	delete X;
	delete Y;

    //序列1的FFT
	FFT2(X1, FD1, WLcount, HLcount);

	//序列2的FFT
    FFT2(X2, FD2, WLcount, HLcount);

    //序列1,2的频域相乘(图像频域序列相乘是否就可看作一维序列相乘?)
    for (i=0; i<WLcount * HLcount; i++)    //序列1,2相乘
    {
        FD12[i] = complex<double>(FD1[i].real()*FD2[i].real()-FD1[i].imag()*FD2[i].imag(), FD1[i].real()*FD2[i].imag()+FD1[i].imag()*FD2[i].real());
    }

	//序列1,2的频域相乘的IFFT
    IFFT2(FD12, TD12, WLcount, HLcount);

	//TD12中的前(M1 + M2 - 1) * (N1 + N2 - 1)项为真正卷积结果写入TDout
	for (i=0; i<Hcount; i++)       
    {
        for (j=0; j<Wcount; j++)
        {
		   TDout[i * Wcount + j] = complex<double>(TD12[i * WLcount + j].real(), TD12[i * WLcount + j].imag()); 
		}
    }
	
	// 释放内存
	delete X1;
	delete X2;
	delete FD1;
	delete FD2;
	delete FD12;
    delete TD12;
}

/*************************************************************************
 *
 * 函数名称:
 *   CORR2()
 *
 * 参数:
 *   complex<double> * TD1	    - 指向时域序列数组1的指针
 *   complex<double> * TD2	    - 指向时域序列数组2的指针
 *   complex<double> * TDout	- 指向时域结果数组的指针
 *   M1,N1					    - 图像1的宽度、高度
 *   M2,N2                      - 图像2的宽度、高度
 *
 * 说明:
 *   该函数利用FFT实现快速相关。
 *
 ************************************************************************/
void CORR2(complex<double> * TD1, complex<double> * TD2, complex<double> * TDout, int M1, int N1, int M2, int N2)
{
	// x方向相关结果长度
	int	Wcount = M1 + M2 - 1;
	// y方向相关结果长度
	int	Hcount = N1 + N2 - 1;	

	// 便于使用FFT,把x,y方向都要扩展为2的幂
	long WLcount, HLcount;
    int w=0;  // 2的幂数,即FFT迭代次数,2的w次方=WLcount
	int h=0;  // 2的幂数,即FFT迭代次数,2的h次方=HLcount

	int temp;
	if (log(Wcount)/log(2)-int(log(Wcount)/log(2))==0)
      temp = log(Wcount)/log(2);
	else
	  temp = log(Wcount)/log(2)+1;
	w = temp;
	WLcount = 1<<w;	

	if (log(Hcount)/log(2)-int(log(Hcount)/log(2))==0)
      temp = log(Hcount)/log(2);
	else
	  temp = log(Hcount)/log(2)+1;
	h = temp;
	HLcount = 1<<h;		

    // 分配运算所需存储器
    complex<double> *X1, *X2, *FD1, *FD2, *FD12, *TD12;

	X1 = new complex<double>[WLcount * HLcount];  //补齐后的序列1
	X2 = new complex<double>[WLcount * HLcount];  //补齐后的序列2
	FD1 = new complex<double>[WLcount * HLcount];   //序列1的傅立叶变换结果
	FD2 = new complex<double>[WLcount * HLcount];   //序列2的傅立叶变换结果
	FD12 = new complex<double>[WLcount * HLcount];   //序列1,2的频域相乘结果
	TD12 = new complex<double>[WLcount * HLcount];   //序列1,2的傅立叶反变换结果
	
    //将序列补齐为WLcount * HLcount长度
	complex<double> *X, *Y;
	X = new complex<double>[M1 * N1];  //临时存储器
	Y = new complex<double>[M2 * N2];
	
	// 将时域点写入X,Y
	memcpy(X, TD1, sizeof(complex<double>) * M1 * N1);
	memcpy(Y, TD2, sizeof(complex<double>) * M2 * N2);

    // 循环变量
	int	i, j;

    for (i=0; i<N1; i++)    //拷贝序列1内容   
    {
        for (j=0; j<M1; j++)
        {
		   X1[i * WLcount + j] = complex<double>(X[i * M1 + j].real(), X[i * M1 + j].imag()); 
		}
    }

    for (i=0; i<N1; i++)    //序列1补0
    {
        for (j=M1; j<WLcount; j++)
		{
		   X1[i * WLcount + j] = complex<double>(0, 0);
		}
    }
	for (i=N1; i<HLcount; i++)    //序列1补0
    {
        for (j=0; j<WLcount; j++)
		{
		   X1[i * WLcount + j] = complex<double>(0, 0);
		}
    }

	for (i=0; i<N2; i++)    //拷贝序列2内容   
    {
        for (j=0; j<M2; j++)
        {
		   X2[i * WLcount + j] = complex<double>(Y[i * M2 + j].real(), Y[i * M2 + j].imag()); 
		}
    }

    for (i=0; i<N2; i++)    //序列2补0
    {
        for (j=M2; j<WLcount; j++)
		{
		   X2[i * WLcount + j] = complex<double>(0, 0);
		}
    }
	for (i=N2; i<HLcount; i++)    //序列2补0
    {
        for (j=0; j<WLcount; j++)
		{
		   X2[i * WLcount + j] = complex<double>(0, 0);
		}
    }

	// 释放内存
	delete X;
	delete Y;

    //序列1的FFT
	FFT2(X1, FD1, WLcount, HLcount);

	//序列2的FFT
    FFT2(X2, FD2, WLcount, HLcount);

	//求序列2的FFT结果的共轭
	for(i=0; i<WLcount * HLcount; i++)
	{
		FD2[i] = complex<double> (FD2[i].real(), -FD2[i].imag());
	}

    //序列1,2的频域相乘(图像频域序列相乘是否就可看作一维序列相乘?)
    for (i=0; i<WLcount * HLcount; i++)    //序列1,2相乘
    {
        FD12[i] = complex<double>(FD1[i].real()*FD2[i].real()-FD1[i].imag()*FD2[i].imag(), FD1[i].real()*FD2[i].imag()+FD1[i].imag()*FD2[i].real());
    }

	//序列1,2的频域相乘的IFFT
    IFFT2(FD12, TD12, WLcount, HLcount);

	//TD12中的前(M1 + M2 - 1) * (N1 + N2 - 1)项为真正相关结果写入TDout
	for (i=0; i<Hcount; i++)       
    {
        for (j=0; j<Wcount; j++)
        {
		   TDout[i * Wcount + j] = complex<double>(TD12[i * WLcount + j].real(), TD12[i * WLcount + j].imag()); 
		}
    }
	
	// 释放内存
	delete X1;
	delete X2;
	delete FD1;
	delete FD2;
	delete FD12;
    delete TD12;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -