📄 strassen.cpp
字号:
#include<stdio.h>
struct matrix
{
float m[32][32];
};
void Divide(matrix &d,matrix &d11,matrix &d12,matrix &d21,matrix &d22,int n)
/*将一个大矩阵拆分成四个小矩阵的函数*/
{
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
d11.m[i][j]=d.m[i][j];
d12.m[i][j]=d.m[i][j+n];
d21.m[i][j]=d.m[i+n][j];
d22.m[i][j]=d.m[i+n][j+n];
}
}
matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
int i,j;
matrix a;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
a.m[i][j]=a11.m[i][j];
a.m[i][j+n]=a12.m[i][j];
a.m[i+n][j]=a21.m[i][j];
a.m[i+n][j+n]=a22.m[i][j];
}
return a;
}
matrix AdhocMatrixMultiply(matrix x,matrix y) /*阶数为2的矩阵乘法函数*/
{
float m1,m2,m3,m4,m5,m6,m7;
matrix z;
m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);
m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
z.m[1][1]=m1+m2;
z.m[1][2]=m5-m1+m4-m6;
z.m[2][1]=m5-m3+m2-m7;
z.m[2][2]=m3+m4;
return z;
}
matrix MatrixPlus(matrix f,matrix g,int n) /*矩阵加法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]+g.m[i][j];
return h;
}
matrix MatrixMinus(matrix f,matrix g,int n) /*矩阵减法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]-g.m[i][j];
return h;
}
matrix MatrixMultiply(matrix a,matrix b,int n) /*矩阵乘法函数*/
{
int k,t,i,j;
matrix a11,a12,a21,a22;
matrix b11,b12,b21,b22;
matrix c11,c12,c21,c22,c;
matrix m1,m2,m3,m4,m5,m6,m7;
k=n;
if(k==2)
{
c=AdhocMatrixMultiply(a,b);
return c;
}
else if(k==3) //阶数为3时采用传统法
{
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
c.m[i][j]=0;
for(t=1;t<=n;t++)
{
c.m[i][j]+=a.m[i][t]*b.m[t][j];
}
}
return c;
}
else
{
k=n/2;
Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
Divide(b,b11,b12,b21,b22,k);
Divide(c,c11,c12,c21,c22,k);
m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
c11=MatrixPlus(m1,m2,k);
c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
c22=MatrixPlus(m3,m4,k);
c=Merge(c11,c12,c21,c22,k); //合并C矩阵
return c;
}
}
void main()
{
int i,j,n;
matrix A,B,C={0};
while(n!=0)
{
printf("请输入矩阵的阶数N:\n");
scanf("%d",&n);
if(n==0) break;
printf("矩阵A:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{A.m[i][j]=1.0/(i+j-1); //此出可更改scanf("%f",&A.m[i][j]);用户可自行输入数据
printf("%8f%c",A.m[i][j],j==n?'\n':' ');}
printf("矩阵B:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{B.m[i][j]=1.0/(i+j-1); //此出可更改scanf("%f",&B.m[i][j]);用户可自行输入数据
printf("%8f%c",B.m[i][j],j==n?'\n':' ');}
if(n==1) C.m[1][1]=A.m[1][1]*B.m[1][1]; //矩阵阶数为1时的特殊处理
else C=MatrixMultiply(A,B,n);
printf("矩阵C=矩阵A*矩阵B:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
printf("%8f%c",C.m[i][j],j==n?'\n':' ');
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -