⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 strassen.cpp

📁 设计编程实现矩阵相乘的Strassen算法
💻 CPP
字号:
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <vector>
#include <iostream>
#include <iomanip>
using namespace std;

int count1=0,count2=0;

class CMatrix
{
private:
	int length;		//矩阵的阶
	vector<vector<int> > matrix;	//2维矢量用于放矩阵的值
public:	
	CMatrix(){length=0;}	//无参时,矩阵的阶为0	
	CMatrix(int n,int m);	//m为1时,为矩阵随即赋值,否则只建立矩阵
	CMatrix(CMatrix &tempa); //拷贝构造函数
	const CMatrix& operator = (CMatrix &tempa);	//矩阵赋值
	CMatrix operator - (CMatrix &tempa);		//矩阵减法
	CMatrix operator + (CMatrix &tempa);		//矩阵加法
	CMatrix operator * (CMatrix &tempa);		//矩阵乘法,当length>2时
	//使用srassen算法,否则用传统乘法
	friend CMatrix strassen(CMatrix &A,CMatrix &B);	//srassen算法
	friend CMatrix tradition_mul(CMatrix &A,CMatrix &B);//传统乘法
	void divide(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22);//分割矩阵
	CMatrix merge(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22);//合并
	//矩阵
	int getlength();//取矩阵长度
	void printCM();	//显示矩阵
	~CMatrix()	//析构函数
	{
		matrix.~vector();
	}
	
};

CMatrix::CMatrix(int n,int m)
{
	length=n;
	matrix=vector<vector<int> >(n,vector<int>(n));
	if(n!=0&&m==1)
	{
			for(int i=0;i<n;i++)
				for(int j=0;j<n;j++)
					matrix[i][j]=rand()%11;

	}
}
CMatrix::CMatrix(CMatrix &tempa) 
{
	length=tempa.length;
	matrix=vector<vector<int> >(tempa.matrix );
}

inline const CMatrix& CMatrix::operator = (CMatrix &tempa)
{
	length=tempa.length;
	matrix=tempa.matrix;
	return *this;
}

CMatrix CMatrix::operator + (CMatrix &tempa)
{
	CMatrix temp(*this);
	if(length!=tempa.length)
		cout<<"矩阵阶不同,不能加。"<<endl;
	else
	{
		for(int i=0;i<length;i++)
			for(int j=0;j<length;j++)
				temp.matrix[i][j]=matrix[i][j]+tempa.matrix[i][j];

	}
	return temp;
}

CMatrix CMatrix::operator - (CMatrix &tempa)
{
	CMatrix temp(*this);
	if(length!=tempa.length)
		cout<<"矩阵阶不同,不能减。"<<endl;
	else
	{
		for(int i=0;i<length;i++)
			for(int j=0;j<length;j++)
				temp.matrix[i][j]=matrix[i][j]-tempa.matrix[i][j];

	}
	return temp;
}

CMatrix CMatrix::operator * (CMatrix &tempa)
{
	CMatrix temp(*this);
	temp.matrix=vector<vector<int> >(length,vector<int>(length,0));
	if(length!=tempa.length)
		cout<<"矩阵阶不同,不能乘。"<<endl;
	else
	{
		if(length>2)
			temp=strassen(*this,tempa);
		else if(length==2)
		{
			for(int i=0;i<length;i++)
				for(int j=0;j<length;j++)
					for(int k=0;k<length;k++)
						temp.matrix[i][j] += matrix[i][k]*tempa.matrix[k][j];

		}
		else
			cout<<"出错了。"<<endl;
	}
	return temp;
}

void CMatrix::divide(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22)
{
	int n=(length+1)/2;	//子矩阵长度
	int i,j;
	t11=CMatrix(n,0);
	t12=CMatrix(n,0);
	t21=CMatrix(n,0);
	t22=CMatrix(n,0);	//建立四个空矩阵
	t11.matrix.insert(t11.matrix.begin(),matrix.begin(),matrix.begin()+n);
	for(i=0;i<n;i++)
		t11.matrix[i].resize(n);
	for(i=0;i<n;i++)
	{
		t12.matrix[i].insert(t12.matrix[i].begin(),matrix[i].begin()+n
			,matrix[i].end());
		t12.matrix[i].resize(n);
	}
	t21.matrix.insert(t21.matrix.begin(),matrix.begin()+n,matrix.end()); 
	for(i=0;i<n;i++)
		t21.matrix[i].resize(n);
	for(i=0,j=n;j<length;i++,j++)
	{
		t22.matrix[i].insert(t22.matrix[i].begin(),matrix[j].begin()+n
			,matrix[j].end());
		t22.matrix[i].resize(n);
	}
	if(length!=(n*2))
	{
		for(i=0;i<n;i++)
			t12.matrix[i].push_back (0);//对t21最后一列补0
		t21.matrix .push_back(vector<int>(n,0));//对t21最后一行补0
		for(i=0;i<(length-n);i++)
			t22.matrix[i].push_back (0);//对t22前n-1行的最后一列补0
		t22.matrix .push_back(vector<int>(n,0));//对t22补最后一行补0
		
	}
}

CMatrix CMatrix::merge(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22)
{
	int n=t11.length,i,j;
	CMatrix temp(2*n,0);
	temp.matrix.insert(temp.matrix.begin(),t11.matrix .begin ()
		,t11.matrix .end ());
	temp.matrix.insert(temp.matrix.begin()+n,t21.matrix .begin ()
		,t21.matrix .end ());
	for(i=0;i<n;i++)
	{
		temp.matrix[i].insert(temp.matrix[i].begin()+n,t12.matrix[i] .begin ()
			,t12.matrix[i].end ());
		temp.matrix[i].resize(length);
	}
	for(i=0,j=n;i<n;i++,j++)
	{
		temp.matrix[j].insert (temp.matrix[j].begin()+n,t22.matrix[i].begin()
			,t22.matrix[i].end ());
		temp.matrix[j].resize(length);
	}
	temp.matrix.resize(length);
	temp.length =length;
	return temp;
}

void CMatrix::printCM ()
{
	for(int i=0;i<length;i++)
	{
		for(int j=0;j<length;j++)
			cout<<setw(5)<<matrix[i][j];
		cout<<endl;
	}	
}

inline int CMatrix::getlength()
{
	return length;
}

CMatrix strassen(CMatrix &A,CMatrix &B)
{
	CMatrix C(A.length,0);
	CMatrix A11,A12,A21,A22,B11,B12,B21,B22,C11,C12,C21,C22,
		M1,M2,M3,M4,M5,M6,M7;
	A.divide (A11,A12,A21,A22);	//矩阵分割
	B.divide (B11,B12,B21,B22);	//矩阵分割
	M1=A11*(B12-B22);			//开始矩阵strassen算法
	M2=(A11+A12)*B22;
	M3=(A21+A22)*B11;
	M4=A22*(B21-B11);
	M5=(A11+A22)*(B11+B22);
	M6=(A12-A22)*(B21+B22);
	M7=(A11-A21)*(B11+B12);	
	C11=M5+M4-M2+M6;
	C12=M1+M2;
	C21=M3+M4;
	C22=M5+M1-M3-M7;			//结束矩阵strassen算法
/*	C11=A11*B11+A12*B21;		//传统分治乘法
	C12=A11*B12+A12*B22;
	C21=A21*B11+A22*B21;
	C22=A21*B12+A22*B22;*/
	return C.merge (C11,C12,C21,C22);

}

CMatrix tradition_mul(CMatrix &A,CMatrix &B)
{
	int i,j,k,n=A.length ;
	CMatrix C(n ,0);
	C.matrix = vector<vector<int> >(n,vector<int>(n,0));
	for(i=0;i<n;i++)
		for(j=0;j<n;j++)
			for(k=0;k<n;k++)
			{
				C.matrix[i][j] +=A.matrix[i][k]*B.matrix[k][j];
			}
	return C;
	
}


void main()
{
	CMatrix C(5,1),C11,C12,C21,C22,M;
	cout<<"矩阵C:\n";
	C.printCM ();
	cout<<"分割矩阵:\n";
	C.divide (C11,C12,C21,C22);
	cout<<"矩阵C11:\n";
	C11.printCM ();
	cout<<"矩阵C12:\n";
	C12.printCM ();
	cout<<"矩阵C21:\n";
	C21.printCM ();
	cout<<"矩阵C22:\n";
	C22.printCM ();
	cout<<"合并矩阵:\n";
	M=C.merge (C11,C12,C21,C22); 
	M.printCM ();
/*
	srand((unsigned)time(NULL));	//生成新的随机排列
	CMatrix A(5,1),B(5,1),C1,C2;
	cout<<"矩阵A的值:\n";
	A.printCM();
	cout<<endl;
	cout<<"矩阵B的值:\n";
	B.printCM();
	cout<<endl;
	C1=A*B;
	C2=tradition_mul(A,B);
	cout<<"strassen算法所得的结果:\n";
	C1.printCM();
	cout<<endl;
	cout<<"传统算法所得的结果:\n";
	C2.printCM();
	cout<<endl;
*/
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -