📄 strassen.cpp
字号:
/*作者:朱友勇(江南大学),于2006.10.2晚完成!花了我两个晚上的时间编程测试!完全正确,放心使用!*/
/*传统方法与Strassen方法相结合的矩阵相乘算法,可以求出任意两个偶数阶矩阵的乘积!*/
#include<iostream.h>
#include<math.h>
#define M 20
#define N 100
int kk,mm;
//全局变量,任一个偶数n可以表示为n=mm*pow(2,kk),其中mm为奇数。
typedef struct matri
{
int m[32][32];
}matrix;
matrix S[M][M],T[M][M],cc[M][M];
int last[N][N]; //用于存储两个偶数阶矩阵相乘的结果,也就是要求的答案
void MakeMAndK(int n) //计算mm和kk的值
{
kk=0;
mm=0;
while(n%2==0)
{
n=n/2;
kk++;
}
mm=n;
}
void DivideMM(matrix A,matrix B)
/*将一个大矩阵拆分成mm*mm个小矩阵的函数*/
{
int i,j,ii,jj,dis;
dis=pow(2,kk);
for(ii=1;ii<=mm;ii++)
for(jj=1;jj<=mm;jj++)
for(i=1;i<=dis;i++)
for(j=1;j<=dis;j++)
{
S[ii][jj].m[i][j]=A.m[i+(ii-1)*dis][j+(jj-1)*dis];
T[ii][jj].m[i][j]=B.m[i+(ii-1)*dis][j+(jj-1)*dis];
cc[ii][jj].m[i][j]=0;
}
}
void Divide(matrix &d,matrix &d11,matrix &d12,matrix &d21,matrix &d22,int n)
/*将一个n阶矩阵拆分成四个小矩阵的函数.其中n必须是2的幂*/
{
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
d11.m[i][j]=d.m[i][j];
d12.m[i][j]=d.m[i][j+n];
d21.m[i][j]=d.m[i+n][j];
d22.m[i][j]=d.m[i+n][j+n];
}
}
matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
int i,j;
matrix a;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
a.m[i][j]=a11.m[i][j];
a.m[i][j+n]=a12.m[i][j];
a.m[i+n][j]=a21.m[i][j];
a.m[i+n][j+n]=a22.m[i][j];
}
return a;
}
matrix AdhocMatrixMultiply(matrix x,matrix y)
/*阶数为2的矩阵乘法函数*/
{
int m1,m2,m3,m4,m5,m6,m7;
matrix z;
m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);
m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
z.m[1][1]=m1+m2;
z.m[1][2]=m5-m1+m4-m6;
z.m[2][1]=m5-m3+m2-m7;
z.m[2][2]=m3+m4;
return z;
}
matrix MatrixPlus(matrix f,matrix g,int n) /*矩阵加法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]+g.m[i][j];
return h;
}
matrix MatrixMinus(matrix f,matrix g,int n) /*矩阵减法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]-g.m[i][j];
return h;
}
matrix MatrixMultiply(matrix a,matrix b,int n)
/*必须是2的幂阶矩阵乘法函数*/
{
int k;
matrix a11,a12,a21,a22;
matrix b11,b12,b21,b22;
matrix c11,c12,c21,c22,c;
matrix m1,m2,m3,m4,m5,m6,m7;
k=n;
if(k==2)
{
c=AdhocMatrixMultiply(a,b);
return c;
}
else
{
k=n/2;
Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
Divide(b,b11,b12,b21,b22,k);
Divide(c,c11,c12,c21,c22,k);
m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
c11=MatrixPlus(m1,m2,k);
c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
c22=MatrixPlus(m3,m4,k);
c=Merge(c11,c12,c21,c22,k); //合并C矩阵
return c;
}
}
void Multiply()
/*矩阵相乘的传统方法*/
{
int i,j,k,dis;
dis=pow(2,kk);
for(i=1;i<=mm;i++)
for(j=1;j<=mm;j++)
for(k=1;k<=mm;k++)
cc[i][j]=MatrixPlus(cc[i][j],MatrixMultiply(S[i][k],T[k][j],dis),dis);
}
void MergeN()
/*合并成n*n的大矩阵,最后要求的结果*/
{
int i,j,ii,jj,dis;
dis=pow(2,kk);
for(ii=1;ii<=mm;ii++)
for(jj=1;jj<=mm;jj++)
for(i=1;i<=dis;i++)
for(j=1;j<=dis;j++)
last[i+(ii-1)*dis][j+(jj-1)*dis]=cc[ii][jj].m[i][j];
}
void main()
{
int i,j,n;
matrix A,B;
cout<<"请输入矩阵的阶数(可以是任意的偶数):";
cin>>n;
if((n%2==0)&&n>=2)
{
cout<<"现在录入矩阵A:"<<endl;
for(i=1;i<=n;i++)
{
cout<<"请输入第"<<i<<"行"<<endl;
for(j=1;j<=n;j++)
cin>>A.m[i][j];
}
cout<<"现在录入矩阵B:"<<endl;
for(i=1;i<=n;i++)
{
cout<<"请输入第"<<i<<"行"<<endl;
for(j=1;j<=n;j++)
cin>>B.m[i][j];
}
MakeMAndK(n);
DivideMM(A,B);
Multiply();
MergeN();
cout<<"结果为:"<<endl;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
cout<<last[i][j]<<'\t';
if(j==n)
cout<<endl;
}
}
else
cout<<"n值不合法!"<<endl;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -