📄 strassen.cpp
字号:
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <vector>
#include <iostream>
#include <iomanip>
using namespace std;
int count1=0,count2=0;
class CMatrix
{
private:
int length; //矩阵的阶
vector<vector<int> > matrix; //2维矢量用于放矩阵的值
public:
CMatrix(){length=0;} //无参时,矩阵的阶为0
CMatrix(int n,int m); //m为1时,为矩阵随即赋值,否则只建立矩阵
CMatrix(CMatrix &tempa); //拷贝构造函数
const CMatrix& operator = (CMatrix &tempa); //矩阵赋值
CMatrix operator - (CMatrix &tempa); //矩阵减法
CMatrix operator + (CMatrix &tempa); //矩阵加法
CMatrix operator * (CMatrix &tempa); //矩阵乘法,当length>2时
//使用srassen算法,否则用传统乘法
friend CMatrix strassen(CMatrix &A,CMatrix &B); //srassen算法
friend CMatrix tradition_mul(CMatrix &A,CMatrix &B);//传统乘法
void divide(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22);//分割矩阵
CMatrix merge(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22);//合并
//矩阵
int getlength();//取矩阵长度
void printCM(); //显示矩阵
~CMatrix() //析构函数
{
matrix.~vector();
}
};
CMatrix::CMatrix(int n,int m)
{
length=n;
matrix=vector<vector<int> >(n,vector<int>(n));
if(n!=0&&m==1)
{
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
matrix[i][j]=rand()%11;
}
}
CMatrix::CMatrix(CMatrix &tempa)
{
length=tempa.length;
matrix=vector<vector<int> >(tempa.matrix );
}
inline const CMatrix& CMatrix::operator = (CMatrix &tempa)
{
length=tempa.length;
matrix=tempa.matrix;
return *this;
}
CMatrix CMatrix::operator + (CMatrix &tempa)
{
CMatrix temp(*this);
if(length!=tempa.length)
cout<<"矩阵阶不同,不能加。"<<endl;
else
{
for(int i=0;i<length;i++)
for(int j=0;j<length;j++)
temp.matrix[i][j]=matrix[i][j]+tempa.matrix[i][j];
}
return temp;
}
CMatrix CMatrix::operator - (CMatrix &tempa)
{
CMatrix temp(*this);
if(length!=tempa.length)
cout<<"矩阵阶不同,不能减。"<<endl;
else
{
for(int i=0;i<length;i++)
for(int j=0;j<length;j++)
temp.matrix[i][j]=matrix[i][j]-tempa.matrix[i][j];
}
return temp;
}
CMatrix CMatrix::operator * (CMatrix &tempa)
{
CMatrix temp(*this);
temp.matrix=vector<vector<int> >(length,vector<int>(length,0));
if(length!=tempa.length)
cout<<"矩阵阶不同,不能乘。"<<endl;
else
{
if(length>2)
temp=strassen(*this,tempa);
else if(length==2)
{
for(int i=0;i<length;i++)
for(int j=0;j<length;j++)
for(int k=0;k<length;k++)
temp.matrix[i][j] += matrix[i][k]*tempa.matrix[k][j];
}
else
cout<<"出错了。"<<endl;
}
return temp;
}
void CMatrix::divide(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22)
{
int n=(length+1)/2; //子矩阵长度
int i,j;
t11=CMatrix(n,0);
t12=CMatrix(n,0);
t21=CMatrix(n,0);
t22=CMatrix(n,0); //建立四个空矩阵
t11.matrix.insert(t11.matrix.begin(),matrix.begin(),matrix.begin()+n);
for(i=0;i<n;i++)
t11.matrix[i].resize(n);
for(i=0;i<n;i++)
{
t12.matrix[i].insert(t12.matrix[i].begin(),matrix[i].begin()+n
,matrix[i].end());
t12.matrix[i].resize(n);
}
t21.matrix.insert(t21.matrix.begin(),matrix.begin()+n,matrix.end());
for(i=0;i<n;i++)
t21.matrix[i].resize(n);
for(i=0,j=n;j<length;i++,j++)
{
t22.matrix[i].insert(t22.matrix[i].begin(),matrix[j].begin()+n
,matrix[j].end());
t22.matrix[i].resize(n);
}
if(length!=(n*2))
{
for(i=0;i<n;i++)
t12.matrix[i].push_back (0);//对t21最后一列补0
t21.matrix .push_back(vector<int>(n,0));//对t21最后一行补0
for(i=0;i<(length-n);i++)
t22.matrix[i].push_back (0);//对t22前n-1行的最后一列补0
t22.matrix .push_back(vector<int>(n,0));//对t22补最后一行补0
}
}
CMatrix CMatrix::merge(CMatrix& t11,CMatrix& t12,CMatrix& t21,CMatrix& t22)
{
int n=t11.length,i,j;
CMatrix temp(2*n,0);
temp.matrix.insert(temp.matrix.begin(),t11.matrix .begin ()
,t11.matrix .end ());
temp.matrix.insert(temp.matrix.begin()+n,t21.matrix .begin ()
,t21.matrix .end ());
for(i=0;i<n;i++)
{
temp.matrix[i].insert(temp.matrix[i].begin()+n,t12.matrix[i] .begin ()
,t12.matrix[i].end ());
temp.matrix[i].resize(length);
}
for(i=0,j=n;i<n;i++,j++)
{
temp.matrix[j].insert (temp.matrix[j].begin()+n,t22.matrix[i].begin()
,t22.matrix[i].end ());
temp.matrix[j].resize(length);
}
temp.matrix.resize(length);
temp.length =length;
return temp;
}
void CMatrix::printCM ()
{
for(int i=0;i<length;i++)
{
for(int j=0;j<length;j++)
cout<<setw(5)<<matrix[i][j];
cout<<endl;
}
}
inline int CMatrix::getlength()
{
return length;
}
CMatrix strassen(CMatrix &A,CMatrix &B)
{
CMatrix C(A.length,0);
CMatrix A11,A12,A21,A22,B11,B12,B21,B22,C11,C12,C21,C22,
M1,M2,M3,M4,M5,M6,M7;
A.divide (A11,A12,A21,A22); //矩阵分割
B.divide (B11,B12,B21,B22); //矩阵分割
M1=A11*(B12-B22); //开始矩阵strassen算法
M2=(A11+A12)*B22;
M3=(A21+A22)*B11;
M4=A22*(B21-B11);
M5=(A11+A22)*(B11+B22);
M6=(A12-A22)*(B21+B22);
M7=(A11-A21)*(B11+B12);
C11=M5+M4-M2+M6;
C12=M1+M2;
C21=M3+M4;
C22=M5+M1-M3-M7; //结束矩阵strassen算法
/* C11=A11*B11+A12*B21; //传统分治乘法
C12=A11*B12+A12*B22;
C21=A21*B11+A22*B21;
C22=A21*B12+A22*B22;*/
return C.merge (C11,C12,C21,C22);
}
CMatrix tradition_mul(CMatrix &A,CMatrix &B)
{
int i,j,k,n=A.length ;
CMatrix C(n ,0);
C.matrix = vector<vector<int> >(n,vector<int>(n,0));
for(i=0;i<n;i++)
for(j=0;j<n;j++)
for(k=0;k<n;k++)
{
C.matrix[i][j] +=A.matrix[i][k]*B.matrix[k][j];
}
return C;
}
void main()
{
CMatrix C(5,1),C11,C12,C21,C22,M;
cout<<"矩阵C:\n";
C.printCM ();
cout<<"分割矩阵:\n";
C.divide (C11,C12,C21,C22);
cout<<"矩阵C11:\n";
C11.printCM ();
cout<<"矩阵C12:\n";
C12.printCM ();
cout<<"矩阵C21:\n";
C21.printCM ();
cout<<"矩阵C22:\n";
C22.printCM ();
cout<<"合并矩阵:\n";
M=C.merge (C11,C12,C21,C22);
M.printCM ();
/*
srand((unsigned)time(NULL)); //生成新的随机排列
CMatrix A(5,1),B(5,1),C1,C2;
cout<<"矩阵A的值:\n";
A.printCM();
cout<<endl;
cout<<"矩阵B的值:\n";
B.printCM();
cout<<endl;
C1=A*B;
C2=tradition_mul(A,B);
cout<<"strassen算法所得的结果:\n";
C1.printCM();
cout<<endl;
cout<<"传统算法所得的结果:\n";
C2.printCM();
cout<<endl;
*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -