⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 strassen.cpp

📁 传统方法与Strassen算法相结合的矩阵相乘算法
💻 CPP
字号:

/*作者:朱友勇(江南大学),于2006.10.2晚完成!花了我两个晚上的时间编程测试!完全正确,放心使用!*/
/*传统方法与Strassen方法相结合的矩阵相乘算法,可以求出任意两个偶数阶矩阵的乘积!*/

#include<iostream.h>
#include<math.h>
#define M 20
#define N 100

int kk,mm;   
//全局变量,任一个偶数n可以表示为n=mm*pow(2,kk),其中mm为奇数。   
       
typedef struct matri
{
	int m[32][32];
}matrix;

matrix S[M][M],T[M][M],cc[M][M];
int last[N][N];   //用于存储两个偶数阶矩阵相乘的结果,也就是要求的答案

void MakeMAndK(int n)   //计算mm和kk的值
{
	kk=0;
	mm=0;
	while(n%2==0)
	{
		n=n/2;
		kk++;
	}
	mm=n;
}

void DivideMM(matrix A,matrix B)
/*将一个大矩阵拆分成mm*mm个小矩阵的函数*/
{
	int i,j,ii,jj,dis;
	dis=pow(2,kk);
    for(ii=1;ii<=mm;ii++)
		for(jj=1;jj<=mm;jj++)
			for(i=1;i<=dis;i++)
				for(j=1;j<=dis;j++)
				{
					S[ii][jj].m[i][j]=A.m[i+(ii-1)*dis][j+(jj-1)*dis];
                    T[ii][jj].m[i][j]=B.m[i+(ii-1)*dis][j+(jj-1)*dis];
                    cc[ii][jj].m[i][j]=0;
				}
}

void Divide(matrix &d,matrix &d11,matrix &d12,matrix &d21,matrix &d22,int n)
/*将一个n阶矩阵拆分成四个小矩阵的函数.其中n必须是2的幂*/
{
	int i,j;
    for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
		{
			d11.m[i][j]=d.m[i][j];
            d12.m[i][j]=d.m[i][j+n];
            d21.m[i][j]=d.m[i+n][j];
            d22.m[i][j]=d.m[i+n][j+n];
		}
}

matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
	int i,j;
    matrix a;
    for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
		{
			a.m[i][j]=a11.m[i][j];
            a.m[i][j+n]=a12.m[i][j];
            a.m[i+n][j]=a21.m[i][j];
            a.m[i+n][j+n]=a22.m[i][j];
		}
    return a;
}

matrix AdhocMatrixMultiply(matrix x,matrix y)
/*阶数为2的矩阵乘法函数*/
{
	int m1,m2,m3,m4,m5,m6,m7;
    matrix z;

    m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
    m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
    m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
    m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);
    m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
    m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
    m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
    z.m[1][1]=m1+m2;
    z.m[1][2]=m5-m1+m4-m6;
    z.m[2][1]=m5-m3+m2-m7;
    z.m[2][2]=m3+m4;

    return z;
}

matrix MatrixPlus(matrix f,matrix g,int n) /*矩阵加法函数*/
{
	int i,j;
    matrix h;
    for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
			h.m[i][j]=f.m[i][j]+g.m[i][j];
    return h;
}

matrix MatrixMinus(matrix f,matrix g,int n) /*矩阵减法函数*/
{
	int i,j;
    matrix h;
    for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
			h.m[i][j]=f.m[i][j]-g.m[i][j];
    return h;
}

matrix MatrixMultiply(matrix a,matrix b,int n) 
/*必须是2的幂阶矩阵乘法函数*/
{
	int k;
    matrix a11,a12,a21,a22;
    matrix b11,b12,b21,b22;
    matrix c11,c12,c21,c22,c;
    matrix m1,m2,m3,m4,m5,m6,m7;
    k=n;
    if(k==2)
	{
		c=AdhocMatrixMultiply(a,b);
        return c;
	}
    else
	{ 
		k=n/2;
        Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
        Divide(b,b11,b12,b21,b22,k);
        Divide(c,c11,c12,c21,c22,k);

        m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
        m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
        m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
        m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
        m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
        m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
        m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
        c11=MatrixPlus(m1,m2,k);
        c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
        c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
        c22=MatrixPlus(m3,m4,k);

        c=Merge(c11,c12,c21,c22,k); //合并C矩阵
        return c;
	} 
}

void Multiply()
/*矩阵相乘的传统方法*/
{
	int i,j,k,dis;
    dis=pow(2,kk);
    for(i=1;i<=mm;i++)
		for(j=1;j<=mm;j++)
			for(k=1;k<=mm;k++)
				cc[i][j]=MatrixPlus(cc[i][j],MatrixMultiply(S[i][k],T[k][j],dis),dis);
}

void MergeN()
/*合并成n*n的大矩阵,最后要求的结果*/
{
	int i,j,ii,jj,dis;
	dis=pow(2,kk);
    for(ii=1;ii<=mm;ii++)
		for(jj=1;jj<=mm;jj++)
			for(i=1;i<=dis;i++)
				for(j=1;j<=dis;j++)
					last[i+(ii-1)*dis][j+(jj-1)*dis]=cc[ii][jj].m[i][j];
}

void main()
{
	int i,j,n;
    matrix A,B;
	cout<<"请输入矩阵的阶数(可以是任意的偶数):";
	cin>>n;
if((n%2==0)&&n>=2)
{
	cout<<"现在录入矩阵A:"<<endl;
	for(i=1;i<=n;i++)
    {
        cout<<"请输入第"<<i<<"行"<<endl;
        for(j=1;j<=n;j++)
            cin>>A.m[i][j];
    }
	cout<<"现在录入矩阵B:"<<endl;
	for(i=1;i<=n;i++)
    {
        cout<<"请输入第"<<i<<"行"<<endl;
        for(j=1;j<=n;j++)
            cin>>B.m[i][j];
    }
     MakeMAndK(n);
	 DivideMM(A,B);
	 Multiply();
	 MergeN();
	 cout<<"结果为:"<<endl;
	 for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
		{
			cout<<last[i][j]<<'\t';
			if(j==n)
				cout<<endl;
		}
}
else
cout<<"n值不合法!"<<endl;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -