fft.br
来自「用于GPU通用计算的编程语言BrookGPU 0.4」· BR 代码 · 共 497 行
BR
497 行
#ifdef _WIN32#pragma warning(disable:4275)#endif#include <complex>#include <cmath>#include "main.h"// NOTE:// Please generate the dataset and the twiddle factors by running the data.cc program in this directory// and then run this program. This program currently works for power of 2 lengths. // References:// [1] Alan V. Oppenheim and Ronald W. Schafer, ``Discrete-Time Signal Processing'', Prentice-Hall, 1989// [2] Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest and Clifford Stein,// ``Introduction to Algorithms'', MIT Press, 2001.// NOTE:// We implement the Decimation in Frequency Algorithm of the FFT described in [1]#include "main.h"#include <stdio.h>#include <stdlib.h>#include <math.h>int debug_fft=0;#ifdef USE_FFTW#include <fftw3.h>#endif//c version from FFT book#ifndef M_PI#define M_PI 3.1415926536f#endif#define SWAP(a,b) tempr=(a); (a) = (b); (b) = temprtypedef int mycomplex;typedef int myplan;#define mycomplex fftwf_complex#define myplan fftwf_planint DOTIMES=128;#ifdef USE_FFTWmycomplex * getMaxData () { return (mycomplex*)fftwf_malloc(sizeof(fftwf_complex) * 4096*4096);}mycomplex * fftwoutput = getMaxData();void FftwTransform(float2 *data, unsigned long N, int isign, char cast) { myplan p; mycomplex *in; mycomplex *output=fftwoutput; unsigned long i; if (!cast) { in = (mycomplex*)fftwf_malloc(sizeof(fftwf_complex) * N); for (i=0;i<N;++i) { in[i][0]=data[i].x; in[i][1]=data[i].y; } }else { in = (mycomplex*)data; } p = fftwf_plan_dft_1d(N, in, output, FFTW_FORWARD, FFTW_ESTIMATE); for (i=0;i<(unsigned long)DOTIMES;++i) { fftwf_execute(p); /* repeat as needed */ } fftwf_destroy_plan(p); if (debug_fft) { printf("cpu version\n"); for (i=0;i<N;++i) { float x = output[i][0]; float y = output[i][1]; printf ("{%.2f, %.2f}",x,y); } printf ("\n"); } if (!cast) { fftwf_free(in); }}void FftwTransform2d(float2 *data, unsigned long N, unsigned long M, int isign, char cast) { myplan p; mycomplex *in; mycomplex *output=fftwoutput; unsigned long i,j; if (!cast) { in = (mycomplex*)fftwf_malloc(sizeof(fftwf_complex) * N*M); for (i=0;i<N*M;++i) { in[i][0]=data[i].x; in[i][1]=data[i].y; } }else { in = (mycomplex*)data; } start = GetTime(); p = fftwf_plan_dft_2d(N, M, in, output, FFTW_FORWARD, FFTW_ESTIMATE); fftwf_execute(p); /* repeat as needed */ fftwf_execute(p); /* repeat as needed */ fftwf_destroy_plan(p); stop = GetTime(); if (debug_fft) { printf("cpu version\n"); for (i=0;i<N;++i) { for (j=0;j<M;++j) { float x = output[i*M+j][0]; float y = output[i*M+j][1]; printf ("{%.2f, %.2f}",x,y); } printf("\n"); } printf ("\n"); } if (!cast) { fftwf_free(in); }}#else#define fftwTransform FftwTransform#define fftwTransform2d FftwTransform2dvoid fftwTransform(float2 *data, unsigned long N, int isign, char cast) {}void fftwTransform2d(float2 *data, unsigned long N, unsigned long M, int isign, char cast) {}#endifvoid four1(float data[], unsigned long nn, int isign) { unsigned long n, mmax, m, j, istep, i; double wtemp, wr, wpr, wpi, wi, theta; float tempr, tempi; n =nn<<1; j=1; for (i=1;i<n;i+=2) { if (j>i) { SWAP(data[j],data[i]); SWAP(data[j+1],data[i+1]); } m= n>> 1; while (m>=2&& j >m) { j-=m; m >>=1; } j+=m; } mmax = 2; while (n>mmax) { istep =mmax<<1; theta=isign*(2*M_PI/mmax); wtemp = sin (0.5*theta); wpr = cos(theta);//-2.0*wtemp*wtemp; wpi = sin(theta); wr = 1.0; wi = 0.0; for (m=1;m<mmax;m+=2) { for (i=m;i<=n;i+=istep) { j=i+mmax; tempr=((float ) wr)*data[j]-((float ) wi)*data[j+1]; tempi=((float ) wr)*data[j+1]+((float) wi)*data[j]; data[j]=data[i]-tempr; data[j+1]=data[i+1]-tempi; data[i]+=tempr; data[i+1]+=tempi; } wr = (wtemp=wr)*wpr-wi*wpi+wr; wi = wi*wpr+wtemp*wpi+wi; } mmax=istep; } printf ("Base Case FFT \n"); for (i=0;i<nn;++i) { printf("{%f, %f}",data[2*i],data[2*i+1]); } printf("\n");}kernel void streamInit(out float2 a<>){ a.x = 0.0; a.y = 0.0;}void init (float2 * a) { a->x = a->y = 0.0;}kernel void streamCopy(out float2 a<>, float2 b<>){ a.x = b.x; a.y = b.y;}void copy (float2 * a, float2 * b) { a->x = b->x; a->y = b->y;}void __printf_cpu_inner(float inx, float iny, float outx, float outy) { printf("%g %g -> %g %g\n",inx,iny,outx,outy);}void __print_cpu_inner(float inx, float iny, float outx, float outy) { printf("%g %g && %g %g\n",inx,iny,outx,outy);}kernel void mult(float2 a <>, float2 b <>, out float2 c <>){ c.x = a.x*b.x - a.y*b.y; c.y = a.x*b.y + a.y*b.x;}float2 cmult (float2 a, float2 b) { float2 c; c.x = a.x*b.x - a.y*b.y; c.y = a.x*b.y + a.y*b.x; return c;}////////////////////////////////////////////////////FIXME// DFT()// Implements the Perfect Shuffle algorithmkernel voidDoDFT (float2 s[], float2 W[], float2 StrideN, out float4 s_prime<>) { float index = (indexof s_prime).x; float2 temp,r,t,twiddle; twiddle = W[index-fmod(index,StrideN.x)]; r = s[index]; t = s[index+round(StrideN.y/2)]; s_prime.xy = r+t ; mult ((r-t), twiddle, temp); s_prime.zw = temp;} kernel void flattenS (out float2 s<>, float4 s_prime<>){ if (round (fmod((indexof(s)).x,2))==1.0f) s = s_prime.zw; else s = s_prime.xy;}// Utilities and the BitReverse() procedure// To compute 2**xint TwoPowerX(int nNumber) { // Achieve the multiplication by left-shifting return (1<<nNumber);}// Procedure to reverse the bits. // Example: // INPUTS: nNumberToReverse = 110; nNumberOfBits = 3;// OUTPUT: nReversedNumber = 011// CAUTION: Make sure that you pass atleast the minimum number of bits to represent the // number. For reversing 6 you need atleast 3 bits, if only 2 bits is passed then// we get junk results.int BitReverse(int nNumberToReverse, int nNumberOfBits) { int nBitIndex; int nReversedNumber = 0; for (nBitIndex = nNumberOfBits-1; nBitIndex >= 0; --nBitIndex) { if ((1 == nNumberToReverse >> nBitIndex)) { nReversedNumber += TwoPowerX(nNumberOfBits-1-nBitIndex); nNumberToReverse -= TwoPowerX(nBitIndex); } } return(nReversedNumber);}// BitReverseStream()// INPUT: s, nDataLength. // OUTPUT: r. `r' contains the bit-reversed streamtypedef int streamcmplx;#define streamcmplx const brook::stream &float * bitReversedIndices (int logN) { int i,N = (1<< logN); float * s_array; s_array = (float*)calloc(N,sizeof(float)); for (i=0;i<N;++i) { s_array[i]=(float)BitReverse(i,logN); } return s_array;}void BitReverseStream(streamcmplx r, streamcmplx s, int nDataLength) { int nIndex, B = 1; float2 *s_array; s_array = (float2 *) calloc(nDataLength, sizeof(float2)); // This assumes that nDataLength is a power of 2 // Number of Bits while ((nDataLength >> B) > 0) B++; B--; // Now 2**B gives the nDataLength // copy the stream to the array streamWrite(s, s_array); for(nIndex = 1; nIndex < nDataLength-1; ++nIndex) { int nReversedIndex = BitReverse(nIndex, B); float2 temp; // Need to swap only half of the index // and no need to swap the first and last if (nReversedIndex < nIndex) continue; // Swap by references init(&temp); copy(&temp, s_array+nIndex); copy(s_array+nIndex, s_array+nReversedIndex); copy(s_array+nReversedIndex, &temp); } // copy the bit-reversed stream into r streamRead(r, s_array); // do the painful deallocation free(s_array); }float2 * getData(int N, int M) { int seed = 214098127; int i,j; float2 * ret; ret = (float2*)malloc(N*M*sizeof(float2)); for (i=0;i<N;++i) { for (j=0;j<M;++j) { seed+=24101491; ret[i*M+j]=float2((float)(seed%4096),0); seed%=9420409; } } return ret;}float2 * data = getData(4096,4096);float2 * getCompactData(int N, int M) { int seed = 214098127; int i,j; float2 * ret; ret = (float2*)malloc(N*M*sizeof(float2)); for (i=0;i<N;++i) { for (j=0;j<M;++j) { seed+=24101491; ret[i*M+j].x=(float)(seed%4096); seed%=9420409; seed+=24101491; ret[i*M+j].y=(float)(seed%4096); seed%=9420409; } } return ret;}float2 * compactdata = getData(4096,4096);float2 * getTwiddleFactor(int N) { int i; float2 *ret; float theta = 2*M_PI/N; // float wtemp = sin(0.5*theta); ret = (float2 *)malloc(sizeof(float2)*N); ret[0].x=1;ret[0].y=0; ret[1].x=cos(theta); ret[1].y = sin(theta); for (i=2;i<N;++i) { ret[i]=cmult(ret[i-1],ret[1]); } return ret;}kernel void myGather(out float2 s_out<>, float indices<>,float2 s[]) { s_out=s[indices];}void fft1d(int logN){ int N = (1<<logN); float2 s<N>; float2 s_out<N>; float2 r <N>; float2 t <N>; float2 W<N>; // passed as the out variable in the DFT float4 s_prime<(N/2)>; // variables to get around the strange ways of Brook // to get around the references in `flattening' the stream float2 temp<N>; // Read the data // NOTE: Twiddle factors // Reference: Pg 601 OS (1/e) Fig 9.18 // Imagine Beginners Guide pg 23 // Please note that nPass and nBits denote the same but there // are two different variables for cleaner coding and easier // maintenance. int nPass, nBits; int nPassCounter; int nDataLength; float2 *dat = getData(1,N); streamRead(s,dat); printf("The stream length is %d\n", N); streamPrint(s); FftwTransform(dat,N,1,1); free (dat); start = GetTimeMillis(); dat = getTwiddleFactor(N); streamRead(W,dat); free(dat); nDataLength = N; // Just print the data for reference printf("\n"); // handle the base case if(1 == nDataLength) streamSwap(s_out,s); // do nothing else { // Number of passes = log2(nDataLength) // Implementation Notes: // There is an adjustment for nDataLength due to the fact that // ceil(10.0) will give 11 not 10. Thus if are dealing with // nDataLength of powers of 2 then we have to subtract one // to get the correct number of passes nPass = logN; nBits = nPass; nPassCounter = 0; while(nPassCounter < nPass) { // Please remember that t = s[nDataLength/2 .. nDataLength) // Do a `Perfect Shuffle' int nStride = (1<<nPassCounter); // DFT merge call. The heart of this program! DoDFT(s, W,float2((float) nStride, (float) N), s_prime); // To support the `in-place' perfect-shuffle algorithm if (1) { flattenS(s, s_prime); } // update the PassCounter ++nPassCounter; } // while() // Note that 's' now has a bitreversed version of the FFT. // Since FFT is an in-place computation 's' contains the // the FFT(s) itself. // NOTE: streamBitReverse() isn't yet implemented. Just an emulation if (1) { float *rawindices=bitReversedIndices(logN); static float indices<N>; streamRead(indices,rawindices); free(rawindices); myGather(s_out,indices,s); }// printf("The scrambled FFT of the given data is\n");// streamPrint(s); stop = GetTimeMillis(); printf("\nThe FFT of the given data is\n"); if (debug_fft) streamPrint(s_out); printf ("Done in %f seconds\n",(float)(stop-start)); printf("\n"); // scanf("%d",&N); } // end of else construct return ;}#define PrintResults(sstart, sstop, name) \ printf("%9d %6d %6.2f (* %s *)\\n", \ M, (int) (sstop - sstart), \ (float)(DOTIMES*10.*(logN+logM)*M*N) / (float) (sstop - sstart), name);void doFFTW (int logN) { int logM = logN; int N = (1<<logN); int M = (1<<logM); // float2 * output=0; // compute2dFFT(data,output,logN,logM,N,M,FFT_2D); FftwTransform2d(data,N,M,1,1); PrintResults(start,stop,"FFTW");}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?