⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 matrix.cpp

📁 最新收集得老外写的一个矩阵类
💻 CPP
📖 第 1 页 / 共 3 页
字号:
// Matrix.cpp: implementation of the CMatrix class.
//
//////////////////////////////////////////////////////////////////////

#include "stdafx.h"
#include "Matrix.h"

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif

//////////////////////////////////////////////////////////////////////
// Construction/Destruction CMatrix
//////////////////////////////////////////////////////////////////////
IMPLEMENT_SERIAL(CMatrix, CObject, 1) ;

#ifdef _DEBUG
int CMatrix::m_NextObjectNumber = 1 ;
#endif

CMatrix::CMatrix()
{
#ifdef _DEBUG
	// so we can regonise each individual object
	m_ObjectNumber = m_NextObjectNumber++ ;
	TRACE("Creating CMatrix object %1d - default constructor\n", m_ObjectNumber) ;
#endif
	// default constructor, create a 1 * 1 array
	m_NumColumns = 1 ;
	m_NumRows = 1 ;
	m_pData = NULL ;
	m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
	IncrementReferenceCount() ;			// count the reference to this memory
}

CMatrix::CMatrix(const CMatrix &other)
{
#ifdef _DEBUG
	// so we can regonise each individual object
	m_ObjectNumber = m_NextObjectNumber++ ;
	TRACE("Creating CMatrix object %1d - copy constructor other = %1d\n", m_ObjectNumber, other.m_ObjectNumber) ;
#endif
	// copy constructor
	m_pData = NULL ;
	// use the other objects data pointer
	m_NumColumns = other.m_NumColumns ;
	m_NumRows = other.m_NumRows ;
	m_pData = other.m_pData ;				// copy the pointer
	IncrementReferenceCount() ;				// this thread can get the mutex multiple times without blocking
}

CMatrix::CMatrix(int nCols, int nRows)
{
#ifdef _DEBUG
	// so we can regonise each individual object
	m_ObjectNumber = m_NextObjectNumber++ ;
	TRACE("Creating CMatrix object %1d - size constructor\n", m_ObjectNumber) ;
#endif
	// size constructor
	ASSERT(nCols > 0) ;					// matrix size error
	ASSERT(nRows > 0) ;					// matrix size error
	m_pData = NULL ;
	m_NumColumns = nCols ;
	m_NumRows = nRows ;
	m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
	IncrementReferenceCount() ;				// count the reference to this memory
}

CMatrix::CMatrix(int size, bool set_diagonal)
{
	// construct a square matrix with 1.0's on the diagonal if required
#ifdef _DEBUG
	// so we can regonise each individual object
	m_ObjectNumber = m_NextObjectNumber++ ;
	TRACE("Creating CMatrix object %1d - square size constructor\n", m_ObjectNumber) ;
#endif
	// size constructor
	ASSERT(size > 0) ;						// matrix size error
	m_pData = NULL ;
	m_NumColumns = size ;
	m_NumRows = size ;
	m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
	IncrementReferenceCount() ;				// count the reference to this memory
	// set the dialognal if required
	if (set_diagonal)
		{
		for (int i = 0 ; i < size ; ++i)
			SetElement(i, i, 1.0) ;
		}
}

// creates a CMatrix object from a SafeArray that contains a 2D matrix
// Note that you will probably have to call "VariantClear" to correctly de-allocate
// the safe array you have once you have finished with it.
CMatrix::CMatrix(VARIANT& var)
{
	if ((var.vt & VT_ARRAY) == 0)
		throw "Not a SafeArray" ;
	if ((var.vt & VT_R8) == 0)
		throw "Not a double SafeArray" ;

	SAFEARRAY*		psa = var.parray ;	// get a pointer to the safe array
	if (psa->cDims != 2)
		throw "SafeArray, incorrect number of dimensions" ;

	long			lBound1, lBound2 ;
	long			uBound1, uBound2 ;
	HRESULT			hr ;

	// get the bounds of the matrix in the safe array
	hr = SafeArrayGetLBound(psa, 1, &lBound1);
	if (FAILED(hr))
		throw "SafeArray access error" ;
	hr = SafeArrayGetUBound(psa, 1, &uBound1);
	if (FAILED(hr))
		throw "SafeArray access error" ;

	hr = SafeArrayGetLBound(psa, 2, &lBound2);
	if (FAILED(hr))
		throw "SafeArray access error" ;
	hr = SafeArrayGetUBound(psa, 2, &uBound2);
	if (FAILED(hr))
		throw "SafeArray access error" ;

	double*		dummy = NULL ;						// for access to the data
	m_NumColumns = uBound1 - lBound1 + 1 ;
	m_NumRows = uBound2 - lBound2 + 1 ;
	m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
	IncrementReferenceCount() ;			// count the reference to this memory

	SafeArrayAccessData(psa, (void**)(&dummy)) ;	// dummy now points to the data
	
	// copy each element across into the matrix to return
	for (int i = 0; i < m_NumColumns ; ++i)
		{
		for (int j = 0; j < m_NumRows ; ++j)
			SetElement(i, j, dummy[(i - lBound1) * m_NumRows + j - lBound2]) ;
		}
	dummy = NULL ;									// no longer valid
	SafeArrayUnaccessData(psa) ;					// release the safe array data pointer
}

CMatrix::~CMatrix()
{
#ifdef _DEBUG
	TRACE("Destroying CMatrix object %1d\n", m_ObjectNumber) ;
#endif
	DecrementAndRelease() ;			// free's m_pData if no other references
}

#ifdef _DEBUG
void CMatrix::Dump(CDumpContext& dc) const
{
	UNUSED_PARAMETER(dc) ;
	TRACE("CMatrix object    #%1d\n", m_ObjectNumber) ;
	TRACE("Num columns     : %1d\n", m_NumColumns) ;
	TRACE("Num rows        : %1d\n", m_NumRows) ;
	TRACE("Data pointer    : %lx\n", m_pData) ;
	TRACE("Reference count : %1d\n", GetReferenceCount()) ;
	for (int i = 0 ; i < m_NumRows ; ++i)
		{
		TRACE("Row %2d,", i) ;
		for (int j = 0 ; j < m_NumColumns ; ++j)
			{
			TRACE("%e,", GetElement(j, i)) ;
			}
		TRACE("\n") ;
		Sleep(m_NumColumns *2) ;		// this is to allow all element data to be traced for very large matrices!
		}
}

void CMatrix::AssertValid() const
{
	ASSERT(m_NumColumns > 0) ;							// matrix size error
	ASSERT(m_NumRows > 0) ;							// matrix size error
	ASSERT(m_pData) ;								// bad pointer error
	ASSERT(FALSE == IsBadReadPtr(m_pData, sizeof(double) * (m_NumColumns * m_NumRows + 1))) ;
}
#endif

void CMatrix::Serialize(CArchive& archive)
{
	CObject::Serialize(archive) ;
	if (archive.IsStoring())
		{
		// writing the matrix
		// write the object header first so we can correctly recognise the object type "CMatrixC"
		long	header1 = 0x434d6174 ;
		long	header2 = 0x72697843 ;
		int		version = 1 ;						// serialization version format number

		archive << header1 ;
		archive << header2 ;
		archive << version ;						// version number of object type

		// now write out the actual matrix
		archive << m_NumColumns ;
		archive << m_NumRows ;
		// this could be done with a archive.Write(m_pData, sizeof(double) * m_NumColumns * m_NumRows)
		// for efficiency (dont forget its a flat array). Not done here for clarity
		for (int i = 0 ; i < m_NumColumns ; ++i)
			{
			for (int j = 0 ; j < m_NumRows ; ++j)
				{
				archive << GetElement(i, j) ;		
				}
			}
		// done!
		}
	else
		{
		// reading the matrix
		// read the object header first so we can correctly recognise the object type "CMatrixC"
		long	header1 = 0 ;
		long	header2 = 0 ;
		int		version = 0 ;				// serialization version format

		archive >> header1 ;
		archive >> header2 ;
		if (header1 != 0x434d6174 || header2 != 0x72697843)
			{
			// incorrect header, cannot load it
			AfxThrowArchiveException(CArchiveException::badClass, NULL) ;
			}
		archive >> version ;				// version number of object type
		ASSERT(version == 1) ;				// only file format number so far

		// now write out the actual matrix
		int		nCols ;
		int		nRows ;
		double	value ;
		archive >> nCols ;
		archive >> nRows ;
		CMatrix loading(nCols, nRows) ;
		for (int i = 0 ; i < nCols ; ++i)
			{
			for (int j = 0 ; j < nRows ; ++j)
				{
				archive >> value ;
				loading.SetElement(i, j, value) ;
				}
			}
		*this = loading ;					// copy the data into ourselves
		// done!
		}
}

double* CMatrix::AllocateMemory(int nCols, int nRows)
{
	ASSERT(nCols > 0) ;							// size error
	ASSERT(nRows > 0) ;							// size error
	// allocates heap memory for an array
	double	*pData = NULL ;

	pData = new double[nCols * nRows + 1] ;			// all in one allocation (+1 for reference count)
	ASSERT(pData != NULL) ;					// allocation error
	ASSERT(FALSE == IsBadReadPtr(pData, sizeof(double) * (nCols * nRows + 1))) ;
	// empty the memory
	memset(pData, 0, sizeof(double) * (nCols * nRows + 1)) ;		// starts with a 0 reference count
	return pData ;
}

CMatrix& CMatrix::operator=(const CMatrix &other)
{
	if (&other == this)
		return *this ;
	// this does the same job as a copy constructor except we have to de-allocate any
	// memory we may have already allocated
	DecrementAndRelease() ;			// free's m_pData if no other references
	// now copy the other matrix into ourselves
	// use the other objects data pointer
	m_NumColumns = other.m_NumColumns ;
	m_NumRows = other.m_NumRows ;
	m_pData = other.m_pData ;				// copy the pointer
	IncrementReferenceCount() ;				// this thread can get the mutex multiple times without blocking
	// finally return a reference to ourselves
	return *this ;
}

bool CMatrix::operator==(const CMatrix &other) const
{
	// only return true if the matrices are exactly the same
	if (&other == this)
		return true ;		// comparing to ourselves
	if (m_pData == other.m_pData)
		return true ;		// both pointing to same data, must be same
	if (m_NumColumns != other.m_NumColumns || m_NumRows != other.m_NumRows)
		return false ;		// different dimensions
	if (memcmp(m_pData, other.m_pData, sizeof(double) * m_NumColumns * m_NumRows) == 0)
		return true ;		// buffers are the same
	return false ;			// must be different
}

CMatrix CMatrix::operator+(const CMatrix &other) const
{
	// first check for a valid addition operation
	if (m_NumColumns != other.m_NumColumns)
		throw "Invalid operation" ;
	if (m_NumRows != other.m_NumRows)
		throw "Invalid operation" ;
	// now that we know that the operation is possible
	ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
	// construct the object we are going to return
	CMatrix		result(*this) ;		// copy ourselves
	// now add in the other matrix
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			result.SetElement(i, j, result.GetElement(i, j) + other.GetElement(i, j)) ;
		}
	return result ;
}

CMatrix CMatrix::operator-(const CMatrix &other) const
{
	// first check for a valid subtraction operation
	if (m_NumColumns != other.m_NumColumns)
		throw "Invalid operation" ;
	if (m_NumRows != other.m_NumRows)
		throw "Invalid operation" ;
	// now that we know that the operation is possible
	ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
	// construct the object we are going to return
	CMatrix		result(*this) ;		// copy ourselves
	// now subtract the other matrix
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			result.SetElement(i, j, result.GetElement(i, j) - other.GetElement(i, j)) ;
		}
	return result ;
}

CMatrix CMatrix::operator*(const CMatrix &other) const
{
	// first check for a valid multiplication operation
	if (m_NumRows != other.m_NumColumns)
		throw "Matrices do not have common size" ;
	// now that we know that the operation is possible
	ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
	// construct the object we are going to return
	CMatrix		result(m_NumColumns, other.m_NumRows) ;

	// e.g.
	// [A][B][C]   [G][H]     [A*G + B*I + C*K][A*H + B*J + C*L]
	// [D][E][F] * [I][J] =   [D*G + E*I + F*K][D*H + E*J + F*L]
	//             [K][L]
	//
	double	 value ;
	for (int i = 0 ; i < result.m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < result.m_NumRows ; ++j)
			{
			value = 0.0 ;
			for (int k = 0 ; k < m_NumRows ; ++k)
				{
				value += GetElement(i, k) * other.GetElement(k, j) ;
				}
			result.SetElement(i, j, value) ;
			}
		}
	return result ;
}

void CMatrix::operator+=(const CMatrix &other)
{
	// first check for a valid addition operation
	if (m_NumColumns != other.m_NumColumns)
		throw "Invalid operation" ;
	if (m_NumRows != other.m_NumRows)
		throw "Invalid operation" ;
	// now that we know that the operation is possible
	ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
	// now add in the other matrix
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			SetElement(i, j, GetElement(i, j) + other.GetElement(i, j)) ;
		}
}

void CMatrix::operator-=(const CMatrix &other)
{
	// first check for a valid subtraction operation
	if (m_NumColumns != other.m_NumColumns)
		throw "Invalid operation" ;
	if (m_NumRows != other.m_NumRows)
		throw "Invalid operation" ;
	// now that we know that the operation is possible
	ASSERT(FALSE == IsBadReadPtr(other.m_pData, sizeof(double) * other.m_NumColumns * other.m_NumRows)) ;
	// now subtract the other matrix
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			SetElement(i, j, GetElement(i, j) - other.GetElement(i, j)) ;
		}
}

void CMatrix::operator*=(const CMatrix &other)
{
	// first check for a valid multiplication operation
	if (m_NumRows != other.m_NumColumns)
		throw "Matrices do not have common size" ;

	*this = *this * other ;
}

void CMatrix::operator*=(double value)
{
	// just multiply the elements by the value
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			SetElement(i, j, GetElement(i, j) * value) ;
			}
		}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -