📄 strassen算法可执行程序.cpp
字号:
/* 计算机科学与技术 2004131063 刘春影 */
/*矩阵相乘的Strassen算法及时间复杂度分析*/
#include<iostream.h>
#include<math.h>
#include<stdlib.h>
#include<stdio.h>
#include<iomanip.h>
void Allot(float **&temp,int k); /*动态分配内存给 float** 变量的函数*/
void Output(float **a,int n);/*输出n阶矩阵a*/
float **Sub(float **b,float **c,int n);/*实现两个矩阵相减的函数*/
void MAndK(int n); /*计算 n=m*2^k*/
void Separate(float **x,float ***A,int n);/*把一个矩阵划分为x11 x12 x21 x22四块*/
float **Strassen(float **a,float **b,int n);/*用Strassen算法实现两个矩阵相乘的函数*/
float **Add(float **b,float **c,int n);/*实现两个矩阵相加的函数*/
float **SepMul(float **b,float **c,int k,int n);/*当m大于1时,传统算法和Strassen算法的结合*/
float **Mul(float **b,float **c,int n);/*实现两个矩阵相乘的传统的函数*/
int k,m; /*全局变量 */
/*****************主函数********************/
main()
{
int i,j;
float **a;
float **b;
Allot(b,12);
for(i=0;i<12;i++)
for(j=0;j<12;j++)
{
b[i][j]=1.0/(i+j+1);
}
MAndK(12);
cout<<"12*12的两个矩阵乘积为:"<<endl;
Output(SepMul(b,b,m,12),12);
return 0;
}
/************用传统的方法计算两个矩阵相乘***************/
float **Mul(float **b,float **c,int n)
{
int i,j,k;
float **a;
Allot(a,n);
for(i=0;i<n;i++)
for(j=0;j<n;j++)
a[i][j]=0;
for(i=0;i<n;i++)
for(j=0;j<n;j++){
for(k=0;k<n;k++)
a[i][j]+=b[i][k]*c[k][j];
}
return a;
}
/****************当m大于1时,传统算法和Strassen算法的结合***********/
float **SepMul(float **b,float **c,int k,int n)
{
int i,j,x,y;
float ****B,****C,****A,**a;
Allot(a,n);
B=new float***[k];
for(i=0;i<k;i++)
B[i]=new float**[k];
for(j=0;j<k;j++)
for(i=0;i<k;i++)
Allot(B[j][i],n/k);
C=new float***[k];
for(i=0;i<k;i++)
C[i]=new float**[k];
for(j=0;j<k;j++)
for(i=0;i<k;i++)
Allot(C[j][i],n/k);
A=new float***[k];
for(i=0;i<k;i++)
A[i]=new float**[k];
for(j=0;j<k;j++)
for(i=0;i<k;i++)
Allot(A[j][i],n/k);
for(i=0;i<k;i++)
for(j=0;j<k;j++)
for(x=0;x<n/k;x++)
for(y=0;y<n/k;y++)
A[i][j][x][y]=0; /*把矩阵b的2^k阶的方阵放到B中*/
for(i=0;i<k;i++)
for(j=0;j<k;j++)
for(x=0;x<n/k;x++)
for(y=0;y<n/k;y++)
{
B[i][j][x][y]=b[i*n/k+x][j*n/k+y];
}
for(i=0;i<k;i++)
for(j=0;j<k;j++)
for(x=0;x<n/k;x++)
for(y=0;y<n/k;y++)
{
C[i][j][x][y]=c[i*n/k+x][j*n/k+y]; /*把矩阵c的2^k阶的方阵放到C中*/
}
/*************传统算法和Strassen算法相结合************/
for(i=0;i<k;i++)
for(j=0;j<k;j++)
{
for(x=0;x<k;x++)
A[i][j]=Add(A[i][j],Strassen(B[i][x],C[x][j],n/k),n/k);
}
/*最后把A赋值给a*/
for(i=0;i<k;i++)
for(j=0;j<k;j++)
for(x=0;x<n/k;x++)
for(y=0;y<n/k;y++)
{
a[i*n/k+x][j*n/k+y]=A[i][j][x][y];
}
return a;
}
/******************计算两个矩阵相加******************/
float **Add(float **b,float **c,int n)
{
int i,j;
float **a;
a=new float*[n];
for(i=0;i<n;i++)
{
a[i]=new float[n];
}
for(i=0;i<n;i++)
for(j=0;j<n;j++){
a[i][j]=b[i][j]+c[i][j];
}
return a;
}
/*************计算两个矩阵相减*****************/
float **Sub(float **b,float **c,int n)
{
int i,j;
float **a;
a=new float*[n];
for(i=0;i<n;i++)
{
a[i]=new float[n];
}
for(i=0;i<n;i++)
for(j=0;j<n;j++){
a[i][j]=b[i][j]-c[i][j];
}
return a;
}
void MAndK(int n) /*计算n阶矩阵的m和k*/
{
k=0;
m=0;
while(n%2==0)
{
n=n/2;
k++;
}
m=n;
}
/****************动态分配内存******************/
void Allot(float **&temp,int k)
{
temp=new float*[k];
for(int i=0;i<k;i++)
{
temp[i]=new float[k];
}
}
void Separate(float **x,float ***A,int n)
{
/*划分矩阵x。A[0][][] is A11.A[1][][] is A12.A[2][][] is A21.A[3][][] is A22*/
int i,j;
for(i=0;i<n/2;i++)
for(j=0;j<n/2;j++)
{
A[0][i][j]=x[i][j];
}
for(i=n/2;i<n;i++)
for(j=0;j<n/2;j++)
{
A[1][i-n/2][j]=x[i][j];
}
for(i=0;i<n/2;i++)
for(j=n/2;j<n;j++)
{
A[2][i][j-n/2]=x[i][j];
}
for(i=n/2;i<n;i++)
for(j=n/2;j<n;j++)
{
A[3][i-n/2][j-n/2]=x[i][j];
}
}
/******************用Strassen算法实现两个矩阵相乘******************/
float **Strassen(float **a,float **b,int n)
{
if(n==2)
{
return Mul(a,b,n);
}
/*递归*/
else{
int i=0,j;
float ***A,***B,**c,***C,***M;
Allot(c,n);
A=new float**[4];
B=new float**[4];
M=new float**[7];
C=new float**[4];
for(i=0;i<7;i++)
{
Allot(M[i],n/2);
if(i<4)
{
Allot(A[i],n/2);
Allot(B[i],n/2);
Allot(C[i],n/2);
}
}
Separate(a,A,n);Separate(b,B,n);
M[0]=Strassen(A[0],Sub(B[1],B[3],n/2),n/2);
M[1]=Strassen(Add(A[0],A[1],n/2),B[3],n/2);
M[2]=Strassen(Add(A[2],A[3],n/2),B[0],n/2);
M[3]=Strassen(A[3],Sub(B[2],B[0],n/2),n/2);
M[4]=Strassen(Add(A[0],A[3],n/2),Add(B[0],B[3],n/2),n/2);
M[5]=Strassen(Sub(A[1],A[3],n/2),Add(B[2],A[3],n/2),n/2);
M[6]=Strassen(Sub(A[0],A[2],n/2),Add(B[0],B[2],n/2),n/2);
C[0]=Add(Sub(Add(M[4],M[3],n/2),M[1],n/2),M[5],n/2);
C[1]=Add(M[0],M[1],n/2);
C[2]=Add(M[2],M[3],n/2);
C[3]=Sub(Sub(Add(M[4],M[0],n/2),M[2],n/2),M[6],n/2);
/*最后生成两个矩阵乘积矩阵c*/
for(i=0;i<n/2;i++)
for(j=0;j<n/2;j++)
{
c[i][j]=C[0][i][j];
c[i][j+n/2]=C[1][i][j];
c[i+n/2][j]=C[2][i][j];
c[i+n/2][j+n/2]=C[3][i][j];
}
return c;
}
}
/**************输出最后两个矩阵乘积矩阵********************/
void Output(float **a,int n)
{
int i,j;
cout.setf(ios::fixed);
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
cout<<setw(7)<<setprecision(6)<<a[i][j]<<' ';
cout<<endl;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -