📄 matrix.cpp
字号:
#include "stdafx.h"
#include "Matrix.h"
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
CMatrix::CMatrix()
{
m_nNumColumns = 1;
m_nNumRows = 1;
m_pData = NULL;
BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
ASSERT(bSuccess);
}
//--------------------指定行列构造函数--------------------
CMatrix::CMatrix(int nRows, int nCols)
{
m_nNumRows = nRows;
m_nNumColumns = nCols;
m_pData = NULL;
BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
ASSERT(bSuccess);
}
//--------------------指定值构造函数--------------------
CMatrix::CMatrix(int nRows, int nCols, double value[])
{
m_nNumRows = nRows;
m_nNumColumns = nCols;
m_pData = NULL;
BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
ASSERT(bSuccess);
SetData(value);
}
//--------------------方阵构造函数--------------------
CMatrix::CMatrix(int nSize)
{
m_nNumRows = nSize;
m_nNumColumns = nSize;
m_pData = NULL;
BOOL bSuccess = Init(nSize, nSize);
ASSERT (bSuccess);
}
//--------------------方阵构造函数--------------------
CMatrix::CMatrix(int nSize, double value[])
{
m_nNumRows = nSize;
m_nNumColumns = nSize;
m_pData = NULL;
BOOL bSuccess = Init(nSize, nSize);
ASSERT (bSuccess);
SetData(value);
}
//--------------------拷贝构造函数--------------------
CMatrix::CMatrix(const CMatrix& other)
{
m_nNumColumns = other.GetNumColumns();
m_nNumRows = other.GetNumRows();
m_pData = NULL;
BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
ASSERT(bSuccess);
memcpy(m_pData, other.m_pData, sizeof(double)*m_nNumColumns*m_nNumRows);
}
//--------------------析构函数--------------------
CMatrix::~CMatrix()
{
if (m_pData)
{
delete[] m_pData;
m_pData = NULL;
}
}
//--------------------初始化函数--------------------
BOOL CMatrix::Init(int nRows, int nCols)
{
if (m_pData)
{
delete[] m_pData;
m_pData = NULL;
}
m_nNumRows = nRows;
m_nNumColumns = nCols;
int nSize = nCols*nRows;
if (nSize < 0)
return FALSE;
// 分配内存
m_pData = new double[nSize];
if (m_pData == NULL)
return FALSE; // 内存分配失败
if (IsBadReadPtr(m_pData, sizeof(double) * nSize))
return FALSE;
// 将各元素值置0
memset(m_pData, 0, sizeof(double) * nSize);
return TRUE;
}
//--------------------设置矩阵各元素的值--------------------
void CMatrix::SetData(double value[])
{
memset(m_pData, 0, sizeof(double) * m_nNumColumns*m_nNumRows);
memcpy(m_pData, value, sizeof(double)*m_nNumColumns*m_nNumRows);
}
//--------------------设置指定元素的值--------------------
BOOL CMatrix::SetElement(int nRow, int nCol, double value)
{
if (nCol < 0 || nCol >= m_nNumColumns || nRow < 0 || nRow >= m_nNumRows)
return FALSE; // array bounds error
if (m_pData == NULL)
return FALSE; // bad pointer error
m_pData[nCol + nRow * m_nNumColumns] = value;
return TRUE;
}
//--------------------设置指定元素的值--------------------
double CMatrix::GetElement(int nRow, int nCol) const
{
ASSERT(nCol >= 0 && nCol < m_nNumColumns && nRow >= 0 && nRow < m_nNumRows);
ASSERT(m_pData); // bad pointer error
return m_pData[nCol + nRow * m_nNumColumns] ;
}
//--------------------获取矩阵的列数--------------------
int CMatrix::GetNumColumns() const
{
return m_nNumColumns;
}
//--------------------获取矩阵的行数--------------------
int CMatrix::GetNumRows() const
{
return m_nNumRows;
}
//--------------------获取矩阵的数据--------------------
double* CMatrix::GetData() const
{
return m_pData;
}
//--------------------重载运算符=,给矩阵赋值--------------------
CMatrix& CMatrix::operator=(const CMatrix& other)
{
if (&other != this)
{
BOOL bSuccess = Init(other.GetNumRows(), other.GetNumColumns());
ASSERT(bSuccess);
memcpy(m_pData, other.m_pData, sizeof(double)*m_nNumColumns*m_nNumRows);
}
return *this ;
}
//--------------------重载运算符==,判断矩阵是否相等--------------------
BOOL CMatrix::operator==(const CMatrix& other) const
{
// 首先检查行列数是否相等
if (m_nNumColumns != other.GetNumColumns() || m_nNumRows != other.GetNumRows())
return FALSE;
for (int i=0; i<m_nNumRows; ++i)
{
for (int j=0; j<m_nNumColumns; ++j)
{
if (GetElement(i, j) != other.GetElement(i, j))
return FALSE;
}
}
return TRUE;
}
//--------------------重载运算符!=,判断矩阵是否不相等--------------------
BOOL CMatrix::operator!=(const CMatrix& other) const
{
return !(*this == other);
}
//--------------------重载运算符+,实现矩阵的加法--------------------
CMatrix CMatrix::operator+(const CMatrix& other) const
{
// 首先检查行列数是否相等
ASSERT (m_nNumColumns == other.GetNumColumns() && m_nNumRows == other.GetNumRows());
// 构造结果矩阵
CMatrix result(*this) ;
for (int i = 0 ; i < m_nNumRows ; ++i)
{
for (int j = 0 ; j < m_nNumColumns; ++j)
result.SetElement(i, j, result.GetElement(i, j) + other.GetElement(i, j)) ;
}
return result ;
}
//--------------------重载运算符-,实现矩阵的减法--------------------
CMatrix CMatrix::operator-(const CMatrix& other) const
{
ASSERT (m_nNumColumns == other.GetNumColumns() && m_nNumRows == other.GetNumRows());
CMatrix result(*this) ;
for (int i = 0 ; i < m_nNumRows ; ++i)
{
for (int j = 0 ; j < m_nNumColumns; ++j)
result.SetElement(i, j, result.GetElement(i, j) - other.GetElement(i, j)) ;
}
return result ;
}
//--------------------重载运算符*,实现矩阵的数乘--------------------
CMatrix CMatrix::operator*(double value) const
{
CMatrix result(*this) ;
for (int i = 0 ; i < m_nNumRows ; ++i)
{
for (int j = 0 ; j < m_nNumColumns; ++j)
result.SetElement(i, j, result.GetElement(i, j) * value) ;
}
return result ;
}
//--------------------重载运算符*,实现矩阵的乘法--------------------
CMatrix CMatrix::operator*(const CMatrix& other) const
{
// 首先检查行列数是否符合要求
ASSERT (m_nNumColumns == other.GetNumRows());
CMatrix result(m_nNumRows, other.GetNumColumns()) ;
double value ;
for (int i = 0 ; i < result.GetNumRows() ; ++i)
{
for (int j = 0 ; j < other.GetNumColumns() ; ++j)
{
value = 0.0 ;
for (int k = 0 ; k < m_nNumColumns ; ++k)
{
value += GetElement(i, k) * other.GetElement(k, j) ;
}
result.SetElement(i, j, value) ;
}
}
return result ;
}
//--------------------矩阵的转置--------------------
CMatrix CMatrix::Transpose() const
{
// 构造目标矩阵
CMatrix Trans(m_nNumColumns, m_nNumRows);
// 转置各元素
for (int i = 0 ; i < m_nNumRows ; ++i)
{
for (int j = 0 ; j < m_nNumColumns ; ++j)
Trans.SetElement(j, i, GetElement(i, j)) ;
}
return Trans;
}
//-----------------------------奇异值分解求广义逆矩阵的实现---------------------------------
BOOL CMatrix::GInvertUV(CMatrix& mtxAP, CMatrix& mtxU, CMatrix& mtxV, double eps)
{
int i,j,k,l,t,p,q,f;
// 调用奇异值分解
if (! SplitUV(mtxU, mtxV, eps))
return FALSE;
int m = m_nNumRows;
int n = m_nNumColumns;
// 初始化广义逆矩阵
if (! mtxAP.Init(n, m))
return FALSE;
// 计算广义逆矩阵
j=n;
if (m<n)
j=m;
j=j-1;
k=0;
while ((k<=j)&&(m_pData[k*n+k]!=0.0))
k=k+1;
k=k-1;
for (i=0; i<=n-1; i++)
{
for (j=0; j<=m-1; j++)
{
t=i*m+j;
mtxAP.m_pData[t]=0.0;
for (l=0; l<=k; l++)
{
f=l*n+i;
p=j*m+l;
q=l*n+l;
mtxAP.m_pData[t]=mtxAP.m_pData[t]+mtxV.m_pData[f]*mtxU.m_pData[p]/m_pData[q];
}
}
}
return TRUE;
}
BOOL CMatrix::SplitUV(CMatrix& mtxU, CMatrix& mtxV, double eps) //矩阵的奇异值分解实现
{
int i,j,k,l,it,ll,kk,ix,iy,mm,nn,iz,m1,ks;
double d,dd,t,sm,sm1,em1,sk,ek,b,c,shh,fg[2],cs[2];
double *s,*e,*w;
int m = m_nNumRows;
int n = m_nNumColumns;
// 初始化U, V矩阵
if (! mtxU.Init(m, m) || ! mtxV.Init(n, n))
return FALSE;
// 临时缓冲区
int ka = max(m, n) + 1;
s = new double[ka];
e = new double[ka];
w = new double[ka];
// 指定迭代次数为60
it=60;
k=n;
if (m-1<n)
k=m-1;
l=m;
if (n-2<m)
l=n-2;
if (l<0)
l=0;
// 循环迭代计算
ll=k;
if (l>k)
ll=l;
if (ll>=1)
{
for (kk=1; kk<=ll; kk++)
{
if (kk<=k)
{
d=0.0;
for (i=kk; i<=m; i++)
{
ix=(i-1)*n+kk-1;
d=d+m_pData[ix]*m_pData[ix];
}
s[kk-1]=sqrt(d);
if (s[kk-1]!=0.0)
{
ix=(kk-1)*n+kk-1;
if (m_pData[ix]!=0.0)
{
s[kk-1]=fabs(s[kk-1]);
if (m_pData[ix]<0.0)
s[kk-1]=-s[kk-1];
}
for (i=kk; i<=m; i++)
{
iy=(i-1)*n+kk-1;
m_pData[iy]=m_pData[iy]/s[kk-1];
}
m_pData[ix]=1.0+m_pData[ix];
}
s[kk-1]=-s[kk-1];
}
if (n>=kk+1)
{
for (j=kk+1; j<=n; j++)
{
if ((kk<=k)&&(s[kk-1]!=0.0))
{
d=0.0;
for (i=kk; i<=m; i++)
{
ix=(i-1)*n+kk-1;
iy=(i-1)*n+j-1;
d=d+m_pData[ix]*m_pData[iy];
}
d=-d/m_pData[(kk-1)*n+kk-1];
for (i=kk; i<=m; i++)
{
ix=(i-1)*n+j-1;
iy=(i-1)*n+kk-1;
m_pData[ix]=m_pData[ix]+d*m_pData[iy];
}
}
e[j-1]=m_pData[(kk-1)*n+j-1];
}
}
if (kk<=k)
{
for (i=kk; i<=m; i++)
{
ix=(i-1)*m+kk-1;
iy=(i-1)*n+kk-1;
mtxU.m_pData[ix]=m_pData[iy];
}
}
if (kk<=l)
{
d=0.0;
for (i=kk+1; i<=n; i++)
d=d+e[i-1]*e[i-1];
e[kk-1]=sqrt(d);
if (e[kk-1]!=0.0)
{
if (e[kk]!=0.0)
{
e[kk-1]=fabs(e[kk-1]);
if (e[kk]<0.0)
e[kk-1]=-e[kk-1];
}
for (i=kk+1; i<=n; i++)
e[i-1]=e[i-1]/e[kk-1];
e[kk]=1.0+e[kk];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -