⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 matrix.cpp

📁 最新收集得老外写的一个矩阵类
💻 CPP
📖 第 1 页 / 共 3 页
字号:

// CMatrixHelper is only used for this to simulate a CMatrix::operator[][]
CMatrixHelper CMatrix::operator[](int nCol) const
{
	ASSERT(nCol >= 0) ;							// array bounds error
	ASSERT(nCol < m_NumColumns) ;				// array bounds error
	// construc the CMatrixHelper object to allow operator[][] to work
	CMatrixHelper	mh(this, nCol) ;
	return mh ;
}

CMatrixHelper CMatrix::operator[](int nCol)
{
	ASSERT(nCol >= 0) ;							// array bounds error
	ASSERT(nCol < m_NumColumns) ;				// array bounds error
	// construc the CMatrixHelper object to allow operator[][] to work
	CMatrixHelper	mh(this, nCol) ;
	return mh ;
}

bool CMatrix::SetElement(int nCol, int nRow, double value)
{
	// first check the reference count on our data object to see whether we need to create a copy
	if (GetReferenceCount() > 1)
		{
		// we need to make a copy
		double	*pData = m_pData ;				// take a copy of the pointer
		DecrementReferenceCount() ;				// decrement the current reference count
		m_pData = AllocateMemory(m_NumColumns, m_NumRows) ;
		memcpy(m_pData, pData, sizeof(double) * m_NumColumns * m_NumRows) ;
		IncrementReferenceCount() ;				// increment the new data's reference count
		}
	ASSERT(nCol >= 0) ;							// array bounds error
	ASSERT(nCol < m_NumColumns) ;						// array bounds error
	ASSERT(nRow >= 0) ;							// array bounds error
	ASSERT(nRow < m_NumRows) ;						// array bounds error
	ASSERT(m_pData) ;							// bad pointer error
	m_pData[nCol + nRow * m_NumColumns] = value ;
	return true ;
}

#ifdef _DEBUG
// release version is in-line
double CMatrix::GetElement(int nCol, int nRow) const
{
	ASSERT(nCol >= 0) ;							// array bounds error
	ASSERT(nCol < m_NumColumns) ;						// array bounds error
	ASSERT(nRow >= 0) ;							// array bounds error
	ASSERT(nRow < m_NumRows) ;						// array bounds error
	ASSERT(m_pData) ;							// bad pointer error
	return m_pData[nCol + nRow * m_NumColumns] ;
}
#endif

// 
// To avoid big hits when constructing and assigning CMatrix objects, multiple CMatrix's can reference
// the same m_pData member. Only when a matrix becomes different from the other does a new version of the array
// get created and worked with.
//
void CMatrix::IncrementReferenceCount()
{
	// get a pointer to the end of the m_pData object where the reference count resides
	int*	pReference = (int*)&m_pData[m_NumColumns * m_NumRows] ;
	++(*pReference) ;				// increment the reference count
	// done!
}

void CMatrix::DecrementReferenceCount()
{
	// get a pointer to the end of the m_pData object where the reference count resides
	int*	pReference = (int*)&m_pData[m_NumColumns * m_NumRows] ;
	--(*pReference) ;				// decrement the reference count
	// done!
}

void CMatrix::DecrementAndRelease()
{
	// get a pointer to the end of the m_pData object where the reference count resides
	int*	pReference = (int*)&m_pData[m_NumColumns * m_NumRows] ;
	--(*pReference) ;				// decrement the reference count
	if (*pReference == 0)
		{
		// the memory is no longer needed, release it
		delete []m_pData ;
		m_pData = NULL ;
		}
	// done!
}

int CMatrix::GetReferenceCount() const
{
	// get a pointer to the end of the m_pData object where the reference count resides
	int*	pReference = (int*)&m_pData[m_NumColumns * m_NumRows] ;
	return *pReference ;
}

CMatrix CMatrix::GetTransposed() const
{
	CMatrix	transposed(*this) ;		// make a copy of ourselves

	transposed.Transpose() ;
	return transposed ;
}

void CMatrix::Transpose()
{
	// first check the reference count on our data object to see whether we need to create a copy
	CMatrix	mcopy(*this) ;
	// swap the x/y values
	int	copy = m_NumColumns ;
	m_NumColumns = m_NumRows ;
	m_NumRows = copy ;
	// copy across the transposed data
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			SetElement(i, j, mcopy.GetElement(j, i)) ;
		}
}

CMatrix CMatrix::GetInverted() const
{
	// matrix inversion will only work on square matrices
	if (m_NumColumns != m_NumRows)
		throw "Matrix must be square." ;
	// return this matrix inverted
	CMatrix	copy(*this) ;
	copy.Invert() ;

	return copy ;
}

void CMatrix::Invert()
{
	// matrix inversion will only work on square matrices
	// invert ourselves
	if (m_NumColumns != m_NumRows)
		throw "Matrix must be square." ;
	double	e ;

	for (int k = 0 ; k < m_NumColumns ; ++k)
		{
		e = GetElement(k, k) ;
		SetElement(k, k, 1.0) ;
		if (e == 0.0)
			throw "Matrix inversion error" ;
		for (int j = 0 ; j < m_NumColumns ; ++j)
			SetElement(k, j, GetElement(k, j) / e) ;
		for (int i = 0 ; i < m_NumColumns ; ++i)
			{
			if (i != k)
				{
				e = GetElement(i, k) ;
				SetElement(i, j, 0.0) ;
				for (j = 0 ; j < m_NumColumns ; ++j)
					SetElement(i, j, GetElement(i, j) - e * GetElement(k, j)) ;
				}
			}
		}
}

// A' * A
CMatrix CMatrix::Covariant() const
{
	CMatrix	result ;
	CMatrix trans(GetTransposed()) ;

	result = *this * trans ;
	return result ;
}

CMatrix CMatrix::ExtractSubMatrix(int col_start, int row_start, int col_size, int row_size) const
{
	ASSERT(col_start >= 0) ;						// bad start index
	ASSERT(row_start >= 0) ;						// bad start index
	ASSERT(col_size > 0) ;						// bad size
	ASSERT(row_size > 0) ;						// bad size
	// make sure the requested sub matrix is in the current matrix
	if (col_start + col_size >= m_NumColumns)
		throw "Sub matrix is not contained in source" ;
	if (row_start + row_size >= m_NumRows)
		throw "Sub matrix is not contained in source" ;

	CMatrix sub(col_size, row_size) ;

	for (int i = 0 ; i < col_size ; ++i)
		{
		for (int j = 0 ; j < row_size ; ++j)
			{
			sub.SetElement(i, j, GetElement(col_start + i, row_start + j)) ;
			}
		}
	return sub ;
}

void CMatrix::SetSubMatrix(int col_start, int row_start, const CMatrix &other)
{
	ASSERT(col_start >= 0) ;						// bad start index
	ASSERT(row_start >= 0) ;						// bad start index
	ASSERT(col_start + other.m_NumColumns <= m_NumColumns) ;	// bad size
	ASSERT(row_start + other.m_NumRows <= m_NumRows) ;	// bad size
	for (int i = 0 ; i < other.m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < other.m_NumRows ; ++j)
			{
			SetElement(col_start + i, row_start + j, other.GetElement(i, j)) ;
			}
		}
}

CMatrix CMatrix::ExtractDiagonal() const
{
	if (m_NumColumns != m_NumRows)
		throw "Can only extract diagonal from square matrix" ;
	CMatrix	diagonal(m_NumColumns, 1) ;

	for (int i = 0 ; i < m_NumColumns ; ++i)
		diagonal.SetElement(i, 0, GetElement(i, i)) ;
	return diagonal ;
}


CMatrix CMatrix::GetConcatinatedColumns(const CMatrix& other) const
{
	if (m_NumRows != other.m_NumRows)
		throw "Cannot concatenate matrices, not same size" ;
	// copy ourselves and then return the concatenated result
	CMatrix		result(*this) ;

	result.ConcatinateColumns(other) ;
	return result ;
}

// concatinate the other matrix to ourselves
void CMatrix::ConcatinateColumns(const CMatrix &other)
{
	if (m_NumRows != other.m_NumRows)
		throw "Cannot concatenate matrices, not same size" ;
	// create a matrix big enough to hold both
	CMatrix		result(m_NumColumns + other.m_NumColumns, m_NumRows) ;

	// now populate it
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			result.SetElement(i, j, GetElement(i, j)) ;
			}
		}
	// now add the other matrix
	for (i = 0 ; i < other.m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			result.SetElement(i + m_NumColumns, j, other.GetElement(i, j)) ;
			}
		}
	*this = result ;					// assign it to us
}

CMatrix CMatrix::GetConcatinatedRows(const CMatrix& other) const
{
	if (m_NumColumns != other.m_NumColumns)
		throw "Cannot concatenate matrices, not same size" ;
	// copy ourselves and then return the concatenated result
	CMatrix		result(*this) ;

	result.ConcatinateRows(other) ;
	return result ;
}

void CMatrix::ConcatinateRows(const CMatrix &other)
{
	if (m_NumColumns != other.m_NumColumns)
		throw "Cannot concatenate matrices, not same size" ;
	// create a matrix big enough to hold both
	CMatrix		result(m_NumColumns, m_NumRows + other.m_NumRows) ;

	// now populate it
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			result.SetElement(i, j, GetElement(i, j)) ;
			}
		}
	// now add the other matrix
	for (i = 0 ; i < other.m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			result.SetElement(i, j + m_NumRows, other.GetElement(i, j)) ;
			}
		}
	*this = result ;					// assign it to us
}

void CMatrix::AddColumn(const double *pData)
{
	ASSERT(FALSE == IsBadReadPtr(pData, sizeof(double) * m_NumRows)) ;

	CMatrix	result(m_NumColumns + 1, m_NumRows) ;			// costruct the result

	result.SetSubMatrix(0, 0, *this) ;				// copy ouselves across
	// now add the new row
	for (int i = 0 ; i < m_NumRows ; ++i)
		{
		result.SetElement(m_NumColumns, i, pData[i]) ;
		}
	*this = result ;								// assign result to us
}

void CMatrix::AddRow(const double *pData)
{
	ASSERT(FALSE == IsBadReadPtr(pData, sizeof(double) * m_NumColumns)) ;

	CMatrix	result(m_NumColumns, m_NumRows + 1) ;			// costruct the result

	result.SetSubMatrix(0, 0, *this) ;				// copy ouselves across
	// now add the new row
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		result.SetElement(i, m_NumRows, pData[i]) ;
		}
	*this = result ;								// assign result to us
}

CMatrix	operator*(const CMatrix &other, double value)
{
	CMatrix copy(other) ;

	// just multiply the elements by the value
	for (int i = 0 ; i < copy.m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < copy.m_NumRows ; ++j)
			{
			copy.SetElement(i, j, copy.GetElement(i, j) * value) ;
			}
		}
	return copy ;
}

CMatrix CMatrix::GetSquareMatrix() const
{
	CMatrix	copy(*this) ;

	copy.MakeSquare() ;
	return copy ;
}

void CMatrix::MakeSquare()
{
	// make the current matrix square by either stepping in the x or y directions
	// square to the smallest side
	int size = m_NumColumns ;
	if (size > m_NumRows)
		size = m_NumRows ;

	CMatrix	work(size, size) ;				// construct result
	double	x_step = m_NumColumns / size ;
	double	y_step = m_NumRows / size ;

	for (int i = 0 ; i < size ; ++i)
		{
		for (int j = 0 ; j < size ; ++j)
			work.SetElement(i, j, GetElement((int)(i * x_step), (int)(j * y_step))) ;
		}
	*this = work ;				// copy the result to ourselves
}

CMatrix CMatrix::GetNormalised(double min, double max) const
{
	CMatrix copy(*this) ;
	copy.Normalise(min, max) ;
	return copy ;
}

void CMatrix::Normalise(double min, double max)
{
	// get the lower and upper limit values in the matrix
	// we use the range to normalise
	double	e_min ;
	double	e_max ;

	GetNumericRange(e_min, e_max) ;
	
	double	range = e_max - e_min ;
	double	r_range = max - min ;			// required range
	double	value ;
	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			value = GetElement(i, j) ;
			value -= e_min ;			// 0 - range
			value /= range ;
			value *= r_range ;
			value += min ;
			SetElement(i, j, value) ;
			}
		}
}

// gets the lowest and highest values in the matrix
void CMatrix::GetNumericRange(double &min, double &max) const
{
	double	e_min = GetElement(0, 0) ;
	double	e_max = e_min ;
	double	value ;

	for (int i = 0 ; i < m_NumColumns ; ++i)
		{
		for (int j = 0 ; j < m_NumRows ; ++j)
			{
			value = GetElement(i, j) ;
			if (value < e_min)
				e_min = value ;
			else if (value > e_max)
				e_max = value ;
			}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -