📄 matrix.h
字号:
/*! @file
********************************************************************************
<PRE>
模块名 : 实现复数矩阵的基本运算
文件名 : matrix.h
文件实现功能 : 实现复数矩阵的基本运算,并能实现复数矩阵的求逆、快速傅里叶变换和奇异值分解。
作者 : 陈鹏飞
版本 : V1.0
--------------------------------------------------------------------------------
多线程安全性 : <是/否>[,说明]
异常时安全性 : <是/否>[,说明]
--------------------------------------------------------------------------------
备注 : 该复数矩阵模板只适用于float和double型复数矩阵。不能用于实数矩阵。
--------------------------------------------------------------------------------
修改记录 :
日 期 版本 修改人 修改内容
2009/04/16 V1.01 <陈鹏飞> <模板的规范化>
</PRE>
*******************************************************************************/
#pragma once //避免头文件重复定义
#include <cassert>
#include <valarray>
#include <complex>
#include <cmath>
#include <cstdlib>
#include <float.h>
using namespace std;
// RTTI
#include <typeinfo>
typedef complex<float> COMPLEX_FLOAT;
typedef complex<double> COMPLEX_DOUBLE;
const float FLOATERROR = 1.0e-6F;
const double DOUBLEERROR = 1.0e-15;
const long double LONGDOUBLEERROR = 1.0e-30;
const double GOLDENSECTION = 0.618033399; //黄金分割常数(1.0-0.381966)
#define DBL_MIN 2.2250738585072014e-308 /* min positive value */
#define DBL_EPSILON 2.2204460492503131e-016 /* smallest such that 1.0+DBL_EPSILON != 1.0 */
//取x符号,+-或0
template <class T>
T Sgn(const T& x)
{
return x < T(0) ? T(-1) : (x > T(0) ? T(1) : T(0));
}
//绝对值
template <class T>
long double Abs(const T& x)
{
complex<long double> cld(x);
long double ldAbs = abs(x);
return(ldAbs);
}
//比较两float浮点数相等
bool FloatEqual(float lhs, float rhs)
{
if (Abs(lhs - rhs) < FLOATERROR)
return true;
else
return false;
}
//比较两float浮点数不相等
bool FloatNotEqual(float lhs, float rhs)
{
if (Abs(lhs - rhs) >= FLOATERROR)
return true;
else
return false;
}
//比较两double浮点数相等
bool FloatEqual(double lhs, double rhs)
{
if (Abs(lhs - rhs) < DOUBLEERROR)
return true;
else
return false;
}
//比较两double浮点数不相等
bool FloatNotEqual(double lhs, double rhs)
{
if (Abs(lhs - rhs) >= DOUBLEERROR)
return true;
else
return false;
}
//比较两long double浮点数相等
bool FloatEqual(long double lhs, long double rhs)
{
if (Abs(lhs - rhs) < LONGDOUBLEERROR)
return true;
else
return false;
}
//比较两long double浮点数不相等
bool FloatNotEqual(long double lhs, long double rhs)
{
if (Abs(lhs - rhs) >= LONGDOUBLEERROR)
return true;
else
return false;
}
//求x与y的最小值,返回小者
template <class T>
T Min(const T& x, const T& y)
{
if(x < y)
return x;
else
return y;
}
//求x与y的最大值,返回大者
template <class T>
T Max(const T& x, const T& y)
{
if(x > y)
return x;
else
return y;
}
/*! @class
********************************************************************************
<PRE>
类名称 : 复数矩阵模板类
功能 : 实现复数矩阵的基本运算
--------------------------------------------------------------------------------
备注 : <该类只能用来实现float和double型复数矩阵的基本运算>
--------------------------------------------------------------------------------
作者 : 陈鹏飞
</PRE>
*******************************************************************************/
template <class _Ty>
class CMatrix
{
typedef CMatrix<_Ty> _MatTy;
typedef _Ty _vTy;
public:
//构造函数一(参数1, 2分别为矩阵的行与列数)
/******
矩阵类matrix的构造函数一
构造函数中出现的m_Datas为valarray类的对象,申请stRow * stCol个单元,
单元内没赋值。对数组对象m_Datas使用了valarray类的构造函数:
explicit valarray(size_t n)
对私有变量m_stRow和m_stCol分别赋初值stRow, stCol。
******/
CMatrix(size_t stRow, size_t stCol)
: m_Datas(stRow * stCol),
m_stRow(stRow), m_stCol(stCol)
{
m_Datas.resize(GetRowNum() * GetColNum(), _vTy(0));
/*
构造函数中出现的m_Datas为valarray类的对象,若在函数体内调用的
resize(size_t n, const T& c = T())函数有两个参数(参阅valarray中的
定义),第一个参数申请有矩阵行*列那么多元素个数的存储空间,第二个
参数对这些申请的空间赋予具有模式T的初值0。如果不调用该函数,则缺省
情况下m_Datas长度为0;另外,调用该函数但不使用第二个参数,则不会对
m_Datas的任何一个元素赋初值。
*/
}
//构造函数二(参数1为指向矩阵的指针,参数2, 3为矩阵行与列数)
/******
矩阵类matrix的构造函数二
对私有变量m_stRow和m_stCol分别赋初值stRow, stCol。
对数组对象m_Datas使用了valarray类的构造函数:
valarray(const _Ty *p, size_t n)
m_Datas初始化的第一参数为矩阵rhs指针,第二个参数为rhs的元素总个数,
即rhs行数*列数
******/
CMatrix(const _vTy* rhs, size_t stRow, size_t stCol)
: m_Datas(rhs, stRow * stCol),
m_stRow(stRow), m_stCol(stCol)
{
}
//构造函数三(拷贝构造函数,参数为对矩阵matrix的引用)
/******
矩阵类matrix的构造函数三
用引用矩阵rhs的数组对象m_Datas初始化matrix所定义对象的m_Datas,
用引用矩阵rhs的行数rhs.GetRowNum()和列数rhs.GetColNum()分别初始化私
有变量m_stRow和m_stCol
******/
CMatrix(const _MatTy& rhs)
: m_Datas(rhs.m_Datas),
m_stRow(rhs.GetRowNum()), m_stCol(rhs.GetColNum())
{
}
size_t GetRowNum() const //返回矩阵行数的函数
{
return m_stRow;
}
size_t GetColNum() const //返回矩阵列数的函数
{
return m_stCol;
}
////////////////////////////////////////////////
// 重载运算符
// 索引行列(写入)
_vTy* operator [] (size_t stRow)
{
assert(stRow != 0);
assert(stRow < GetRowNum() + 1); //断定stRow不超实际矩阵行值
_vTy *p = &m_Datas[(stRow-1) * GetColNum()];
return p;
}
// 索引行列(只读)
const _vTy* operator [] (size_t stRow) const
{
assert(stRow != 0);
assert(stRow < GetRowNum()); //断定stRow不超实际矩阵行值
_vTy *p = &m_Datas[(stRow-1) * GetColNum()];
return p;
}
// 索引行列(写入)
_vTy& operator () (size_t stRow, size_t stCol)
{
assert(stRow < GetRowNum()); //断定stRow不超实际矩阵行值
assert(stCol < GetColNum()); //断定stCol不超实际矩阵列值
return m_Datas[stRow * GetColNum() + stCol];
}
// 索引行列(只读)
const _vTy operator () (size_t stRow, size_t stCol) const
{
assert(stRow < GetRowNum()); //断定stRow不超实际矩阵行值
assert(stCol < GetColNum()); //断定stCol不超实际矩阵列值
return m_Datas[stRow * GetColNum() + stCol];
}
// 赋值操作符
//矩阵与矩阵的自反*, +, -运算符
// 矩阵与矩阵的自反+
_MatTy& operator += (const _MatTy& rhs)
{
assert(GetRowNum() == rhs.GetRowNum());
assert(GetColNum() == rhs.GetColNum());
m_Datas += rhs.m_Datas;
return *this;
}
// 矩阵与矩阵的自反-
_MatTy& operator -= (const _MatTy& rhs)
{
assert(GetRowNum() == rhs.GetRowNum());
assert(GetColNum() == rhs.GetColNum());
m_Datas -= rhs.m_Datas;
return *this;
}
// 矩阵与矩阵的自反*
_MatTy& operator *= (const _MatTy& rhs)
{
Multiply(*this, *this, rhs);
return *this;
}
//矩阵自反加、减、乘以、除以数
_MatTy& operator += (const _Ty& rhs) //矩阵自加数
{
m_Datas += rhs; //利用数组对象对每个元素加数
return *this; //结果放在原矩阵(数组m_Datas)中
}
_MatTy& operator -= (const _Ty& rhs) //矩阵自减数
{
m_Datas -= rhs;
return *this;
}
_MatTy& operator *= (const _Ty& rhs) //矩阵自乘数
{
m_Datas *= rhs;
return *this;
}
_MatTy& operator /= (const _Ty& rhs) //矩阵自除以数
{
m_Datas /= rhs;
return *this;
}
// 矩阵取反
_MatTy operator - () const
{
_MatTy mat(*this);
mat.m_Datas = -mat.m_Datas;
return mat;
}
//矩阵乘数 mat = lhs * rhs
friend _MatTy operator * (const _MatTy& lhs, const _Ty& rhs)
{
_MatTy mat(lhs); //新生成一新矩阵对象mat
mat.m_Datas *= rhs; //对新矩阵对象每个元素乘以数
return mat;
}
//数加矩阵 mat = lhs + rhs
friend _MatTy operator + (const _Ty& lhs, const _MatTy& rhs)
{
_MatTy mat(rhs); //新生成一新矩阵对象mat
mat.m_Datas += lhs; //数加上新矩阵对象的每个元素
return mat;
}
//矩阵加数 mat = lhs + rhs
friend _MatTy operator + (const _MatTy& lhs, const _Ty& rhs)
{
_MatTy mat(lhs); //新生成一新矩阵对象mat
mat.m_Datas += rhs; //数加上新矩阵对象的每个元素
return mat;
}
//数减矩阵 mat = lhs - rhs
friend _MatTy operator - (const _Ty& lhs, const _MatTy& rhs)
{
_MatTy mat(rhs); //新生成一新矩阵对象mat
mat.m_Datas -= lhs; //数减新矩阵对象的每个元素
return mat;
}
//数乘矩阵 mat = lhs * rhs
friend _MatTy operator * (const _Ty& lhs, const _MatTy& rhs)
{
_MatTy mat(rhs);
mat.m_Datas *= lhs;
return mat;
}
//矩阵加法 mat = lhs + rhs
friend _MatTy operator + (const _MatTy& lhs, const _MatTy& rhs)
{
_MatTy mat(lhs);
mat.m_Datas += rhs.m_Datas;
return mat;
}
//矩阵减法 mat = lhs - rhs
friend _MatTy operator - (const _MatTy& lhs, const _MatTy& rhs)
{
_MatTy mat(lhs);
mat.m_Datas -= rhs.m_Datas;
return mat;
}
//矩阵乘法 mTmp = lhs * rhs
friend _MatTy operator * (const _MatTy& lhs, const _MatTy& rhs)
{ //生成一个矩阵新对象mTmp
_MatTy mTmp(lhs.GetRowNum(), rhs.GetColNum()); //没初始化
return Multiply(mTmp, lhs, rhs);
}
//矩阵点乘
friend _MatTy PM(_MatTy& lhs, _MatTy& rhs)
{
int Row, Col;
Row = lhs.GetRowNum();
Col = lhs.GetColNum();
_MatTy Temp(Row, Col);
for ( int i = 0; i < Row; i++)
for ( int j = 0; j < Col; j++)
Temp(i, j) = lhs(i, j) * rhs(i, j);
return Temp;
}
//矩阵点除
friend _MatTy PD(_MatTy& lhs, _MatTy& rhs)
{
int Row, Col;
Row = lhs.GetRowNum();
Col = rhs.GetColNum();
_MatTy Temp(Row, Col);
for ( int i = 0; i < Row; i++)
for ( int j = 0; j < Col; j++)
{
if(rhs(i, j) == _Ty(0, 0))
{
//cout << "input is error" << endl;
exit(1);
}
Temp(i, j) = lhs(i, j) / rhs(i, j);
}
return Temp;
}
//矩阵余弦函数
friend _MatTy COS(_MatTy& lhs)
{
int Row, Col;
Row = lhs.GetRowNum();
Col = lhs.GetColNum();
_MatTy tmp(Row, Col);
for ( int i = 0; i < Row; i++)
for ( int j = 0; j < Col; j++)
tmp(i, j) = cos(lhs(i, j));
return tmp;
}
//矩阵正弦函数
friend _MatTy SIN(_MatTy& lhs)
{
int Row, Col;
Row = lhs.GetRowNum();
Col = lhs.GetColNum();
_MatTy tmp(Row, Col);
for ( int i = 0; i < Row; i++)
for ( int j = 0; j < Col; j++)
tmp(i, j) = sin(lhs(i, j));
return tmp;
}
friend void lstSprsFit(_MatTy& x, _MatTy& y, int n, _MatTy& yFit, _MatTy& coef)
{
int mCol = x.GetColNum();
int yCol = y.GetColNum();
_MatTy X1(1, 2*n);
_MatTy X2(1, n+1);
_MatTy X3(n, n+1);
_MatTy XSUM(n+1, n+1);
_MatTy YSUM(1, n+1);
_Ty SumNum = (0, 0);
for ( int i = 0; i < yCol; i++)
{
SumNum = SumNum + y(0, i);
}
for ( int i = 0; i < 2*n; i++)
{
_Ty sum = (0, 0);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -