⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 strassen.cpp

📁 strassen矩阵算法以及效率分析
💻 CPP
字号:
#include<stdio.h>

struct matrix
{
float m[32][32];
};

void Divide(matrix &d,matrix &d11,matrix &d12,matrix &d21,matrix &d22,int n)
/*将一个大矩阵拆分成四个小矩阵的函数*/
{
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
d11.m[i][j]=d.m[i][j];
d12.m[i][j]=d.m[i][j+n];
d21.m[i][j]=d.m[i+n][j];
d22.m[i][j]=d.m[i+n][j+n];
}
}

matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
int i,j;
matrix a;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
a.m[i][j]=a11.m[i][j];
a.m[i][j+n]=a12.m[i][j];
a.m[i+n][j]=a21.m[i][j];
a.m[i+n][j+n]=a22.m[i][j];
}
return a;
}

matrix AdhocMatrixMultiply(matrix x,matrix y) /*阶数为2的矩阵乘法函数*/
{
float m1,m2,m3,m4,m5,m6,m7;
matrix z;

m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);
m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
z.m[1][1]=m1+m2;
z.m[1][2]=m5-m1+m4-m6;
z.m[2][1]=m5-m3+m2-m7;
z.m[2][2]=m3+m4;

return z;
}

matrix MatrixPlus(matrix f,matrix g,int n) /*矩阵加法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]+g.m[i][j];
return h;
}

matrix MatrixMinus(matrix f,matrix g,int n) /*矩阵减法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]-g.m[i][j];
return h;
}

matrix MatrixMultiply(matrix a,matrix b,int n) /*矩阵乘法函数*/
{
int k,t,i,j;
matrix a11,a12,a21,a22;
matrix b11,b12,b21,b22;
matrix c11,c12,c21,c22,c;
matrix m1,m2,m3,m4,m5,m6,m7;
k=n;
if(k==2)
{
c=AdhocMatrixMultiply(a,b);
return c;
}
else if(k==3)    //阶数为3时采用传统法
{
for(i=1;i<=n;i++)
    for(j=1;j<=n;j++)
	{
		c.m[i][j]=0;
		 for(t=1;t<=n;t++)
	{
		
		c.m[i][j]+=a.m[i][t]*b.m[t][j];
		 }
	}
		 return c;
}
else
{ 
k=n/2;
Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
Divide(b,b11,b12,b21,b22,k);
Divide(c,c11,c12,c21,c22,k);

m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
c11=MatrixPlus(m1,m2,k);
c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
c22=MatrixPlus(m3,m4,k);

c=Merge(c11,c12,c21,c22,k); //合并C矩阵
return c;
} 
}

void main()
{
int i,j,n;
matrix A,B,C={0};
while(n!=0)
{
printf("请输入矩阵的阶数N:\n");
scanf("%d",&n);
if(n==0) break;

printf("矩阵A:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{A.m[i][j]=1.0/(i+j-1);        //此出可更改scanf("%f",&A.m[i][j]);用户可自行输入数据
printf("%8f%c",A.m[i][j],j==n?'\n':' ');}

printf("矩阵B:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{B.m[i][j]=1.0/(i+j-1);        //此出可更改scanf("%f",&B.m[i][j]);用户可自行输入数据
printf("%8f%c",B.m[i][j],j==n?'\n':' ');}

if(n==1) C.m[1][1]=A.m[1][1]*B.m[1][1]; //矩阵阶数为1时的特殊处理 
else C=MatrixMultiply(A,B,n);

printf("矩阵C=矩阵A*矩阵B:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
printf("%8f%c",C.m[i][j],j==n?'\n':' ');

}
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -