📄 matrix.cpp
字号:
// Matrix.cpp: implementation of the CMatrix class.
//
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "Matrix.h"
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
//////////////////////////////////////////////////////////////////////
// Construction/Destruction CMatrix
//////////////////////////////////////////////////////////////////////
IMPLEMENT_SERIAL(CMatrix, CObject, 1) ;
#ifdef _DEBUG
int CMatrix::m_NextObjectNumber = 1 ;
#endif
CMatrix::CMatrix()
{
#ifdef _DEBUG
// so we can regonise each individual object
m_ObjectNumber = m_NextObjectNumber++ ;
TRACE("Creating CMatrix object %1d - default constructor\n", m_ObjectNumber) ;
#endif
// default constructor, create a 1 * 1 array
m_NumColumns = 1 ;
m_NumRows = 1 ;
m_pData = NULL ;
m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
IncrementReferenceCount() ; // count the reference to this memory
}
CMatrix::CMatrix(const CMatrix &other)
{
#ifdef _DEBUG
// so we can regonise each individual object
m_ObjectNumber = m_NextObjectNumber++ ;
TRACE("Creating CMatrix object %1d - copy constructor other = %1d\n", m_ObjectNumber, other.m_ObjectNumber) ;
#endif
// copy constructor
m_pData = NULL ;
// use the other objects data pointer
m_NumColumns = other.m_NumColumns ;
m_NumRows = other.m_NumRows ;
m_pData = other.m_pData ; // copy the pointer
IncrementReferenceCount() ; // this thread can get the mutex multiple times without blocking
}
CMatrix::CMatrix(int nCols, int nRows)
{
#ifdef _DEBUG
// so we can regonise each individual object
m_ObjectNumber = m_NextObjectNumber++ ;
TRACE("Creating CMatrix object %1d - size constructor\n", m_ObjectNumber) ;
#endif
// size constructor
ASSERT(nCols > 0) ; // matrix size error
ASSERT(nRows > 0) ; // matrix size error
m_pData = NULL ;
m_NumColumns = nCols ;
m_NumRows = nRows ;
m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
IncrementReferenceCount() ; // count the reference to this memory
}
CMatrix::CMatrix(int size, bool set_diagonal)
{
// construct a square matrix with 1.0's on the diagonal if required
#ifdef _DEBUG
// so we can regonise each individual object
m_ObjectNumber = m_NextObjectNumber++ ;
TRACE("Creating CMatrix object %1d - square size constructor\n", m_ObjectNumber) ;
#endif
// size constructor
ASSERT(size > 0) ; // matrix size error
m_pData = NULL ;
m_NumColumns = size ;
m_NumRows = size ;
m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
IncrementReferenceCount() ; // count the reference to this memory
// set the dialognal if required
if (set_diagonal)
{
for (int i = 0 ; i < size ; ++i)
SetElement(i, i, 1.0) ;
}
}
// creates a CMatrix object from a SafeArray that contains a 2D matrix
// Note that you will probably have to call "VariantClear" to correctly de-allocate
// the safe array you have once you have finished with it.
CMatrix::CMatrix(VARIANT& var)
{
if ((var.vt & VT_ARRAY) == 0)
throw "Not a SafeArray" ;
if ((var.vt & VT_R8) == 0)
throw "Not a double SafeArray" ;
SAFEARRAY* psa = var.parray ; // get a pointer to the safe array
if (psa->cDims != 2)
throw "SafeArray, incorrect number of dimensions" ;
long lBound1, lBound2 ;
long uBound1, uBound2 ;
HRESULT hr ;
// get the bounds of the matrix in the safe array
hr = SafeArrayGetLBound(psa, 1, &lBound1);
if (FAILED(hr))
throw "SafeArray access error" ;
hr = SafeArrayGetUBound(psa, 1, &uBound1);
if (FAILED(hr))
throw "SafeArray access error" ;
hr = SafeArrayGetLBound(psa, 2, &lBound2);
if (FAILED(hr))
throw "SafeArray access error" ;
hr = SafeArrayGetUBound(psa, 2, &uBound2);
if (FAILED(hr))
throw "SafeArray access error" ;
double* dummy = NULL ; // for access to the data
m_NumColumns = uBound1 - lBound1 + 1 ;
m_NumRows = uBound2 - lBound2 + 1 ;
m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
IncrementReferenceCount() ; // count the reference to this memory
SafeArrayAccessData(psa, (void**)(&dummy)) ; // dummy now points to the data
// copy each element across into the matrix to return
for (int i = 0; i < m_NumColumns ; ++i)
{
for (int j = 0; j < m_NumRows ; ++j)
SetElement(i, j, dummy[(i - lBound1) * m_NumRows + j - lBound2]) ;
}
dummy = NULL ; // no longer valid
SafeArrayUnaccessData(psa) ; // release the safe array data pointer
}
CMatrix::~CMatrix()
{
#ifdef _DEBUG
TRACE("Destroying CMatrix object %1d\n", m_ObjectNumber) ;
#endif
DecrementAndRelease() ; // free's m_pData if no other references
}
#ifdef _DEBUG
void CMatrix::Dump(CDumpContext& dc) const
{
UNUSED_PARAMETER(dc) ;
TRACE("CMatrix object #%1d\n", m_ObjectNumber) ;
TRACE("Num columns : %1d\n", m_NumColumns) ;
TRACE("Num rows : %1d\n", m_NumRows) ;
TRACE("Data pointer : %lx\n", m_pData) ;
TRACE("Reference count : %1d\n", GetReferenceCount()) ;
for (int i = 0 ; i < m_NumRows ; ++i)
{
TRACE("Row %2d,", i) ;
for (int j = 0 ; j < m_NumColumns ; ++j)
{
TRACE("%e,", GetElement(j, i)) ;
}
TRACE("\n") ;
Sleep(m_NumColumns *2) ; // this is to allow all element data to be traced for very large matrices!
}
}
void CMatrix::AssertValid() const
{
ASSERT(m_NumColumns > 0) ; // matrix size error
ASSERT(m_NumRows > 0) ; // matrix size error
ASSERT(m_pData) ; // bad pointer error
ASSERT(FALSE == IsBadReadPtr(m_pData, sizeof(double) * (m_NumColumns * m_NumRows + 1))) ;
}
#endif
void CMatrix::Serialize(CArchive& archive)
{
CObject::Serialize(archive) ;
if (archive.IsStoring())
{
// writing the matrix
// write the object header first so we can correctly recognise the object type "CMatrixC"
long header1 = 0x434d6174 ;
long header2 = 0x72697843 ;
int version = 1 ; // serialization version format number
archive << header1 ;
archive << header2 ;
archive << version ; // version number of object type
// now write out the actual matrix
archive << m_NumColumns ;
archive << m_NumRows ;
// this could be done with a archive.Write(m_pData, sizeof(double) * m_NumColumns * m_NumRows)
// for efficiency (dont forget its a flat array). Not done here for clarity
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
{
archive << GetElement(i, j) ;
}
}
// done!
}
else
{
// reading the matrix
// read the object header first so we can correctly recognise the object type "CMatrixC"
long header1 = 0 ;
long header2 = 0 ;
int version = 0 ; // serialization version format
archive >> header1 ;
archive >> header2 ;
if (header1 != 0x434d6174 || header2 != 0x72697843)
{
// incorrect header, cannot load it
AfxThrowArchiveException(CArchiveException::badClass, NULL) ;
}
archive >> version ; // version number of object type
ASSERT(version == 1) ; // only file format number so far
// now write out the actual matrix
int nCols ;
int nRows ;
double value ;
archive >> nCols ;
archive >> nRows ;
CMatrix loading(nCols, nRows) ;
for (int i = 0 ; i < nCols ; ++i)
{
for (int j = 0 ; j < nRows ; ++j)
{
archive >> value ;
loading.SetElement(i, j, value) ;
}
}
*this = loading ; // copy the data into ourselves
// done!
}
}
double* CMatrix::AllocateMemory(int nCols, int nRows)
{
ASSERT(nCols > 0) ; // size error
ASSERT(nRows > 0) ; // size error
// allocates heap memory for an array
double *pData = NULL ;
pData = new double[nCols * nRows + 1] ; // all in one allocation (+1 for reference count)
ASSERT(pData != NULL) ; // allocation error
ASSERT(FALSE == IsBadReadPtr(pData, sizeof(double) * (nCols * nRows + 1))) ;
// empty the memory
memset(pData, 0, sizeof(double) * (nCols * nRows + 1)) ; // starts with a 0 reference count
return pData ;
}
CMatrix& CMatrix::operator=(const CMatrix &other)
{
if (&other == this)
return *this ;
// this does the same job as a copy constructor except we have to de-allocate any
// memory we may have already allocated
DecrementAndRelease() ; // free's m_pData if no other references
// now copy the other matrix into ourselves
// use the other objects data pointer
m_NumColumns = other.m_NumColumns ;
m_NumRows = other.m_NumRows ;
m_pData = other.m_pData ; // copy the pointer
IncrementReferenceCount() ; // this thread can get the mutex multiple times without blocking
// finally return a reference to ourselves
return *this ;
}
bool CMatrix::operator==(const CMatrix &other) const
{
// only return true if the matrices are exactly the same
if (&other == this)
return true ; // comparing to ourselves
if (m_pData == other.m_pData)
return true ; // both pointing to same data, must be same
if (m_NumColumns != other.m_NumColumns || m_NumRows != other.m_NumRows)
return false ; // different dimensions
if (memcmp(m_pData, other.m_pData, sizeof(double) * m_NumColumns * m_NumRows) == 0)
return true ; // buffers are the same
return false ; // must be different
}
CMatrix CMatrix::operator+(const CMatrix &other) const
{
// first check for a valid addition operation
if (m_NumColumns != other.m_NumColumns)
throw "Invalid operation" ;
if (m_NumRows != other.m_NumRows)
throw "Invalid operation" ;
// now that we know that the operation is possible
ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
// construct the object we are going to return
CMatrix result(*this) ; // copy ourselves
// now add in the other matrix
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
result.SetElement(i, j, result.GetElement(i, j) + other.GetElement(i, j)) ;
}
return result ;
}
CMatrix CMatrix::operator-(const CMatrix &other) const
{
// first check for a valid subtraction operation
if (m_NumColumns != other.m_NumColumns)
throw "Invalid operation" ;
if (m_NumRows != other.m_NumRows)
throw "Invalid operation" ;
// now that we know that the operation is possible
ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
// construct the object we are going to return
CMatrix result(*this) ; // copy ourselves
// now subtract the other matrix
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
result.SetElement(i, j, result.GetElement(i, j) - other.GetElement(i, j)) ;
}
return result ;
}
CMatrix CMatrix::operator*(const CMatrix &other) const
{
// first check for a valid multiplication operation
if (m_NumRows != other.m_NumColumns)
throw "Matrices do not have common size" ;
// now that we know that the operation is possible
ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
// construct the object we are going to return
CMatrix result(m_NumColumns, other.m_NumRows) ;
// e.g.
// [A][B][C] [G][H] [A*G + B*I + C*K][A*H + B*J + C*L]
// [D][E][F] * [I][J] = [D*G + E*I + F*K][D*H + E*J + F*L]
// [K][L]
//
double value ;
for (int i = 0 ; i < result.m_NumColumns ; ++i)
{
for (int j = 0 ; j < result.m_NumRows ; ++j)
{
value = 0.0 ;
for (int k = 0 ; k < m_NumRows ; ++k)
{
value += GetElement(i, k) * other.GetElement(k, j) ;
}
result.SetElement(i, j, value) ;
}
}
return result ;
}
void CMatrix::operator+=(const CMatrix &other)
{
// first check for a valid addition operation
if (m_NumColumns != other.m_NumColumns)
throw "Invalid operation" ;
if (m_NumRows != other.m_NumRows)
throw "Invalid operation" ;
// now that we know that the operation is possible
ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
// now add in the other matrix
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
SetElement(i, j, GetElement(i, j) + other.GetElement(i, j)) ;
}
}
void CMatrix::operator-=(const CMatrix &other)
{
// first check for a valid subtraction operation
if (m_NumColumns != other.m_NumColumns)
throw "Invalid operation" ;
if (m_NumRows != other.m_NumRows)
throw "Invalid operation" ;
// now that we know that the operation is possible
ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
// now subtract the other matrix
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
SetElement(i, j, GetElement(i, j) - other.GetElement(i, j)) ;
}
}
void CMatrix::operator*=(const CMatrix &other)
{
// first check for a valid multiplication operation
if (m_NumRows != other.m_NumColumns)
throw "Matrices do not have common size" ;
*this = *this * other ;
}
void CMatrix::operator*=(double value)
{
// just multiply the elements by the value
for (int i = 0 ; i < m_NumColumns ; ++i)
{
for (int j = 0 ; j < m_NumRows ; ++j)
{
SetElement(i, j, GetElement(i, j) * value) ;
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -