📄 sparematrix.cpp
字号:
#include<iostream>
#include<cstdlib>
#include<iomanip>
#include<fstream>
#include<string>
#include"spareMatrix.h"
using namespace std;
template<class T>
Triple<T>& operator -(Triple<T>& x,Triple<T>& y)
{
x.row = y.row;
x.col = y.col;
x.value = x.value - y.value;
return x;
}
template<class T>
Triple<T>& operator +(Triple<T>& x,Triple<T>& y)
{
x.row = y.row;
x.col = y.col;
x.value = x.value + y.value;
return x;
}
template<class T>
SpareMatrix<T>::SpareMatrix(int szmax)
{
//cout<<"调用默认构造函数\n";
maxTerms = szmax;
smArray = new Triple<T>[maxTerms];
if(smArray == NULL)
{
cerr<<"内存分配出错1,退出!\n";
exit(1);
}
Rows = Cols = Terms = 0;
}
template<class T>
SpareMatrix<T>::SpareMatrix(SpareMatrix<T>& x)
{
//cout<<"调用复制构造函数\n";
Rows = x.Rows;
Cols = x.Cols;
Terms = x.Terms;
maxTerms = x.maxTerms;
smArray = new Triple<T>[maxTerms];
if(smArray == NULL)
{
cerr<<"内存分配出错2,退出!\n";
exit(1);
}
for(int i=0;i<Terms;i++)
smArray[i] = x.smArray[i];
}
template<class T>
void SpareMatrix<T>::Input(SpareMatrix<T>& x,int szmax)
{
int n,r,c;
int i=0;
T data;
bool IsTrue;
cout<<"请输入矩阵的维数,以行列的形式输入"<<endl;
cin>>r;
cin>>c;
x.Rows = r;
x.Cols = c;
cout<<"请输入非零矩阵的非零元个数:"<<endl;
cin>>n;
if(n>szmax)cerr<<"非零元个数太多!最多能存50个非零元\n";
x.Terms = n;
if(n != 0)
IsTrue = true;
else
IsTrue = false;
cout<<"请输入非零元所在的行数、列数、和值"<<endl;
cout<<"他们之间用空格键区分,确认按enter"<<endl;
if(IsTrue)
{
for(i=0;i<x.Terms;i++)
{
cin>>r;
cin>>c;
cin>>data;
if((r > x.Rows-1)||(c > x.Cols-1))
{
if(r > x.Rows-1)
{
cerr<<"矩阵行向量下标溢出!\n";
cout<<"重新输入!\n";
}
if(c > x.Cols-1)
{
cerr<<"矩阵列向量下标溢出!\n";
cout<<"重新输入!\n";
}
i--;
continue;
}
x.smArray[i].row = r;
x.smArray[i].col = c;
x.smArray[i].value = data;
}
}
}
template<class T>
void SpareMatrix<T>::Output(SpareMatrix<T>& M)
{
int r,c,t;
r = M.Rows;
c = M.Cols;
t = M.Terms;
for(int i=0;i<t;i++)
cout<<M.smArray[i].row<<setw(5)<<M.smArray[i].col<<setw(5)<<M.smArray[i].value<<endl;
}
template<class T>
SpareMatrix<T> SpareMatrix<T>::operator = (SpareMatrix<T>& x)
{
Rows = x.Rows;
Cols = x.Cols;
Terms = x.Terms;
maxTerms = x.maxTerms;
smArray = new Triple<T>[maxTerms];
if(smArray == NULL)
{
cerr<<"内存分配出错4,退出!\n";
exit(1);
}
for(int i=0;i<Terms;i++)
smArray[i] = x.smArray[i];
return *this;
}
template<class T>
istream& operator >> (istream& in,SpareMatrix<T>& M(int szmax=50))
{
cout<<"请输入矩阵的维数,以行列的形式输入"<<endl;
in>>M.Rows>>M.Cols;
cout<<"请输入非零矩阵的非零元个数:";
in>>M.Terms;
if(M.Terms>szmax)
{
cerr<<"非零元个数太多!最多能存50个非零元\n"<<endl;
exit(1);
}
for(int i=0;i<M.Terms;i++)
{
in>>M.smArray[i].row>>M.smArray[i].col>>M.smArray[i].value;
}
return in;
}
template<class T>
ostream& operator <<(ostream& out,SpareMatrix<T>& M)
{
out<<"行="<<M.Rows<<endl;
out<<"列="<<M.Cols<<endl;
out<<"非零元="<<M.Terms<<endl;
out<<"行"<<setw(5)<<"列"<<setw(5)<<"非零元"<<endl;
for(int i=0;i<M.Terms;i++)
{
out<<M.smArray[i].row<<setw(5)<<M.smArray[i].col<<setw(5)<<M.smArray[i].value;
out<<endl;
}
return out;
}
template<class T>
SpareMatrix<T> Add(SpareMatrix<T> a,SpareMatrix<T> b)
{
int ra,rb,ca,cb,ta,tb;
SpareMatrix<T> c;
ra = a.Rows;
rb = b.Rows;
ca = a.Cols;
cb = b.Cols;
ta = a.Terms;
tb = b.Terms;
if((ra != rb)||(ca != cb))
{
cerr<<"矩阵不匹配,退出!\n";
exit(1);
}
c.Rows = ra;
c.Cols = ca;
c.Terms = 0;
int i=0,j=0;
while(i<a.Terms&&j<b.Terms)
{
int temp_a = 0;
int temp_b = 0;
temp_a = a.smArray[i].row*ca+a.smArray[i].col;
temp_b = b.smArray[j].row*cb+b.smArray[j].col;
if(temp_a < temp_b)
{
c.smArray[c.Terms] = a.smArray[i];
c.Terms++;
i++;
}
else if(temp_a > temp_b)
{
c.smArray[c.Terms] = b.smArray[j];
c.Terms++;
j++;
}
else
{
c.smArray[c.Terms].value = a.smArray[i].value + b.smArray[j].value;
if(c.smArray[c.Terms].value != 0)
{
c.smArray[c.Terms] = a.smArray[i] + b.smArray[j];
i++;
j++;
c.Terms++;
}
else
{
i++;
j++;
continue;
}
}
}
while(i<a.Terms)
{
c.smArray[c.Terms] = a.smArray[i];
c.Terms++;
i++;
}
while(j<b.Terms)
{
c.smArray[c.Terms] = b.smArray[j];
c.Terms++;
j++;
}
return c;
}
template<class T>
SpareMatrix<T> Sub(SpareMatrix<T> a,SpareMatrix<T> b)
{
int ra,rb,ca,cb,ta,tb;
SpareMatrix<T> c;
ra = a.Rows;
rb = b.Rows;
ca = a.Cols;
cb = b.Cols;
ta = a.Terms;
tb = b.Terms;
if((ra != rb)||(ca != cb))
{
cerr<<"矩阵不匹配,退出!\n";
exit(1);
}
c.Rows = ra;
c.Cols = ca;
c.Terms = 0;
int i=0,j=0;
while(i<a.Terms&&j<b.Terms)
{
int temp_a = 0;
int temp_b = 0;
temp_a = a.smArray[i].row*ca+a.smArray[i].col;
temp_b = b.smArray[j].row*cb+b.smArray[j].col;
if(temp_a < temp_b)
{
c.smArray[c.Terms] = a.smArray[i];
c.Terms++;
i++;
}
else if(temp_a > temp_b)
{
c.smArray[c.Terms].row = b.smArray[j].row;
c.smArray[c.Terms].col = b.smArray[j].col;
c.smArray[c.Terms].value = -b.smArray[j].value;
c.Terms++;
j++;
}
else
{
c.smArray[c.Terms].value = a.smArray[i].value - b.smArray[j].value;
if(c.smArray[c.Terms].value != 0)
{
c.smArray[c.Terms] = a.smArray[i] - b.smArray[j];
i++;
j++;
c.Terms++;
}
else
{
i++;
j++;
continue;
}
}
}
for(;i<a.Terms;i++)
{
c.smArray[c.Terms] = a.smArray[i];
c.Terms++;
}
for(;j<b.Terms;j++)
{
c.smArray[c.Terms].row = b.smArray[j].row;
c.smArray[c.Terms].col = b.smArray[j].col;
c.smArray[c.Terms].value = -b.smArray[j].value;
c.Terms++;
}
return c;
}
template<class T>
SpareMatrix<T> Mul(SpareMatrix<T> a,SpareMatrix<T> b)
{
SpareMatrix<T> c;
if(a.Cols != b.Rows)
{
cerr<<"矩阵不匹配!\n";
exit(1);
}
if((a.Terms == a.maxTerms)||(b.Terms == b.maxTerms))
{
cerr<<"非零元溢出,一共能存50个非零元!\n";
exit(1);
}
if((a.Terms == 0)||(b.Terms == 0))return c;
int *rowSize = new int[b.Rows]; //存储b中每行非零元的个数
b.rowStart = new int[b.Rows+1]; //存储b中每行非零元在三元组中的起始位置
T temp[10][10]; //暂存每行计算结果
int i,j,Current,lastInResult,RowA,ColA,ColB;
for(i=0;i<b.Rows;i++)
rowSize[i] = 0;
i=0;
for(;i<b.Terms;i++) rowSize[b.smArray[i].row]++; //计算每行非零元
b.rowStart[0] = 0;
for(i=1;i<=b.Rows;i++)
b.rowStart[i] = b.rowStart[i-1] + rowSize[i-1]; //计算每行非零元的起始位置
Current = 0;
lastInResult = -1;
//while(Current < a.Terms)
//{
for(i=0;i<a.Rows;i++)
for(j=0;j<b.Cols;j++)temp[i][j] = 0;
while(Current<a.Terms)
{
RowA = a.smArray[Current].row;
ColA = a.smArray[Current].col;
for(i=b.rowStart[ColA];i<b.rowStart[ColA+1];i++)
{
ColB = b.smArray[i].col;
temp[RowA][ColB] = temp[RowA][ColB] + a.smArray[Current].value*b.smArray[i].value;
}
Current++;
}
for(i=0;i<a.Rows;i++)
for(j=0;j<b.Cols;j++)
if(temp[i][j] != 0)
{
lastInResult++;
c.smArray[lastInResult].row = i;
c.smArray[lastInResult].col = j;
c.smArray[lastInResult].value = temp[i][j];
}
c.Rows = a.Rows;
c.Cols = b.Cols;
c.Terms = lastInResult+1;
delete []rowSize;
delete []b.rowStart;
return c;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -