📄 matrix.cpp
字号:
//matrix.cpp
#include "stdafx.h"
#include "matrix.h"
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
////////////////////////////////////////////////////////////////////////////
//默认构造
CMatrix::CMatrix()
{
m_nNumColumns = 1;
m_nNumRows = 1;
m_pData = NULL;
Init(m_nNumRows,m_nNumColumns);
}
////////////////////////////////////////////////////////////////////////////
//指定行列
CMatrix::CMatrix(int nRows,int nCols)
{
m_nNumRows = nRows;
m_nNumColumns = nCols;
m_pData = NULL;
Init(m_nNumRows,m_nNumColumns);
}
////////////////////////////////////////////////////////////////////////////
//指定值
CMatrix::CMatrix(int nRows,int nCols,double value[])
{
m_nNumRows = nRows;
m_nNumColumns = nCols;
m_pData = NULL;
Init(m_nNumRows,m_nNumColumns);
SetData(value);
}
////////////////////////////////////////////////////////////////////////////
//方阵
CMatrix::CMatrix(int nSize)
{
m_nNumRows = nSize;
m_nNumColumns = nSize;
m_pData = NULL;
Init(nSize,nSize);
}
////////////////////////////////////////////////////////////////////////////
//方阵
CMatrix::CMatrix(int nSize,double value[])
{
m_nNumRows = nSize;
m_nNumColumns = nSize;
m_pData = NULL;
Init(nSize,nSize);
SetData(value);
}
////////////////////////////////////////////////////////////////////////////
//拷贝构造
CMatrix::CMatrix(const CMatrix& other)
{
m_nNumRows = other.GetNumRows();
m_nNumColumns = other.GetNumColumns();
m_pData = NULL;
Init(m_nNumRows,m_nNumColumns);
//copy the pointer
memcpy(m_pData,other.m_pData,sizeof( double) * m_nNumRows * m_nNumColumns);
}
///////////////////////////////////////////////////////////////////////////
//初始化,分配矩阵数据的内存,并全部置0
BOOL CMatrix::Init(int nRows,int nCols)
{
if(m_pData)
{
delete[] m_pData;
m_pData = NULL;
}
m_nNumRows = nRows;
m_nNumColumns = nCols;
int nSize = nCols * nRows;
if(nSize < 0)
return FALSE;
//分配内存
m_pData = new double[nSize];
if(m_pData == NULL)
return FALSE;
if(IsBadReadPtr(m_pData,sizeof(double) * nSize))
return FALSE;
//各个元素置零
memset(m_pData,0,sizeof(double) * nSize);
return TRUE;
}
////////////////////////////////////////////////////////////////////////////
//将方阵初始化为方阵
BOOL CMatrix::MakeUnitMatrix(int nSize)
{
if(! Init(nSize,nSize))
return FALSE;
for(int i = 0; i < nSize; ++i)
for(int j = 0; j < nSize; ++j)
if(i == j) SetElement(i,j,1);
return TRUE;
}
////////////////////////////////////////////////////////////////////////////
//析构函数
CMatrix::~CMatrix()
{
if(m_pData)
{
delete[] m_pData;
m_pData = NULL;
}
}
////////////////////////////////////////////////////////////////////////////
//设置矩阵个元素的值
void CMatrix::SetData(double value[])
{
//empty the memory
memset(m_pData,0,sizeof(double) * m_nNumRows * m_nNumColumns);
//copy data
memcpy(m_pData,value,sizeof(double) * m_nNumRows * m_nNumColumns);
}
////////////////////////////////////////////////////////////////////////////
//设置指定元素的值
BOOL CMatrix::SetElement(int nRow,int nCol,double value)
{
//array bounds error
if(nCol < 0 || nCol >= m_nNumColumns || nRow < 0 || nRow >= m_nNumRows)
return FALSE;
if(m_pData == NULL)
return FALSE;
m_pData[nCol + nRow * m_nNumColumns] = value;
return TRUE;
}
////////////////////////////////////////////////////////////////////////////
//得到知道元素的值
double CMatrix::GetElement(int nRow,int nCol) const
{
//array bounds error
ASSERT(nCol >= 0 || nCol <= m_nNumColumns || nRow >= 0 || nRow <= m_nNumRows);
ASSERT(m_pData);
return m_pData[nCol + nRow * m_nNumColumns];
}
////////////////////////////////////////////////////////////////////////////
//获取矩阵的列数
int CMatrix::GetNumColumns() const
{
return m_nNumColumns;
}
////////////////////////////////////////////////////////////////////////////
//获取矩阵的行数
int CMatrix::GetNumRows() const
{
return m_nNumRows;
}
////////////////////////////////////////////////////////////////////////////
//获取矩阵的数据
double* CMatrix::GetData() const
{
return m_pData;
}
////////////////////////////////////////////////////////////////////////////
//获取指定行的向量
int CMatrix::GetRowVector(int nRow,double* pVector) const
{
if(pVector == NULL)
delete pVector;
pVector = new double[m_nNumRows];
ASSERT(pVector != NULL);
for(int j = 0; j< m_nNumRows; ++j)
pVector[j] = GetElement(nRow,j);
return m_nNumRows;
}
////////////////////////////////////////////////////////////////////////////
//获取指定列的向量
int CMatrix::GetColVector(int nCol,double* pVector) const
{
if(pVector == NULL)
delete pVector;
pVector = new double[m_nNumColumns];
ASSERT(pVector != NULL);
for(int j = 0; j< m_nNumColumns; ++j)
pVector[j] = GetElement(nCol,j);
return m_nNumColumns;
}
//数学运算符重载
////////////////////////////////////////////////////////////////////////////
//重载运算符 =
CMatrix& CMatrix::operator=(const CMatrix& other)
{
if(& other != this)
{
m_nNumColumns = other.GetNumColumns();
m_nNumRows = other.GetNumRows();
Init(m_nNumRows,m_nNumColumns);
//copy the pointer
memcpy(m_pData,other.m_pData,sizeof(double) * m_nNumColumns * m_nNumRows);
}
return *this;
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 ==
BOOL CMatrix::operator==(const CMatrix& other) const
{
//行数和列数是否相等
if(m_nNumRows != other.GetNumRows() || m_nNumColumns != other.GetNumColumns())
return FALSE;
for(int i=0;i<m_nNumRows;++i)
for(int j=0;j<m_nNumColumns; ++j)
{
if(GetElement(i,j) != other.GetElement(i,j))
return FALSE;
}
return TRUE;
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 !=
BOOL CMatrix::operator !=(const CMatrix& other) const
{
return !(*this == other);
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 +
CMatrix CMatrix::operator +(const CMatrix& other) const
{
ASSERT(m_nNumRows == other.GetNumRows() && m_nNumColumns == other.GetNumColumns());
CMatrix result(*this); //copy
//add
for(int i = 0; i< m_nNumRows; ++i)
{
for(int j = 0; j < m_nNumColumns; ++j)
result.SetElement(i,j,result.GetElement(i,j) +other.GetElement(i,j));
}
return result;
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 -
CMatrix CMatrix::operator -(const CMatrix& other) const
{
ASSERT(m_nNumRows == other.GetNumRows() && m_nNumColumns == other.GetNumColumns());
CMatrix result(*this); //copy
//add
for(int i = 0; i< m_nNumRows; ++i)
{
for(int j = 0; j < m_nNumColumns; ++j)
result.SetElement(i,j,result.GetElement(i,j) - other.GetElement(i,j));
}
return result;
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 数乘*
CMatrix CMatrix::operator *(double value) const
{
CMatrix result(*this); //copy
//add
for(int i = 0; i< m_nNumRows; ++i)
{
for(int j = 0; j < m_nNumColumns; ++j)
result.SetElement(i,j,result.GetElement(i,j) * value);
}
return result;
}
////////////////////////////////////////////////////////////////////////////
//重载运算符 矩阵相乘*
CMatrix CMatrix::operator *(const CMatrix& other) const
{
ASSERT(m_nNumColumns == other.GetNumRows());
//restruct the object we are going to return
CMatrix result(m_nNumRows,other.GetNumColumns());
//mult
double value;
for(int i = 0; i< result.GetNumRows(); ++i)
{
for(int j = 0; j < other.GetNumColumns(); ++j)
{
value = 0.0;
for(int k = 0 ;k < m_nNumColumns; ++k)
{
value += GetElement(i,k) * other.GetElement(k,j);
}
result.SetElement(i,j,value);
}
}
return result;
}
////////////////////////////////////////////////////////////////////////////
//矩阵转置
CMatrix CMatrix::Transpose() const
{
//aim matrix
CMatrix Trans(m_nNumRows,m_nNumColumns);
//trans each element
for( int i = 0; i< m_nNumRows; ++i)
{
for( int j = 0; j< m_nNumColumns; ++j)
Trans.SetElement(j,i,GetElement(i,j));
}
return Trans;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -