📄 strassen.cpp
字号:
#include<iostream.h>
#include<math.h>
#include <windows.h>
#define M 20
#define N 100
int kk,mm;
//全局变量,任一个偶数n可以表示为n=mm*pow(2,kk),其中mm为奇数。
typedef struct matri
{
int m[32][32];
}matrix;
matrix S[M][M],T[M][M],cc[M][M];
int last[N][N]; //用于存储两个偶数阶矩阵相乘的结果,也就是要求的答案
void MakeMAndK(int n) //计算mm和kk的值
{
kk=0;
mm=0;
while(n%2==0)
{
n=n/2;
kk++;
}
mm=n;
}
void DivideMM(matrix A,matrix B)
/*将一个大矩阵拆分成mm*mm个小矩阵的函数*/
{
int i,j,ii,jj,dis;
dis=pow(2,kk);
for(ii=1;ii<=mm;ii++)
for(jj=1;jj<=mm;jj++)
for(i=1;i<=dis;i++)
for(j=1;j<=dis;j++)
{
S[ii][jj].m[i][j]=A.m[i+(ii-1)*dis][j+(jj-1)*dis];
T[ii][jj].m[i][j]=B.m[i+(ii-1)*dis][j+(jj-1)*dis];
cc[ii][jj].m[i][j]=0;
}
}
void Divide(matrix &d,matrix &d11,matrix &d12,matrix &d21,matrix &d22,int n)
/*将一个n阶矩阵拆分成四个小矩阵的函数.其中n必须是2的幂*/
{
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
d11.m[i][j]=d.m[i][j];
d12.m[i][j]=d.m[i][j+n];
d21.m[i][j]=d.m[i+n][j];
d22.m[i][j]=d.m[i+n][j+n];
}
}
matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
int i,j;
matrix a;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
a.m[i][j]=a11.m[i][j];
a.m[i][j+n]=a12.m[i][j];
a.m[i+n][j]=a21.m[i][j];
a.m[i+n][j+n]=a22.m[i][j];
}
return a;
}
matrix AdhocMatrixMultiply(matrix x,matrix y)
/*阶数为2的矩阵乘法函数*/
{
int m1,m2,m3,m4,m5,m6,m7;
matrix z;
m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);
m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
z.m[1][1]=m1+m2;
z.m[1][2]=m5-m1+m4-m6;
z.m[2][1]=m5-m3+m2-m7;
z.m[2][2]=m3+m4;
return z;
}
matrix MatrixPlus(matrix f,matrix g,int n) /*矩阵加法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]+g.m[i][j];
return h;
}
matrix MatrixMinus(matrix f,matrix g,int n) /*矩阵减法函数*/
{
int i,j;
matrix h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]-g.m[i][j];
return h;
}
matrix MatrixMultiply(matrix a,matrix b,int n)
/*必须是2的幂阶矩阵乘法函数*/
{
int k;
matrix a11,a12,a21,a22;
matrix b11,b12,b21,b22;
matrix c11,c12,c21,c22,c;
matrix m1,m2,m3,m4,m5,m6,m7;
k=n;
if(k==2)
{
c=AdhocMatrixMultiply(a,b);
return c;
}
else
{
k=n/2;
Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
Divide(b,b11,b12,b21,b22,k);
Divide(c,c11,c12,c21,c22,k);
m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
c11=MatrixPlus(m1,m2,k);
c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
c22=MatrixPlus(m3,m4,k);
c=Merge(c11,c12,c21,c22,k); //合并C矩阵
return c;
}
}
void Multiply()
/*矩阵相乘的传统方法*/
{
int i,j,k,dis;
dis=pow(2,kk);
for(i=1;i<=mm;i++)
for(j=1;j<=mm;j++)
for(k=1;k<=mm;k++)
cc[i][j]=MatrixPlus(cc[i][j],MatrixMultiply(S[i][k],T[k][j],dis),dis);
}
void MergeN()
/*合并成n*n的大矩阵,最后要求的结果*/
{
int i,j,ii,jj,dis;
dis=pow(2,kk);
for(ii=1;ii<=mm;ii++)
for(jj=1;jj<=mm;jj++)
for(i=1;i<=dis;i++)
for(j=1;j<=dis;j++)
last[i+(ii-1)*dis][j+(jj-1)*dis]=cc[ii][jj].m[i][j];
}
void main()
{
int i,j,n;
matrix A,B;
LARGE_INTEGER BegainTime ;
LARGE_INTEGER EndTime ;
LARGE_INTEGER Frequency ;
QueryPerformanceFrequency(&Frequency);
cout<<"请输入矩阵的阶数(可以是任意的偶数):";
cin>>n;
if((n%2==0)&&n>=2)
{
cout<<"现在随机生成矩阵A:"<<endl;
for(i=1;i<=n;i++)
{
// cout<<"请输入第"<<i<<"行"<<endl;
for(j=1;j<=n;j++)
A.m[i][j]=rand()%10;
//cin>>A.m[i][j];
}
for(j=1;j<=n;j++)
for(i=1;i<=n;i++)
{
cout<<A.m[i][j]<<" ";
if(i%n==0)
cout<<endl;
}
cout<<endl;
cout<<"现在随机生成矩阵B:"<<endl;
for(i=1;i<=n;i++)
{
//cout<<"请输入第"<<i<<"行"<<endl;
for(j=1;j<=n;j++)
B.m[i][j]=rand()%10;
//cin>>B.m[i][j];
}
for(j=1;j<=n;j++)
for(i=1;i<=n;i++)
{
cout<<B.m[i][j]<<" ";
if(i%n==0)
cout<<endl;
}
cout<<endl;
QueryPerformanceCounter(&BegainTime) ;
MakeMAndK(n);
DivideMM(A,B);
Multiply();
MergeN();
cout<<"计算结果为:"<<endl;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
cout<<last[i][j]<<'\t';
if(j==n)
cout<<endl;
}
QueryPerformanceCounter(&EndTime);
cout << "运行时间(单位:s):" <<(double)( EndTime.QuadPart - BegainTime.QuadPart )/ Frequency.QuadPart <<endl;
system("pause");
}
else
cout<<"n值不合法!"<<endl;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -