📄 tmatrixs.h
字号:
//This is a part of the Supcon Classes C++ library//Copyright (C) 1997-2003 Supcon Software Corporation//All rights reserved.////name:TMatrix.h////aim & functions://实现泛型的矩阵类////create:03-07-03 09:59//modify://03-07-03 09:59////author:zouxiao//changer:#pragma once
#include <windows.h>#include <vector>#include <memory>#include <algorithm>#include <math.h>using namespace std;const HRESULT E_MATRIX_RANG = 0XE0FE0000;const HRESULT E_MATRIX_DIFSIZE=0XE0FE0001;const HRESULT E_MATRIX_MULSIZE=0XE0FE0002;const HRESULT E_MATRIX_BRINVABLE=0XE0FE0003;const HRESULT E_MATRIX_FAIL=0XE0FE0004;//该类实现了泛型的数据存取,为矩阵运算分配存储空间template <class D>class TMatrixAllocator{public: TMatrixAllocator():m_pData(NULL),m_uiRow(0),m_uiCol(0){} virtual ~TMatrixAllocator() { delete []m_pData; m_pData=NULL; } TMatrixAllocator(UINT nrow,UINT ncol): m_uiRow(nrow), m_uiCol(ncol) { m_pData=new D[m_uiRow*m_uiCol]; } inline TMatrixAllocator(const TMatrixAllocator &src):m_pData(NULL),m_uiRow(0),m_uiCol(0) { Copy(src); } inline TMatrixAllocator & operator=(const TMatrixAllocator &src) { Copy(src); return *this; } inline void Attach(TMatrixAllocator &src) { D* p=m_pData; m_pData=src.m_pData; src.m_pData=NULL; m_uiCol=src.m_uiCol; m_uiRow=src.m_uiRow; delete []p; }//克隆一个存储体//存储体单元如果有operator=,则可以控制行为 inline D* Clone(void)const { UINT total=m_uiRow*m_uiCol; D* p=new D[total]; if(!p) return NULL; for(UINT i=0;i<total;i++) { p[i]=m_pData[i]; } return p; } inline const UINT GetCol(void)const { return m_uiCol; } inline const UINT GetRow(void)const { return m_uiRow; } bool SameSize(const TMatrixAllocator &r) { return (r.m_uiRow==m_uiRow&&r.m_uiCol==m_uiCol); }protected: inline void Copy(const TMatrixAllocator &src) { if(m_uiRow!=src.m_uiRow||m_uiCol!=src.m_uiCol) { D* p=src.Clone(); delete []m_pData; m_pData=p; m_uiRow=src.m_uiRow; m_uiCol=src.m_uiCol; } else { UINT num=m_uiCol*m_uiRow; for(UINT i=0;i<num;i++) { m_pData[i]=src.m_pData[i]; } } }//获取某个单元的指针 D* GetElement(UINT row,UINT col)const { if(row>=m_uiRow||col>=m_uiCol) return NULL; return &m_pData[m_uiCol*row+col]; }public://数据存储空间 D* m_pData;//行数和列数 UINT m_uiRow; UINT m_uiCol;};//该类实现了一个泛型的矩阵template <class V> class TMatrix{public: typedef TMatrixAllocator<V> DATA;//==========================constructor & opertor ===========================// TMatrix(){} TMatrix(UINT row,UINT col):m_Data(row,col){} TMatrix(const TMatrix& src):m_Data(src.m_Data){} TMatrix& operator=(const TMatrix& src) { m_Data=src.m_Data; return *this; } //davinic //TMatrix(UINT row,UINT col,V* v):m_Data(row,col) //{
// for(UINT i=0; i<row; ++i)
// for(UINT j=0; j<col; ++j)
// SetValue(i,j,*(v+(col*i+j)));
//} //davinic explicit TMatrix(const V& v):m_Data(1,1) {//一行一列
SetValue(0,0,v);
} TMatrix(UINT row,UINT col,const V& v):m_Data(row,col) {
for(UINT i=0; i<row; ++i)
for(UINT j=0; j<col; ++j)
SetValue(i,j,v);
}//============================access func==================================// inline void SetValue(UINT row,UINT col,V value) { m_Data.m_pData[row*m_Data.m_uiCol+col]=value; } inline const V GetValue(UINT row,UINT col)const { return m_Data.m_pData[row*m_Data.m_uiCol+col]; } //davinic inline const V operator()(UINT i) {////the same functional method as matlab func: A(UINT)
INT r = i%GetRow();
INT c = i/GetRow();
return GetValue(r,c);
} TMatrix Sum()
{//the same functional method as matlab func: sum(A)
TMatrix m(1,GetCol());
for(UINT i=0;i<GetCol();++i)
{
double sum = 0;
for(UINT j=0;j<GetRow();++j)
sum += GetValue(j,i);
m.SetValue(0,i,sum);
}
return m;
} inline UINT Length() {//the same functional method as matlab func: Length(A) if(GetCol()>GetRow()) return GetCol(); else return GetRow(); } ////////////////
void Print()
{
//for(UINT r=0;r<GetRow();r++)
//{
// for(UINT c=0;c<GetCol();c++)
// {
// cout<<setw(15)<<left<<GetValue(r,c);
// }
// cout<<endl;
//}
//cout<<endl;
}
//////////////////============================attri func==================================// inline const UINT GetCol(void)const { return m_Data.m_uiCol; } inline const UINT GetRow(void)const { return m_Data.m_uiRow; }//============================operator func==============================// TMatrix operator+(const TMatrix& right) { TMatrix m(*this); m+=right; return m; } TMatrix operator+(const V& right) {
TMatrix m(*this);
for(UINT i=0; i<GetRow(); ++i)
for(UINT j=0; j<GetCol(); ++j)
m.SetValue(i,j,GetValue(i,j)+right);
return m;
} TMatrix operator-(const TMatrix& right) { TMatrix m(*this); m-=right; return m; } TMatrix operator*(const V gain) { TMatrix m(*this); m*=gain; return m; } TMatrix operator/(const V gain) { TMatrix m(*this); m/=gain; return m; } inline TMatrix operator*(const TMatrix& right) { if(!(GetCol()==right.GetRow())) throw E_MATRIX_MULSIZE; UINT r,c,i; UINT row_num,col_num,r_col_num; r_col_num=right.GetCol(); col_num=GetCol(); row_num=GetRow(); TMatrix m(row_num,r_col_num); V* p,*r_p,*m_p; p=m_Data.m_pData; r_p=right.m_Data.m_pData; m_p=m.m_Data.m_pData; UINT block; for(r=0;r<row_num;r++) { block=r*r_col_num; for(c=0;c<r_col_num;c++) { V value=0; for(i=0;i<col_num;i++) { //value+=GetValue(r,i)*right.GetValue(i,c); value+=p[r*col_num+i]*r_p[i*r_col_num+c]; } //m.SetValue(r,c,value); m_p[block+c]=value; } } return m; } const TMatrix& operator+=(const TMatrix& right) { if(!m_Data.SameSize(right.m_Data)) throw E_MATRIX_DIFSIZE; for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,right.GetValue(r,c)+GetValue(r,c)); } return *this; } const TMatrix& operator-=(const TMatrix& right) { if(!m_Data.SameSize(right.m_Data)) throw E_MATRIX_DIFSIZE; UINT total=GetRow()*GetCol(); V * p,*p_r; p=m_Data.m_pData; p_r=right.m_Data.m_pData; for(UINT r=0;r<total;r++) { p[r]=p[r]-p_r[r]; } return *this; } inline const TMatrix& operator*=(const TMatrix& right) { TMatrix m(*this); m_Data.Attach((m*right).m_Data); return *this; } const TMatrix& DotMul(const TMatrix& right) { if(!m_Data.SameSize(right.m_Data)) throw E_MATRIX_DIFSIZE; for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,GetValue(r,c)*right.GetValue(r,c)); } return *this; } const TMatrix& DotDiv(const TMatrix& right) { if(!m_Data.SameSize(right.m_Data)) throw E_MATRIX_DIFSIZE; for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,GetValue(r,c)/right.GetValue(r,c)); } return *this; } const TMatrix &operator*=(const V gain) { for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,GetValue(r,c)*gain); } return *this; } const TMatrix &operator/=(const V gain) { for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,GetValue(r,c)/gain); } return *this; } template<class Function> void InvokeFunction(const Function func) { for(UINT r=0;r<GetRow();r++) for(UINT c=0;c<GetCol();c++) { SetValue(r,c,func(GetValue(r,c))); } } //用于交换两行的数据 void SwapR(UINT l,UINT r) { if(l>(GetRow()-1) || r> (GetRow()-1)) throw E_MATRIX_RANG; UINT col=GetCol(); UINT row=GetRow(); for(UINT c=0;c<row;c++) swap(m_Data.m_pData[l*col+c],m_Data.m_pData[r*col+c]); } //用于交换两列的数据 void SwapC(UINT l,UINT r) { if(l>(GetCol()-1) || r> (GetCol()-1)) throw E_MATRIX_RANG; UINT row=GetRow(); UINT col=GetCol(); for(UINT r_i=0;r_i<row;r_i++) swap(m_Data.m_pData[r_i*col+l],m_Data.m_pData[r_i*col+r]); } TMatrix SubMatrix(UINT begin,UINT end,BOOL bRow) { if(begin>end) throw E_INVALIDARG; if(bRow) { TMatrix m(end-begin+1,GetCol()); if(end>(GetRow()-1)) throw E_MATRIX_RANG; UINT i=0; for(UINT r=begin;r<=end;r++) { for(UINT c=0;c<GetCol();c++) m.SetValue(i,c,GetValue(r,c)); i++; } return m; } TMatrix m(GetRow(),end-begin+1); if(end>(GetCol()-1)) throw E_MATRIX_RANG; for(UINT r=0;r<GetRow();r++) { UINT i=0; for(UINT c=begin;c<=end;c++) { m.SetValue(r,i,GetValue(r,c)); i++; } } return m; } TMatrix SubMatrixWithout(UINT begin,UINT end,BOOL bRow) { if(begin>end) throw E_INVALIDARG; if(bRow) { TMatrix m(GetRow()-(end-begin+1),GetCol()); if(end>(GetRow()-1)) throw E_MATRIX_RANG; UINT i=0; for(UINT r=0;r<GetRow();r++) { if(r>=begin&&r<=end) continue; for(UINT c=0;c<GetCol();c++) { m.SetValue(i,c,GetValue(r,c)); } i++; } return m; } TMatrix m(GetRow(),GetCol()-(end-begin+1)); if(end>(GetCol()-1)) throw E_MATRIX_RANG; for(UINT r=0;r<GetRow();r++) { UINT i=0; for(UINT c=0;c<GetCol();c++) { if(c>=begin&&c<=end) continue; m.SetValue(r,i,GetValue(r,c)); i++; } } return m; } TMatrix CombineMatrix(UINT begin_s,const TMatrix s,UINT begin,UINT end,BOOL bRow) { if(bRow&&GetCol()!=s.GetCol()) throw E_MATRIX_DIFSIZE; if(!bRow&&GetRow()!=s.GetRow()) throw E_MATRIX_DIFSIZE; if(begin>end) throw E_INVALIDARG; UINT num=end-begin+1; if(bRow) { if(begin_s>GetRow()) throw E_INVALIDARG; if(s.GetRow()<num||s.GetRow()<=end) throw E_MATRIX_RANG; TMatrix m(GetRow()+num,GetCol()); UINT i=0; for(UINT r=0;r<m.GetRow();r++) { for(UINT c=0;c<m.GetCol();c++) { if(r<begin_s) { m.SetValue(r,c,GetValue(r,c)); continue; } if(r>=begin_s&&r<=begin_s+num-1) m.SetValue(r,c,s.GetValue(begin+i,c)); else m.SetValue(r,c,GetValue(r-num,c)); } if(r>=begin_s&&r<=begin_s+num-1) i++; } return m; } if(begin_s>GetCol()) throw E_INVALIDARG; if(s.GetCol()<num||s.GetCol()<=end) throw E_MATRIX_RANG; TMatrix m(GetRow(),GetCol()+num); for(UINT r=0;r<m.GetRow();r++) { UINT i=0; for(UINT c=0;c<m.GetCol();c++) { if(c<begin_s) m.SetValue(r,c,GetValue(r,c)); if(c>=begin_s&&c<=begin_s+num-1) { m.SetValue(r,c,s.GetValue(r,begin+i)); i++; } if(c>begin_s+num-1) { m.SetValue(r,c,GetValue(r,c-num)); } } } return m; }//============================Convert & Brinv==============================// //转置 void Convert(void) { UINT col_num=m_Data.m_uiCol; UINT row_num=m_Data.m_uiRow; TMatrix m(col_num,row_num); for(UINT r=0;r<row_num;r++) for(UINT c=0;c<col_num;c++) { //m.SetValue(c,r,GetValue(r,c)); m.m_Data.m_pData[c*row_num+r]=m_Data.m_pData[r*col_num+c]; } m_Data.Attach(m.m_Data); } //逆变 //2003/8/2为加快逆变的速度,对元素进行直接存取 void Brinv(void) { if(GetRow()!=GetCol()) throw E_MATRIX_BRINVABLE; V /*temp,*/v; auto_ptr<UINT> au_rc(new UINT[GetRow()]); auto_ptr<UINT> au_cc(new UINT[GetRow()]); UINT * rc=au_rc.get(); UINT * cc=au_cc.get(); UINT row_num=GetRow(); V* p_data=m_Data.m_pData; UINT i_row,k_row; for(UINT k=0;k<row_num;k++) { k_row=k*row_num; //查找最大绝对值的元素 V value=0; for(UINT r=k;r<row_num;r++) for(UINT c=k;c<row_num;c++) { v=p_data[r*row_num+c]; if(fabs(v)>value) { rc[k]=r; cc[k]=c; value=fabs(v); } } //如果最大值是零,则该矩阵没有逆阵 if(fabs(value)==0.) throw E_MATRIX_FAIL; //如果该元素不右下子阵的左上角,将其交换到左上角 if(rc[k]!=k) ChangeRow(k,rc[k],row_num); /* for(UINT i=0;i<row_num;i++) { temp=p_data[k*row_num+i]; p_data[k*row_num+i]=p_data[rc[k]*row_num+i]; p_data[rc[k]*row_num+i]=temp; }*/ if(cc[k]!=k) ChangeCol(k,cc[k],row_num); /*for(UINT i=0;i<row_num;i++) { temp=p_data[i*row_num+k]; p_data[i*row_num+k]=p_data[i*row_num+cc[k]]; p_data[i*row_num+cc[k]]=temp; }*/ //运算部分 V v3; //SetValue(k,k,1/GetValue(k,k)); p_data[k*row_num+k]=1/p_data[k*row_num+k]; for(UINT i=0;i<row_num;i++) if(i!=k) //SetValue(k,i,GetValue(k,i)*GetValue(k,k)); p_data[k_row+i]=p_data[k_row+i]*p_data[k_row+k]; for(UINT i=0;i<row_num;i++) { i_row=i*row_num; if(i!=k) for(UINT j=0;j<row_num;j++) if(j!=k) //SetValue(i,j,GetValue(i,j)-GetValue(i,k)*GetValue(k,j)); { v3=p_data[i_row+j]; p_data[i_row+j]=v3-p_data[i_row+k]*p_data[k_row+j]; } } for(UINT i=0;i<row_num;i++) { i_row=i*row_num; if(i!=k) //SetValue(i,k,-GetValue(i,k)*GetValue(k,k)); { p_data[i_row+k]=(-p_data[i_row+k]*p_data[k_row+k]); } } } for(UINT i=row_num;i>0;i--) { UINT k(i-1); if(cc[k]!=k) ChangeRow(k,cc[k],row_num); if(rc[k]!=k) ChangeCol(k,rc[k],row_num); } }//============================Serialize====================================// template <class T> void Serialize(T &ar) { UINT r,c; V val; if(ar.IsStoring()) { ar<<GetRow(); ar<<GetCol(); for(r=0;r<GetRow();r++) for(c=0;c<GetCol();c++) { val=GetValue(r,c); ar<<val; } } else { ar>>r; ar>>c; TMatrix m(r,c); for(r=0;r<m.GetRow();r++) for(c=0;c<m.GetCol();c++) { ar>>val; m.SetValue(r,c,val); } m_Data.Attach(m.m_Data); } } //============================Convert & Brinv==============================// protected: inline void ChangeRow(UINT k,UINT n,UINT row_num) { V temp; UINT k_row=k*row_num; UINT n_row=n*row_num; for(UINT i=0;i<row_num;i++) { /*V temp =GetValue(k,i); SetValue(k,i,GetValue(n,i)); SetValue(n,i,temp);*/ temp=m_Data.m_pData[k_row+i]; m_Data.m_pData[k_row+i]=m_Data.m_pData[n_row+i]; m_Data.m_pData[n_row+i]=(temp); } } inline void ChangeCol(UINT k,UINT n,UINT row_num) { V temp; UINT i_row; for(UINT i=0;i<row_num;i++) { /*V temp =GetValue(i,k); SetValue(i,k,GetValue(i,n)); SetValue(i,n,temp);*/ i_row=i*row_num; temp=m_Data.m_pData[i_row+k]; m_Data.m_pData[i_row+k]=m_Data.m_pData[i_row+n]; m_Data.m_pData[i_row+n]=temp; } }public: DATA m_Data; };
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -