⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hemingyao.cpp

📁 算法与设计中实现Stassen矩阵相乘算法
💻 CPP
字号:
#include <stdio.h> 
#include<time.h>
#include<cstdlib>
#define mytype int//矩阵元素的数据类型 
#define myinputmode "%d"//矩阵元素的输入格式 
#define myprintmode "%4d"//矩阵元素的输出格式 


/**************************************** 
函数名:is2 
参数:m为长整型整数 
功能:检测m是否是2的正整数次幂 
返回值:返回布尔型变量 
true则表示m为2的正整数次幂 
false则表示m不是2的正整数次幂 
****************************************/ 
bool is2(long m) 
{ 
if(m<0)return false; 
if(m>=2) 
{ 
if((m%2)==0) return is2(m/2); 
else return false; 
} 
else 
{ 
if(m==1)return true; 
else return false; 
} 
return false; 
} 

///////////////////////////////////////// 
/**************************************** 
函数名:printmatrix 
参数:M为指向数组的指针,数组中存储着矩阵 
m长整型,是数组M所存矩阵的维数 
name字符型数组,是需要进行数据输入的矩阵的名字 
功能:矩阵数据输出显示的函数,将矩阵元素一一显示一在屏幕上 
返回值:无 
****************************************/ 
void printmatrix(mytype * M,long m,char *name) 
{ 
long i,j; 
printf("\nMatrix %s:\n",name); 
for(i=0;i<m;i++) 
{ 
for(j=0;j<m;j++) 
{ 
printf(myprintmode,M[i*m+j]); 
} 
printf("\n"); 
} 
} 
///////////////////////////////////////// 
/**************************************** 
函数名:Matrix_add_sub 
参数:A,B为指向数组的指针,数组中存储着矩阵 
C为指向数组的指针,用来存储运算结果 
m长整型,是数组A、B、C所存矩阵的维数 
add为布尔型变量,为true则C=A+B,为false则C=A-B 
功能:根据add值对A、B进行加减运算并将结果存入C 
返回值:无 
****************************************/ 
void Matrix_add_sub(mytype * A,mytype * B,mytype * C,long m,bool add) 
{ 
long i; 
for(i=0;i<m*m;i++) 
{ 
if(add) 
C[i]=A[i]+B[i]; 
else 
C[i]=A[i]-B[i]; 
} 
} 
///////////////////////////////////////// 
/**************************************** 
函数名:GetHalfValue 
参数:B为指向数组的指针,数组中存储着矩阵。其中B是指向m维矩阵中的一个元素。 
A为指向数组的指针,用来接收B中的四分之一数据 
m长整型,是数组B所指矩阵的维数 
功能:从B所在位置向左和向右取矩阵的m/2维的子矩阵(子矩阵中包括B所指元素)并存入A 
返回值:无 
****************************************/ 
void GetHalfValue(mytype * A,mytype * B,long m) 
{ 
long i,j; 
for(i=0;i<m/2;i++) 
{ 
for(j=0;j<m/2;j++) 
{ 
A[i*m/2+j]=B[i*m+j]; 
} 
} 
} 
///////////////////////////////////////// 
/**************************************** 
函数名:UpdateHalfValue 
参数:B为指向数组的指针,数组中存储着矩阵。其中B是指向m维矩阵中的一个元素。 
A为指向数组的指针,存储着一个m/2维矩阵 
m长整型,是数组B所指矩阵的维数 
功能:把A矩阵所有元素存入从B所在位置向左和向右的m/2维的子矩阵(子矩阵中包括B所指元素) 
返回值:无 
****************************************/ 
void UpdateHalfValue(mytype * A,mytype * B,long m) 
{ 
long i,j; 
for(i=0;i<m/2;i++) 
{ 
for(j=0;j<m/2;j++) 
{ 
B[i*m+j]=A[i*m/2+j]; 
} 
} 
} 
///////////////////////////////////////// 
/**************************************** 
函数名:Matrix_multiplication 
参数:A,B为指向数组的指针,数组中存储着矩阵。 
C为指向数组的指针,用来存储计算结果 
m长整型,是指针A、B所指矩阵的维数 
功能:用分而治之算法和Strassen方法计算A与B的乘积并存入C 
返回值:无 
****************************************/ 
void Matrix_multiplication(mytype * A,mytype * B,mytype * C,long m) 
{ 
if(m>2)//当矩阵维数大于2时 
{ 
//将矩阵A、B分为四个小矩阵,分别为A1、A2、A3、A4、B1、B2、B3、B4 
mytype *A1=new mytype[m*m/4],*A2=new mytype[m*m/4],*A3=new mytype[m*m/4],*A4=new mytype[m*m/4],*B1=new mytype[m*m/4],*B2=new mytype[m*m/4],*B3=new mytype[m*m/4],*B4=new mytype[m*m/4],*C1=new mytype[m*m/4],*C2=new mytype[m*m/4],*C3=new mytype[m*m/4],*C4=new mytype[m*m/4]; 
GetHalfValue(A1,&A[0],m); 
GetHalfValue(A2,&A[m/2],m); 
GetHalfValue(A3,&A[m*m/2],m); 
GetHalfValue(A4,&A[m*m/2+m/2],m); 
GetHalfValue(B1,&B[0],m); 
GetHalfValue(B2,&B[m/2],m); 
GetHalfValue(B3,&B[m*m/2],m); 
GetHalfValue(B4,&B[m*m/2+m/2],m); 
//利用Strassen方法计算D、E、F、G、H、I、J 
mytype *D=new mytype[m*m/4],*E=new mytype[m*m/4],*F=new mytype[m*m/4],*G=new mytype[m*m/4],*H=new mytype[m*m/4],*I=new mytype[m*m/4],*J=new mytype[m*m/4]; 
mytype *temp1=new mytype[m*m/4],*temp2=new mytype[m*m/4]; 
//D=A1(B2-B4) 
Matrix_add_sub(B2,B4,temp1,m/2,false); 
Matrix_multiplication(A1,temp1,D,m/2); 
//E=A4(B3-B1) 
Matrix_add_sub(B3,B1,temp1,m/2,false); 
Matrix_multiplication(A4,temp1,E,m/2); 
//F=(A3+A4)B1 
Matrix_add_sub(A3,A4,temp1,m/2,true); 
Matrix_multiplication(temp1,B1,F,m/2); 
//G=(A1+A2)B4 
Matrix_add_sub(A1,A2,temp1,m/2,true); 
Matrix_multiplication(temp1,B4,G,m/2); 
//H=(A3-A1)(B1+B2) 
Matrix_add_sub(A3,A1,temp1,m/2,false); 
Matrix_add_sub(B1,B2,temp2,m/2,true); 
Matrix_multiplication(temp1,temp2,H,m/2); 
//I=(A2-A4)(B3+B4) 
Matrix_add_sub(A2,A4,temp1,m/2,false); 
Matrix_add_sub(B3,B4,temp2,m/2,true); 
Matrix_multiplication(temp1,temp2,I,m/2); 
//J=(A1+A4)(B1+B4) 
Matrix_add_sub(A1,A4,temp1,m/2,true); 
Matrix_add_sub(B1,B4,temp2,m/2,true); 
Matrix_multiplication(temp1,temp2,J,m/2); 
//利用Strassen方法计算C1、C2、C3、C4 
//C1=E+I+J-G 
Matrix_add_sub(E,I,temp1,m/2,true); 
Matrix_add_sub(J,G,temp2,m/2,false); 
Matrix_add_sub(temp1,temp2,C1,m/2,true); 
//C2=D+G 
Matrix_add_sub(D,G,C2,m/2,true); 
//C3=E+F 
Matrix_add_sub(E,F,C3,m/2,true); 
//C4=D+H+J-F 
Matrix_add_sub(D,H,temp1,m/2,true); 
Matrix_add_sub(J,F,temp2,m/2,false); 
Matrix_add_sub(temp1,temp2,C4,m/2,true); 
//将计算结果存入数组C 
UpdateHalfValue(C1,&C[0],m); 
UpdateHalfValue(C2,&C[m/2],m); 
UpdateHalfValue(C3,&C[m*m/2],m); 
UpdateHalfValue(C4,&C[m*m/2+m/2],m); 
//释放内存 
delete[] A1,A2,A3,A4,B1,B2,B3,B4,C1,C2,C3,C4,D,E,F,G,H,I,J,temp1,temp2; 
} 
else 
{ 
//当矩阵维数小于2时用Strassen方法计算矩阵乘积 
mytype D,E,F,G,H,I,J; 
//D=A1(B2-B4) 
D=A[0]*(B[1]-B[3]); 
//E=A4(B3-B1) 
E=A[3]*(B[2]-B[0]); 
//F=(A3+A4)B1 
F=(A[2]+A[3])*B[0]; 
//G=(A1+A2)B4 
G=(A[0]+A[1])*B[3]; 
//H=(A3-A1)(B1+B2) 
H=(A[2]-A[0])*(B[0]+B[1]); 
//I=(A2-A4)(B3+B4) 
I=(A[1]-A[3])*(B[2]+B[3]); 
//J=(A1+A4)(B1+B4) 
J=(A[0]+A[3])*(B[0]+B[3]); 
//C1=E+I+J-G 
C[0]=E+I+J-G; 
//C2=D+G 
C[1]=D+G; 
//C3=E+F 
C[2]=E+F; 
//C4=D+H+J-F 
C[3]=D+H+J-F; 
} 
} 
///////////////////////////////////////// 
/*直接计算*/

void Matrix_multiplication1(mytype * A,mytype * B,mytype * C,long m) 
{	
	int i,j;
	i=0;
	while(i<m*m)
		C[i++]=0;
	i=0;
	while(i<m)
	{
		j=0;
		while(j<m)
		{
			int k=0;
			while(k<m)
			{
				C[i*m+j]+=A[i*m+k]*B[k*m+j];
				k++;
			}
			++j;
		}
		++i;
	}

}
/////////////////////////////////////////////////////
int main() 
{ 
long n,k,start1,finish1,start2,finish2; 
int i,j;
//提示输入n维矩阵的维数 
printf("Please input the dimension of the Matrix.(n):"); 
//获得用户输入的n维矩阵维数 
scanf("%d",&k); 
n=k;

	while(n&(n-1))		//检查维数是否是2的幂,不是则转换为2的幂
		n++;
						 
mytype *A=new mytype[n*n]; //开辟空间存储用来存储n维矩阵元素
mytype *B=new mytype[n*n]; 
mytype *C=new mytype[n*n]; 
srand((unsigned)time(NULL));
for(i=0;i<k;++i)
for(j=0;j<k;++j)
{	*(A+i*n+j)=mytype(rand()%10);
	*(B+i*n+j)=mytype(rand()%10);
}

for(i=0;i<k;++i)	//将不是2的幂的矩阵转换为2的幂的矩阵,补充的地方用0
for(j=k;j<n;++j)
{	*(A+i*n+j)=0;
	*(B+i*n+j)=0;}


for(i=k;i<n;++i)	
for(j=0;j<n;++j)
{	*(A+i*n+j)=0;
	*(B+i*n+j)=0;}
start1=clock();
if(n>1)//矩阵维数大于1则用分而治之算法计算 
Matrix_multiplication(A,B,C,n); 
else//矩阵维数为1则直接计算 
*C=(*A)*(*B); 
finish1=clock();
//输出矩阵A、B、C 
printmatrix(A,n,"A"); 
printmatrix(B,n,"B"); 
printf("算法一的结果:\n");
printmatrix(C,n,"C"); 
printf("算法二的结果:\n");
start2=clock();
Matrix_multiplication1(A,B,C,n); 
finish2=clock();
printmatrix(C,n,"C"); 
printf("算法一所用时间:%d\n",finish1-start1);
printf("算法二所用时间:%d\n",finish2-start2);
//释放内存 
delete[] A,B,C; 
getchar();getchar(); 
return 1; 
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -