⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 matrix.cpp

📁 一个矩阵类
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "stdafx.h"
#include "Matrix.h"

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif

CMatrix::CMatrix()
{
	m_nNumColumns = 1;
	m_nNumRows = 1;
	m_pData = NULL;
	BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
	ASSERT(bSuccess);
}

//--------------------指定行列构造函数--------------------
CMatrix::CMatrix(int nRows, int nCols)
{
	m_nNumRows = nRows;
	m_nNumColumns = nCols;
	m_pData = NULL;
	BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
	ASSERT(bSuccess);
}

//--------------------指定值构造函数--------------------
CMatrix::CMatrix(int nRows, int nCols, double value[])
{
	m_nNumRows = nRows;
	m_nNumColumns = nCols;
	m_pData = NULL;
	BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
	ASSERT(bSuccess);
	SetData(value);
}

//--------------------方阵构造函数--------------------
CMatrix::CMatrix(int nSize)
{
	m_nNumRows = nSize;
	m_nNumColumns = nSize;
	m_pData = NULL;
	BOOL bSuccess = Init(nSize, nSize);
	ASSERT (bSuccess);
}

//--------------------方阵构造函数--------------------
CMatrix::CMatrix(int nSize, double value[])
{
	m_nNumRows = nSize;
	m_nNumColumns = nSize;
	m_pData = NULL;
	BOOL bSuccess = Init(nSize, nSize);
	ASSERT (bSuccess);

	SetData(value);
}

//--------------------拷贝构造函数--------------------
CMatrix::CMatrix(const CMatrix& other)
{
	m_nNumColumns = other.GetNumColumns();
	m_nNumRows = other.GetNumRows();
	m_pData = NULL;
	BOOL bSuccess = Init(m_nNumRows, m_nNumColumns);
	ASSERT(bSuccess);
	memcpy(m_pData, other.m_pData, sizeof(double)*m_nNumColumns*m_nNumRows);
}

//--------------------析构函数--------------------
CMatrix::~CMatrix()
{
	if (m_pData)
	{
		delete[] m_pData;
		m_pData = NULL;
	}
}

//--------------------初始化函数--------------------
BOOL CMatrix::Init(int nRows, int nCols)
{
	if (m_pData)
	{
		delete[] m_pData;
		m_pData = NULL;
	}

	m_nNumRows = nRows;
	m_nNumColumns = nCols;
	int nSize = nCols*nRows;
	if (nSize < 0)
		return FALSE;

	// 分配内存
	m_pData = new double[nSize];
	
	if (m_pData == NULL)
		return FALSE;					// 内存分配失败
	if (IsBadReadPtr(m_pData, sizeof(double) * nSize))
		return FALSE;

	// 将各元素值置0
	memset(m_pData, 0, sizeof(double) * nSize);

	return TRUE;
}

//--------------------设置矩阵各元素的值--------------------
void CMatrix::SetData(double value[])
{
	memset(m_pData, 0, sizeof(double) * m_nNumColumns*m_nNumRows);
	memcpy(m_pData, value, sizeof(double)*m_nNumColumns*m_nNumRows);
}

//--------------------设置指定元素的值--------------------
BOOL CMatrix::SetElement(int nRow, int nCol, double value)
{
	if (nCol < 0 || nCol >= m_nNumColumns || nRow < 0 || nRow >= m_nNumRows)
		return FALSE;						// array bounds error
	if (m_pData == NULL)
		return FALSE;							// bad pointer error
	
	m_pData[nCol + nRow * m_nNumColumns] = value;

	return TRUE;
}

//--------------------设置指定元素的值--------------------
double CMatrix::GetElement(int nRow, int nCol) const
{
	ASSERT(nCol >= 0 && nCol < m_nNumColumns && nRow >= 0 && nRow < m_nNumRows); 
	ASSERT(m_pData);     // bad pointer error
	return m_pData[nCol + nRow * m_nNumColumns] ;
}

//--------------------获取矩阵的列数--------------------
int	CMatrix::GetNumColumns() const
{
	return m_nNumColumns;
}

//--------------------获取矩阵的行数--------------------
int	CMatrix::GetNumRows() const
{
	return m_nNumRows;
}

//--------------------获取矩阵的数据--------------------
double* CMatrix::GetData() const
{
	return m_pData;
}

//--------------------重载运算符=,给矩阵赋值--------------------
CMatrix& CMatrix::operator=(const CMatrix& other)
{
	if (&other != this)
	{
		BOOL bSuccess = Init(other.GetNumRows(), other.GetNumColumns());
		ASSERT(bSuccess);

		memcpy(m_pData, other.m_pData, sizeof(double)*m_nNumColumns*m_nNumRows);
	}

	return *this ;
}

//--------------------重载运算符==,判断矩阵是否相等--------------------
BOOL CMatrix::operator==(const CMatrix& other) const
{
	// 首先检查行列数是否相等
	if (m_nNumColumns != other.GetNumColumns() || m_nNumRows != other.GetNumRows())
		return FALSE;

	for (int i=0; i<m_nNumRows; ++i)
	{
		for (int j=0; j<m_nNumColumns; ++j)
		{
			if (GetElement(i, j) != other.GetElement(i, j))
				return FALSE;
		}
	}

	return TRUE;
}

//--------------------重载运算符!=,判断矩阵是否不相等--------------------
BOOL CMatrix::operator!=(const CMatrix& other) const
{
	return !(*this == other);
}

//--------------------重载运算符+,实现矩阵的加法--------------------
CMatrix	CMatrix::operator+(const CMatrix& other) const
{
	// 首先检查行列数是否相等
	ASSERT (m_nNumColumns == other.GetNumColumns() && m_nNumRows == other.GetNumRows());

	// 构造结果矩阵
	CMatrix	result(*this) ;
	for (int i = 0 ; i < m_nNumRows ; ++i)
	{
		for (int j = 0 ; j <  m_nNumColumns; ++j)
			result.SetElement(i, j, result.GetElement(i, j) + other.GetElement(i, j)) ;
	}

	return result ;
}

//--------------------重载运算符-,实现矩阵的减法--------------------
CMatrix	CMatrix::operator-(const CMatrix& other) const
{
	ASSERT (m_nNumColumns == other.GetNumColumns() && m_nNumRows == other.GetNumRows());

	CMatrix	result(*this) ;
	for (int i = 0 ; i < m_nNumRows ; ++i)
	{
		for (int j = 0 ; j <  m_nNumColumns; ++j)
			result.SetElement(i, j, result.GetElement(i, j) - other.GetElement(i, j)) ;
	}

	return result ;
}

//--------------------重载运算符*,实现矩阵的数乘--------------------
CMatrix	CMatrix::operator*(double value) const
{
	CMatrix	result(*this) ;		
	for (int i = 0 ; i < m_nNumRows ; ++i)
	{
		for (int j = 0 ; j <  m_nNumColumns; ++j)
			result.SetElement(i, j, result.GetElement(i, j) * value) ;
	}

	return result ;
}

//--------------------重载运算符*,实现矩阵的乘法--------------------
CMatrix	CMatrix::operator*(const CMatrix& other) const
{
	// 首先检查行列数是否符合要求
	ASSERT (m_nNumColumns == other.GetNumRows());

	CMatrix	result(m_nNumRows, other.GetNumColumns()) ;

	double	value ;
	for (int i = 0 ; i < result.GetNumRows() ; ++i)
	{
		for (int j = 0 ; j < other.GetNumColumns() ; ++j)
		{
			value = 0.0 ;
			for (int k = 0 ; k < m_nNumColumns ; ++k)
			{
				value += GetElement(i, k) * other.GetElement(k, j) ;
			}

			result.SetElement(i, j, value) ;
		}
	}

	return result ;
}

//--------------------矩阵的转置--------------------
CMatrix CMatrix::Transpose() const
{
	// 构造目标矩阵
	CMatrix	Trans(m_nNumColumns, m_nNumRows);

	// 转置各元素
	for (int i = 0 ; i < m_nNumRows ; ++i)
	{
		for (int j = 0 ; j < m_nNumColumns ; ++j)
			Trans.SetElement(j, i, GetElement(i, j)) ;
	}

	return Trans;
}


//-----------------------------奇异值分解求广义逆矩阵的实现---------------------------------

BOOL CMatrix::GInvertUV(CMatrix& mtxAP, CMatrix& mtxU, CMatrix& mtxV, double eps)
{ 
	int i,j,k,l,t,p,q,f;
	
	// 调用奇异值分解
    if (! SplitUV(mtxU, mtxV, eps))
		return FALSE;
	
	int m = m_nNumRows;
	int n = m_nNumColumns;
	
	// 初始化广义逆矩阵
	if (! mtxAP.Init(n, m))
		return FALSE;
	
	// 计算广义逆矩阵
	
    j=n;
    if (m<n) 
		j=m;
    j=j-1;
    k=0;
    while ((k<=j)&&(m_pData[k*n+k]!=0.0)) 
		k=k+1;
	
    k=k-1;
    for (i=0; i<=n-1; i++)
	{
		for (j=0; j<=m-1; j++)
		{ 
			t=i*m+j;	
			mtxAP.m_pData[t]=0.0;
			for (l=0; l<=k; l++)
			{ 
				f=l*n+i; 
				p=j*m+l; 
				q=l*n+l;
				mtxAP.m_pData[t]=mtxAP.m_pData[t]+mtxV.m_pData[f]*mtxU.m_pData[p]/m_pData[q];
			}
		}
	}
	
    return TRUE;
}

BOOL CMatrix::SplitUV(CMatrix& mtxU, CMatrix& mtxV, double eps)     //矩阵的奇异值分解实现
{ 
	int i,j,k,l,it,ll,kk,ix,iy,mm,nn,iz,m1,ks;
    double d,dd,t,sm,sm1,em1,sk,ek,b,c,shh,fg[2],cs[2];
    double *s,*e,*w;

	int m = m_nNumRows;
	int n = m_nNumColumns;

	// 初始化U, V矩阵
	if (! mtxU.Init(m, m) || ! mtxV.Init(n, n))
		return FALSE;

	// 临时缓冲区
	int ka = max(m, n) + 1;
    s = new double[ka];
    e = new double[ka];
    w = new double[ka];

	// 指定迭代次数为60
    it=60; 
	k=n;

    if (m-1<n) 
		k=m-1;

    l=m;
    if (n-2<m) 
		l=n-2;
    if (l<0) 
		l=0;

	// 循环迭代计算
    ll=k;
    if (l>k) 
		ll=l;
    if (ll>=1)
    { 
		for (kk=1; kk<=ll; kk++)
        { 
			if (kk<=k)
            { 
				d=0.0;
                for (i=kk; i<=m; i++)
                { 
					ix=(i-1)*n+kk-1; 
					d=d+m_pData[ix]*m_pData[ix];
				}

                s[kk-1]=sqrt(d);
                if (s[kk-1]!=0.0)
                { 
					ix=(kk-1)*n+kk-1;
                    if (m_pData[ix]!=0.0)
                    { 
						s[kk-1]=fabs(s[kk-1]);
                        if (m_pData[ix]<0.0) 
							s[kk-1]=-s[kk-1];
                    }
                    
					for (i=kk; i<=m; i++)
                    { 
						iy=(i-1)*n+kk-1;
                        m_pData[iy]=m_pData[iy]/s[kk-1];
                    }
                    
					m_pData[ix]=1.0+m_pData[ix];
                }
                
				s[kk-1]=-s[kk-1];
            }
            
			if (n>=kk+1)
            { 
				for (j=kk+1; j<=n; j++)
                { 
					if ((kk<=k)&&(s[kk-1]!=0.0))
                    { 
						d=0.0;
                        for (i=kk; i<=m; i++)
                        { 
							ix=(i-1)*n+kk-1;
                            iy=(i-1)*n+j-1;
                            d=d+m_pData[ix]*m_pData[iy];
                        }
                        
						d=-d/m_pData[(kk-1)*n+kk-1];
                        for (i=kk; i<=m; i++)
                        { 
							ix=(i-1)*n+j-1;
                            iy=(i-1)*n+kk-1;
                            m_pData[ix]=m_pData[ix]+d*m_pData[iy];
                        }
                    }
                    
					e[j-1]=m_pData[(kk-1)*n+j-1];
                }
            }
            
			if (kk<=k)
            { 
				for (i=kk; i<=m; i++)
                { 
					ix=(i-1)*m+kk-1; 
					iy=(i-1)*n+kk-1;
                    mtxU.m_pData[ix]=m_pData[iy];
                }
            }
            
			if (kk<=l)
            { 
				d=0.0;
                for (i=kk+1; i<=n; i++)
					d=d+e[i-1]*e[i-1];
                
				e[kk-1]=sqrt(d);
                if (e[kk-1]!=0.0)
                { 
					if (e[kk]!=0.0)
                    { 
						e[kk-1]=fabs(e[kk-1]);
                        if (e[kk]<0.0) 
							e[kk-1]=-e[kk-1];
                    }

                    for (i=kk+1; i<=n; i++)
                      e[i-1]=e[i-1]/e[kk-1];
                    
					e[kk]=1.0+e[kk];

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -