⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 算法(strassen和strassen混合算法).cpp

📁 strassen算法的扩展
💻 CPP
字号:
//作者:建麟    email blacken1008@163.com
#include<iostream.h>
#include<math.h>
#include<stdlib.h>
#include<stdio.h>
#include<iomanip.h>

//*****************the declaration of method
void Allot(float **&temp,int k); //the method to dynamically allot memory to a float** varible
void Output(float **a,int n);//output a matrix whit n capacity
float **matrixSub(float **b,float **c,int n);//the method to make two matrixes subtract
void MakeMAndK(int n);//the method to calculate the m and k of n=m*2^k
void Separate(float **x,float ***A,int n);//separate a matrix to x11 x12 x21 x22
float **Strassen(float **a,float **b,int n);//the Strassen algorithm of two matrixes multiplication
float **matrixAdd(float **b,float **c,int n);//the method to make two matrixes multiply
float **matrixSepMul(float **b,float **c,int k,int n);//the combination of traditional and Strassen algorithm when m is not 1 
float **matrixMul(float **b,float **c,int n);//the traditional algorithm of two matrixes multiplication

int k,m;             //global varible to record the k and m 
//*****************the main method
main()
{
	int i,j;
	float **a;
	float **b;
	//allot memory to float**
	Allot(a,8);
	Allot(b,12);
	//initialize the matrix n=8
	for(j=0;j<8;j++)
	for(i=0;i<8;i++)
	{
		a[j][i]=1/(float)(i+j+1);
	}
	cout<<"8*8 matrix multiplication Result:"<<endl;

	Output(Strassen(a,a,8),8);

	///*initialize the matrix n=12
for(j=0;j<12;j++)
	for(i=0;i<12;i++)
	{
		b[j][i]=1/(float)(i+j+1);
	}	
	//make out the m and k
	MakeMAndK(12);
    cout<<"12*12 matrix multiplication Result:"<<endl;

	Output(matrixSepMul(b,b,m,12),12);
	system("pause");
	return 0;
}
//the method to calculate the value of two matrix multiplying traditionally
float **matrixMul(float **b,float **c,int n)   
{
    int i,j,k;
	float **a;
	Allot(a,n);
	for(i=0;i<n;i++)
        for(j=0;j<n;j++)
			a[i][j]=0;
    for(i=0;i<n;i++)
        for(j=0;j<n;j++){
        for(k=0;k<n;k++)
        a[i][j]+=b[i][k]*c[k][j];
    }

	return a;
}
float **matrixSepMul(float **b,float **c,int k,int n)  
{
	int i,j,x,y;
//initialize the varible
	float ****B,****C,****A,**a;
	Allot(a,n);
	B=new float***[k];
	for(i=0;i<k;i++)
		B[i]=new float**[k];
	for(j=0;j<k;j++)
		for(i=0;i<k;i++)
			Allot(B[j][i],n/k);
	C=new float***[k];
	for(i=0;i<k;i++)
		C[i]=new float**[k];
	for(j=0;j<k;j++)
		for(i=0;i<k;i++)
			Allot(C[j][i],n/k);
	A=new float***[k];
	for(i=0;i<k;i++)
		A[i]=new float**[k];
	for(j=0;j<k;j++)
		for(i=0;i<k;i++)
			Allot(A[j][i],n/k);
	for(i=0;i<k;i++)                         
		for(j=0;j<k;j++)
			for(x=0;x<n/k;x++)
				for(y=0;y<n/k;y++)
					A[i][j][x][y]=0;
//separate b into with 2^k pieces in B,and the same to c
	for(i=0;i<k;i++)                         
		for(j=0;j<k;j++)
			for(x=0;x<n/k;x++)
				for(y=0;y<n/k;y++)
				{
					B[i][j][x][y]=b[i*n/k+x][j*n/k+y];
				}   
	for(i=0;i<k;i++)                         
		for(j=0;j<k;j++)
			for(x=0;x<n/k;x++)
				for(y=0;y<n/k;y++)
				{
					C[i][j][x][y]=c[i*n/k+x][j*n/k+y];
				}   
//each piece use the strassen algorithm with traditional algorithm
	 for(i=0;i<k;i++)
        for(j=0;j<k;j++)
		{
			for(x=0;x<k;x++)
			A[i][j]=matrixAdd(A[i][j],Strassen(B[i][x],C[x][j],n/k),n/k);
		}
//combination of pieces A into a 
	for(i=0;i<k;i++)                         
		for(j=0;j<k;j++)
			for(x=0;x<n/k;x++)
				for(y=0;y<n/k;y++)
				{
					a[i*n/k+x][j*n/k+y]=A[i][j][x][y];
				}   
	return a;
}
float **matrixSub(float **b,float **c,int n)
{
	int i,j;
	float **a;
	a=new float*[n];
	for(i=0;i<n;i++)
	{
		a[i]=new float[n];
	}
	for(i=0;i<n;i++)
        for(j=0;j<n;j++){  
        a[i][j]=b[i][j]-c[i][j];
    }
	return a;
}
float **matrixAdd(float **b,float **c,int n)
{
	int i,j;
	float **a;
	a=new float*[n];
	for(i=0;i<n;i++)
	{
		a[i]=new float[n];
	}
	for(i=0;i<n;i++)
        for(j=0;j<n;j++){  
        a[i][j]=b[i][j]+c[i][j];
    }
	return a;
}

void MakeMAndK(int n)   //to calculate the m and k of a n-dimension matrix
{
	k=0;
	m=0;
	while(n%2==0)
	{
		n=n/2;
		k++;
	}
	m=n;
}
void Allot(float **&temp,int k)
{
	temp=new float*[k];
	for(int i=0;i<k;i++)
	{
		temp[i]=new float[k];
	}
}
void Separate(float **x,float ***A,int n)
{
	//separation of x
	//A[0][][] is A11.A[1][][] is A12.A[2][][] is A21.A[3][][] is A22
	int i,j;
	for(i=0;i<n/2;i++)
		for(j=0;j<n/2;j++)
	{
		A[0][i][j]=x[i][j];
	}
	for(i=n/2;i<n;i++)
		for(j=0;j<n/2;j++)
	{
		A[1][i-n/2][j]=x[i][j];
	}
	for(i=0;i<n/2;i++)
		for(j=n/2;j<n;j++)
	{
		A[2][i][j-n/2]=x[i][j];
	}
	for(i=n/2;i<n;i++)
		for(j=n/2;j<n;j++)
	{
		A[3][i-n/2][j-n/2]=x[i][j];
	}
}
float **Strassen(float **a,float **b,int n)
{
	if(n==2)
	{
		return matrixMul(a,b,n);
	}
	else{
//recursion
	int i=0,j;
	float ***A,***B,**c,***C,***M;
	Allot(c,n);
	A=new float**[4];
	B=new float**[4];
	M=new float**[7];
	C=new float**[4];
	for(i=0;i<7;i++)
	{
		Allot(M[i],n/2);
		if(i<4)
		{
			Allot(A[i],n/2);
			Allot(B[i],n/2);
			Allot(C[i],n/2);
		}
	}
		Separate(a,A,n);Separate(b,B,n);
		M[0]=Strassen(A[0],matrixSub(B[1],B[3],n/2),n/2);
		M[1]=Strassen(matrixAdd(A[0],A[1],n/2),B[3],n/2);
		M[2]=Strassen(matrixAdd(A[2],A[3],n/2),B[0],n/2);
		M[3]=Strassen(A[3],matrixSub(B[2],B[0],n/2),n/2);
		M[4]=Strassen(matrixAdd(A[0],A[3],n/2),matrixAdd(B[0],B[3],n/2),n/2);
		M[5]=Strassen(matrixSub(A[1],A[3],n/2),matrixAdd(B[2],A[3],n/2),n/2);
		M[6]=Strassen(matrixSub(A[0],A[2],n/2),matrixAdd(B[0],B[2],n/2),n/2);
		C[0]=matrixAdd(matrixSub(matrixAdd(M[4],M[3],n/2),M[1],n/2),M[5],n/2);
		C[1]=matrixAdd(M[0],M[1],n/2);
		C[2]=matrixAdd(M[2],M[3],n/2);
		C[3]=matrixSub(matrixSub(matrixAdd(M[4],M[0],n/2),M[2],n/2),M[6],n/2);
//recursion end
		//combination of result
		for(i=0;i<n/2;i++)
        for(j=0;j<n/2;j++)
        {
            c[i][j]=C[0][i][j];
            c[i][j+n/2]=C[1][i][j];
            c[i+n/2][j]=C[2][i][j];
            c[i+n/2][j+n/2]=C[3][i][j];
        }
		return c;
	}
}
void Output(float **a,int n)
{
	int i,j;
	cout.setf(ios::fixed);
	for(i=0;i<n;i++)
	{
		for(j=0;j<n;j++)
			cout<<setw(7)<<setprecision(6)<<a[i][j]<<' ';
		cout<<'\n';
	}
	cout.flush();
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -