⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 conv&corr.cpp

📁 使用fft实现的快速卷积
💻 CPP
📖 第 1 页 / 共 2 页
字号:
/**************************************************************************
 *  文件名:CONV&CORR.cpp
 *
 *  函数:
 *
 *  CONV()				- 快速卷积
 *  CORR()				- 快速相关
 *  CONV2()				- 图像快速卷积
 *  CORR2()				- 图像快速相关
 *
 *************************************************************************/

#include "stdafx.h"
#include "CONV&CORR.h"
#include "FFT&IFFT.h"

#include <math.h>
#include <complex>
using namespace std;

// 常数π
#define PI 3.1415926535

/*************************************************************************
 *
 * 函数名称:
 *   CONV()
 *
 * 参数:
 *   complex<double> * TD1	    - 指向时域序列数组1的指针
 *   complex<double> * TD2	    - 指向时域序列数组2的指针
 *   complex<double> * TDout	- 指向时域结果数组的指针
 *   M						    - 序列1的长度
 *   N                          - 序列2的长度
 *
 * 说明:
 *   该函数利用FFT实现快速卷积。
 *
 ************************************************************************/
void CONV(complex<double> * TD1, complex<double> * TD2, complex<double> * TDout, int M, int N)
{
	// 卷积结果长度
	int	count = M+N-1;

	// 便于使用FFT,把count扩展为2的幂
	int Lcount;
    int r=0;  // 2的幂数,即FFT迭代次数,2的r次方=Lcount

	int temp;
	if (log(count)/log(2)-int(log(count)/log(2))==0)
      temp = log(count)/log(2);
	else
	  temp = log(count)/log(2)+1;
	r = temp;
	Lcount = 1<<r;	

	// 分配运算所需存储器
    complex<double> *X1, *X2, *FD1, *FD2, *FD12, *TD12;

	X1 = new complex<double>[Lcount];  //补齐后的序列1
	X2 = new complex<double>[Lcount];  //补齐后的序列2
	FD1 = new complex<double>[Lcount];   //序列1的傅立叶变换结果
	FD2 = new complex<double>[Lcount];   //序列2的傅立叶变换结果
	FD12 = new complex<double>[Lcount];   //序列1,2的频域相乘结果
	TD12 = new complex<double>[Lcount];   //序列1,2的傅立叶反变换结果
	
    //将序列补齐为Lcount长度
	complex<double> *X, *Y;
	X = new complex<double>[M];  //临时存储器
	Y = new complex<double>[N];
	
	// 将时域点写入X,Y
	memcpy(X, TD1, sizeof(complex<double>) * M);
	memcpy(Y, TD2, sizeof(complex<double>) * N);

	// 循环变量
	int	i;

    for (i=0; i<M; i++)    //拷贝序列1内容   
    {
        X1[i] = complex<double>(X[i].real(), X[i].imag());                                             
    }

    for (i=M; i<Lcount; i++)    //序列1补0
    {
        X1[i] = complex<double>(0, 0);
    }

	for (i=0; i<N; i++)    //拷贝序列2内容   
    {
        X2[i] = complex<double>(Y[i].real(), Y[i].imag());
    }

    for (i=N; i<Lcount; i++)    //序列2补0
    {
        X2[i] = complex<double>(0, 0);
    }

    // 释放内存
	delete X;
	delete Y;

    //序列1的FFT
	FFT(X1, FD1, r);

	//序列2的FFT
    FFT(X2, FD2, r);

    //序列1,2的频域相乘
    for (i=0; i<Lcount; i++)    //序列1,2相乘
    {
        FD12[i] = complex<double>(FD1[i].real()*FD2[i].real()-FD1[i].imag()*FD2[i].imag(), FD1[i].real()*FD2[i].imag()+FD1[i].imag()*FD2[i].real());
    }

	//序列1,2的频域相乘的IFFT
    IFFT(FD12, TD12, r);

	//TD12中的前M+N-1项为真正卷积结果写入TDout
    memcpy(TDout, TD12, sizeof(complex<double>)*count);
	
	// 释放内存
	delete X1;
	delete X2;
	delete FD1;
	delete FD2;
	delete FD12;
    delete TD12;
}

/*************************************************************************
 *
 * 函数名称:
 *   CORR()
 *
 * 参数:
 *   complex<double> * TD1	    - 指向时域序列数组1的指针
 *   complex<double> * TD2	    - 指向时域序列数组2的指针
 *   complex<double> * TDout	- 指向时域结果数组的指针
 *   M						    - 序列1的长度
 *   N                          - 序列2的长度
 *
 * 说明:
 *   该函数利用FFT实现快速相关。
 *
 ************************************************************************/
void CORR(complex<double> * TD1, complex<double> * TD2, complex<double> * TDout, int M, int N)
{
	// 相关结果长度
	int	count = M+N-1;

	// 便于使用FFT,把序列扩展为2的幂
	int Lcount;
    int r = 0;  // 2的幂数,即FFT迭代次数,2的r次方=Lcount

	/*int temp;
	if (M>N) temp = M;
	else temp = N;
	if (log(2*temp)/log(2)-int(log(2*temp)/log(2))==0)
      temp = log(2*temp)/log(2);
	else
	  temp = log(2*temp)/log(2)+1;
	r = temp;
	Lcount = 1<<r;*/
	
    int temp;
	if (log(count)/log(2)-int(log(count)/log(2))==0)
      temp = log(count)/log(2);
	else
	  temp = log(count)/log(2)+1;
	r = temp;
	Lcount = 1<<r;

	// 分配运算所需存储器
    complex<double> *X1, *X2, *FD1, *FD2, *FD12, *TD12;

	X1 = new complex<double>[Lcount];  //补齐后的序列1
	X2 = new complex<double>[Lcount];  //补齐后的序列2
	FD1 = new complex<double>[Lcount];   //序列1的傅立叶变换结果
	FD2 = new complex<double>[Lcount];   //序列2的傅立叶变换结果
	FD12 = new complex<double>[Lcount];   //序列1,2的频域相乘结果
	TD12 = new complex<double>[Lcount];   //序列1,2的傅立叶反变换结果
	
    //将序列补齐为Lcount长度
	complex<double> *X, *Y;
	X = new complex<double>[M];  //临时存储器
	Y = new complex<double>[N];
	
	// 将时域点写入X,Y
	memcpy(X, TD1, sizeof(complex<double>) * M);
	memcpy(Y, TD2, sizeof(complex<double>) * N);

	// 循环变量
	int	i;

    for (i=0; i<M; i++)    //拷贝序列1内容   
    {
        X1[i] = complex<double>(X[i].real(), X[i].imag());                                             
    }

    for (i=M; i<Lcount; i++)    //序列1补0
    {
        X1[i] = complex<double>(0, 0);
    }

	for (i=0; i<N; i++)    //拷贝序列2内容   
    {
        X2[i] = complex<double>(Y[i].real(), Y[i].imag());
    }

    for (i=N; i<Lcount; i++)    //序列2补0
    {
        X2[i] = complex<double>(0, 0);
    }

    // 释放内存
	delete X;
	delete Y;

    //序列1的FFT
	FFT(X1, FD1, r);

	//序列2的FFT
    FFT(X2, FD2, r);

	//求序列2的FFT结果的共轭
	for(i=0; i<Lcount; i++)
	{
		FD2[i] = complex<double> (FD2[i].real(), -FD2[i].imag());
	}

    //序列1,2的频域相乘
    for (i=0; i<Lcount; i++)    //序列1,2相乘
    {
        FD12[i] = complex<double>(FD1[i].real()*FD2[i].real()-FD1[i].imag()*FD2[i].imag(), FD1[i].real()*FD2[i].imag()+FD1[i].imag()*FD2[i].real());
    }

	//序列1,2的频域相乘的IFFT
    IFFT(FD12, TD12, r);

	//TD12中的前M+N-1项为真正相关结果写入TDout
    memcpy(TDout, TD12, sizeof(complex<double>)*count);
	
	// 释放内存
	delete X1;
	delete X2;
	delete FD1;
	delete FD2;
	delete FD12;
    delete TD12;
}

/*************************************************************************
 *
 * 函数名称:
 *   CONV2()
 *
 * 参数:
 *   complex<double> * TD1	    - 指向时域序列数组1的指针
 *   complex<double> * TD2	    - 指向时域序列数组2的指针
 *   complex<double> * TDout	- 指向时域结果数组的指针
 *   M1,N1					    - 图像1的宽度、高度
 *   M2,N2                      - 图像2的宽度、高度
 *
 * 说明:
 *   该函数利用FFT实现图像快速卷积。
 *
 ************************************************************************/
void CONV2(complex<double> * TD1, complex<double> * TD2, complex<double> * TDout, int M1, int N1, int M2, int N2)
{
    // x方向卷积长度
	int	Wcount = M1 + M2 - 1;
	// y方向卷积长度
	int	Hcount = N1 + N2 - 1;

	// 便于使用FFT,把x,y方向都要扩展为2的幂
	long WLcount, HLcount;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -