📄 matrix.h
字号:
// Matrix.h: interface and implementation of the CMatrix class.
//
/////////////////////////////////////////////////////////////////////////////////
#include <math.h>
#define SQR(a) ( (a) == 0. ? 0. : (a)*(a) )
#define SIGN(a,b) ((b) > 0. ? fabs(a) : -fabs(a))
template <class T>
class CMatrix
{
public:
CMatrix();
CMatrix(CSize size);
CMatrix(int r, int c);
CMatrix(CMatrix<T>& src);
~CMatrix();
CMatrix<T> Transpose();
CMatrix<T> Inverse();
CMatrix<T> Inverse(int order);
CSize GetSize(){ return m_Size ; }
T** GetPtr(){ return m_pMatrix; }
void operator = (CMatrix<T>& src);
T* operator [] (int row);
CMatrix<T> operator + (CMatrix<T>& src);
CMatrix<T> operator - (CMatrix<T>& src);
CMatrix<T> operator * (CMatrix<T>& src);
CMatrix<T> operator * (double scale);
protected:
T** CreateMatrix(CSize size);
void DeleteMatrix(CSize size, T** pMatrix);
void CopyMatrix(CSize size, T** pDes, T** pSrc);
double pythag(double a, double b);
CSize m_Size;
T** m_pMatrix;
};
template <class T>
CMatrix<T>::CMatrix()
{
m_Size = CSize(0,0);
m_pMatrix = NULL;
}
template <class T>
CMatrix<T>::CMatrix(CSize size)
{
m_Size = size;
m_pMatrix = CreateMatrix(m_Size);
}
template <class T>
CMatrix<T>::CMatrix(int r, int c)
{
m_Size.cx = c;
m_Size.cy = r;
m_pMatrix = CreateMatrix(m_Size);
}
template <class T>
CMatrix<T>::CMatrix(CMatrix<T>& src)
{
m_Size = src.m_Size;
m_pMatrix = CreateMatrix(m_Size);
CopyMatrix(m_Size, m_pMatrix, src.m_pMatrix);
}
template <class T>
CMatrix<T>::~CMatrix()
{
if( m_pMatrix ) DeleteMatrix(m_Size, m_pMatrix);
}
template <class T>
T** CMatrix<T>::CreateMatrix(CSize size)
{
T** pMatrix = new T*[size.cy];
for( int r = 0 ; r < size.cy ; r++ )
{
pMatrix[r] = new T[size.cx];
memset( pMatrix[r], 0, sizeof(T)*size.cx );
}
return pMatrix;
}
template <class T>
void CMatrix<T>::DeleteMatrix(CSize size, T** pMatrix)
{
for( int r = 0 ; r < size.cy ; r++ )
delete [] pMatrix[r];
delete [] pMatrix;
}
template <class T>
void CMatrix<T>::CopyMatrix(CSize size, T** pDes, T** pSrc)
{
for( int r = 0 ; r < size.cy ; r++ )
{
for( int c = 0 ; c < size.cx ; c++ )
{
pDes[r][c] = pSrc[r][c];
}
}
}
template <class T>
CMatrix<T> CMatrix<T>::Transpose()
{
CSize size;
size.cx = m_Size.cy;
size.cy = m_Size.cx;
CMatrix<T> mat(size);
for( int r = 0 ; r < size.cy ; r++ )
{
for( int c = 0 ; c < size.cx ; c++ )
{
mat.m_pMatrix[r][c] = m_pMatrix[c][r];
}
}
return mat;
}
template <class T>
T* CMatrix<T>::operator [] (int row)
{
if( row < 0 && row >= m_Size.cy )
return NULL;
return m_pMatrix[row];
}
template <class T>
void CMatrix<T>::operator = (CMatrix<T>& src)
{
if( m_pMatrix ) DeleteMatrix(m_Size, m_pMatrix);
m_Size = src.m_Size;
m_pMatrix = CreateMatrix( m_Size );
CopyMatrix(m_Size, m_pMatrix, src.m_pMatrix );
}
template <class T>
CMatrix<T> CMatrix<T>::operator + (CMatrix<T>& src)
{
CMatrix<T> mat(m_Size);
if( m_Size != src.m_Size ) return mat;
for( int r = 0 ; r < m_Size.cy ; r++ )
{
for( int c = 0 ; c < m_Size.cx ; c++ )
{
mat.m_pMatrix[r][c] = m_pMatrix[r][c] + src.m_pMatrix[r][c];
}
}
return mat;
}
template <class T>
CMatrix<T> CMatrix<T>::operator - (CMatrix<T>& src)
{
ASSERT( m_Size == src.m_Size );
CMatrix<T> mat(m_Size);
for( int r = 0 ; r < m_Size.cy ; r++ )
{
for( int c = 0 ; c < m_Size.cx ; c++ )
{
mat.m_pMatrix[r][c] = m_pMatrix[r][c] - src.m_pMatrix[r][c];
}
}
return mat;
}
template <class T>
CMatrix<T> CMatrix<T>::operator * (CMatrix<T>& src)
{
ASSERT( m_Size.cx == src.m_Size.cy );
CSize size;
size.cx = src.m_Size.cx;
size.cy = m_Size.cy;
CMatrix<T> mat(size);
for( int r = 0 ; r < size.cy ; r++ )
{
for( int c = 0 ; c < size.cx ; c++ )
{
for( int i = 0 ; i < m_Size.cx ; i++ )
{
mat.m_pMatrix[r][c] += m_pMatrix[r][i]*src.m_pMatrix[i][c];
}
}
}
return mat;
}
template <class T>
CMatrix<T> CMatrix<T>::operator * (double scale)
{
CMatrix<T> mat(m_Size);
for( int r = 0 ; r < m_Size.cy ; r++ )
{
for( int c = 0 ; c < m_Size.cx ; c++ )
{
mat.m_pMatrix[r][c] = scale * m_pMatrix[r][c];
}
}
return mat;
}
template <class T>
double CMatrix<T>::pythag(double a, double b)
{
double absa, absb;
absa = fabs(a);
absb = fabs(b);
if( absa > absb ) return absa*sqrt(1.+SQR(absb/absa));
else return (absb == 0. ? 0. : absb*sqrt(1.+SQR(absa/absb)));
}
template <class T>
CMatrix<T> CMatrix<T>::Inverse()
{
int flag, i, j, jj, its, k, l, n, m, nm;
T anorm, c, f, g, h, s, scale, x, y, z, *rv1, *w;
n = m_Size.cx;
m = m_Size.cy;
CMatrix<T> mat(*this);
CMatrix<T> vmat(n,n);
T** a = mat.m_pMatrix;
T** v = vmat.m_pMatrix;
rv1 = new T[m_Size.cx];
w = new T[m_Size.cx];
g = scale = anorm = 0.;
for( i = 0 ; i < n ; i++ )
{
l = i+1;
rv1[i] = scale*g;
g = s = scale = 0.;
if( i < m )
{
for( k = i ; k < m ; k++ )
scale += (T)fabs(a[k][i]);
if( scale )
{
for( k = i ; k < m ; k++ )
{
a[k][i] /= scale;
s += a[k][i]*a[k][i];
}
f = a[i][i];
g = -(T)SIGN( sqrt(s), f);
h = f*g-s;
a[i][i] = f-g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = i ; k < m ; k++ )
s += a[k][i]*a[k][j];
f = s/h;
for( k = i ; k < m ; k++ )
a[k][j] += f*a[k][i];
}
for( k = i ; k < m ; k++ )
a[k][i] *= scale;
}
}
w[i] = scale*g;
g = s = scale = 0.;
if( i < m && i != n-1 )
{
for( k = l ; k < n ; k++ )
scale += (T)fabs(a[i][k]);
if( scale )
{
for( k = l ; k < n ; k++ )
{
a[i][k] /= scale;
s += a[i][k]*a[i][k];
}
f = a[i][l];
g = -(T)SIGN(sqrt(s), f);
h = f*g-s;
a[i][l] = f-g;
for( k = l ; k < n ; k++ )
rv1[k] = a[i][k]/h;
for( j = l ; j < m ; j++ )
{
for( s = 0., k = l ; k < n ; k++ )
s += a[j][k]*a[i][k];
for( k = l ; k < n ; k++ )
a[j][k] += s*rv1[k];
}
for( k = l ; k < n ; k++ )
a[i][k] *= scale;
}
}
anorm = __max( anorm, (T)(fabs(w[i])+fabs(rv1[i])) );
}
for( i = n-1 ; i >= 0 ; i-- )
{
if( i < n-1 )
{
if( g )
{
for( j = l ; j < n ; j++ )
v[j][i] = (a[i][j]/a[i][l])/g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = l ; k < n ; k++ )
s += a[i][k]*v[k][j];
for( k = l ; k < n ; k++ )
v[k][j] += s*v[k][i];
}
}
for( j = l ; j < n ; j++ )
v[i][j] = v[j][i] = 0.;
}
v[i][i] = 1.;
g = rv1[i];
l = i;
}
for( i = __min(m,n)-1 ; i >= 0 ; i-- )
{
l = i+1;
g = w[i];
for( j = l ; j < n ; j++ )
a[i][j] = 0.;
if( g )
{
g = (T)1./g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = l ; k < m ; k++ )
s += a[k][i]*a[k][j];
f = (s/a[i][i])*g;
for( k = i ; k < m ; k++ )
a[k][j] += f*a[k][i];
}
for( j = i ; j < m ; j++ )
a[j][i] *= g;
}
else
{
for( j = i ; j < m ; j++ )
a[j][i] = 0.;
}
a[i][i]++;
}
for( k = n-1 ; k >= 0 ; k-- )
{
for( its = 0 ; its < 30 ; its++ )
{
flag = 1;
for( l = k ; l >= 0 ; l-- )
{
nm = l-1;
if( (fabs(rv1[l])+anorm) == anorm )
{
flag = 0;
break;
}
if( (fabs(w[nm])+anorm) == anorm ) break;
}
if( flag )
{
c = 0.;
s = 1.;
for( i = l ; i <= k ; i++ )
{
f = s*rv1[i];
rv1[i] = c*rv1[i];
if( (fabs(f)+anorm) == anorm ) break;
g = w[i];
h = (T)pythag(f,g);
w[i] = h;
h = (T)1./h;
c = g*h;
s = -f*h;
for( j = 0 ; j < m ; j++ )
{
y = a[j][nm];
z = a[j][i];
a[j][nm] = y*c+z*s;
a[j][i] = z*c-y*s;
}
}
}
z = w[k];
if( l == k )
{
if( z < 0. )
{
w[k] = -z;
for( j = 0 ; j < n ; j++ )
v[j][k] = -v[j][k];
}
break;
}
x = w[l];
nm = k-1;
y = w[nm];
g = rv1[nm];
h = rv1[k];
f = (T)(((y-z)*(y+z)+(g-h)*(g+h))/(2.*h*y));
g = (T)pythag(f, 1.);
f = (T)(((x-z)*(x+z)+h*((y/(f+SIGN(g,f)))-h))/x);
c = s = 1.;
for( j = l ; j <= nm ; j++ )
{
i = j+1;
g = rv1[i];
y = w[i];
h = s*g;
g = c*g;
z = (T)pythag(f,h);
rv1[j] = z;
c = f/z;
s = h/z;
f = x*c+g*s;
g = g*c-x*s;
h = y*s;
y *= c;
for( jj = 0 ; jj < n ; jj++ )
{
x = v[jj][j];
z = v[jj][i];
v[jj][j] = x*c+z*s;
v[jj][i] = z*c-x*s;
}
z = (T)pythag(f,h);
w[j] = z;
if( z )
{
z = (T)1./z;
c = f*z;
s = h*z;
}
f = c*g+s*y;
x = c*y-s*g;
for( jj = 0 ; jj < m ; jj++ )
{
y = a[jj][j];
z = a[jj][i];
a[jj][j] = y*c+z*s;
a[jj][i] = z*c-y*s;
}
}
rv1[l] = 0.;
rv1[k] = f;
w[k] = x;
}
}
CMatrix<T> sm(n,n);
for( i = 0 ; i < n ; i++ )
sm[i][i] = (T)1./w[i];
delete [] rv1;
delete [] w;
return (vmat*sm*(mat.Transpose()));
}
template <class T>
CMatrix<T> CMatrix<T>::Inverse(int order)
{
int flag, i, j, jj, its, k, l, n, m, nm;
T anorm, c, f, g, h, s, scale, x, y, z, *rv1, *w;
n = m_Size.cx;
m = m_Size.cy;
CMatrix<T> mat(*this);
CMatrix<T> vmat(n,n);
T** a = mat.m_pMatrix;
T** v = vmat.m_pMatrix;
rv1 = new T[m_Size.cx];
w = new T[m_Size.cx];
g = scale = anorm = 0.;
for( i = 0 ; i < n ; i++ )
{
l = i+1;
rv1[i] = scale*g;
g = s = scale = 0.;
if( i < m )
{
for( k = i ; k < m ; k++ )
scale += (T)fabs(a[k][i]);
if( scale )
{
for( k = i ; k < m ; k++ )
{
a[k][i] /= scale;
s += a[k][i]*a[k][i];
}
f = a[i][i];
g = -(T)SIGN( sqrt(s), f);
h = f*g-s;
a[i][i] = f-g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = i ; k < m ; k++ )
s += a[k][i]*a[k][j];
f = s/h;
for( k = i ; k < m ; k++ )
a[k][j] += f*a[k][i];
}
for( k = i ; k < m ; k++ )
a[k][i] *= scale;
}
}
w[i] = scale*g;
g = s = scale = 0.;
if( i < m && i != n-1 )
{
for( k = l ; k < n ; k++ )
scale += (T)fabs(a[i][k]);
if( scale )
{
for( k = l ; k < n ; k++ )
{
a[i][k] /= scale;
s += a[i][k]*a[i][k];
}
f = a[i][l];
g = -(T)SIGN(sqrt(s), f);
h = f*g-s;
a[i][l] = f-g;
for( k = l ; k < n ; k++ )
rv1[k] = a[i][k]/h;
for( j = l ; j < m ; j++ )
{
for( s = 0., k = l ; k < n ; k++ )
s += a[j][k]*a[i][k];
for( k = l ; k < n ; k++ )
a[j][k] += s*rv1[k];
}
for( k = l ; k < n ; k++ )
a[i][k] *= scale;
}
}
anorm = __max( anorm, (T)(fabs(w[i])+fabs(rv1[i])) );
}
for( i = n-1 ; i >= 0 ; i-- )
{
if( i < n-1 )
{
if( g )
{
for( j = l ; j < n ; j++ )
v[j][i] = (a[i][j]/a[i][l])/g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = l ; k < n ; k++ )
s += a[i][k]*v[k][j];
for( k = l ; k < n ; k++ )
v[k][j] += s*v[k][i];
}
}
for( j = l ; j < n ; j++ )
v[i][j] = v[j][i] = 0.;
}
v[i][i] = 1.;
g = rv1[i];
l = i;
}
for( i = __min(m,n)-1 ; i >= 0 ; i-- )
{
l = i+1;
g = w[i];
for( j = l ; j < n ; j++ )
a[i][j] = 0.;
if( g )
{
g = (T)1./g;
for( j = l ; j < n ; j++ )
{
for( s = 0., k = l ; k < m ; k++ )
s += a[k][i]*a[k][j];
f = (s/a[i][i])*g;
for( k = i ; k < m ; k++ )
a[k][j] += f*a[k][i];
}
for( j = i ; j < m ; j++ )
a[j][i] *= g;
}
else
{
for( j = i ; j < m ; j++ )
a[j][i] = 0.;
}
a[i][i]++;
}
for( k = n-1 ; k >= 0 ; k-- )
{
for( its = 0 ; its < 30 ; its++ )
{
flag = 1;
for( l = k ; l >= 0 ; l-- )
{
nm = l-1;
if( (fabs(rv1[l])+anorm) == anorm )
{
flag = 0;
break;
}
if( (fabs(w[nm])+anorm) == anorm ) break;
}
if( flag )
{
c = 0.;
s = 1.;
for( i = l ; i <= k ; i++ )
{
f = s*rv1[i];
rv1[i] = c*rv1[i];
if( (fabs(f)+anorm) == anorm ) break;
g = w[i];
h = (T)pythag(f,g);
w[i] = h;
h = (T)1./h;
c = g*h;
s = -f*h;
for( j = 0 ; j < m ; j++ )
{
y = a[j][nm];
z = a[j][i];
a[j][nm] = y*c+z*s;
a[j][i] = z*c-y*s;
}
}
}
z = w[k];
if( l == k )
{
if( z < 0. )
{
w[k] = -z;
for( j = 0 ; j < n ; j++ )
v[j][k] = -v[j][k];
}
break;
}
x = w[l];
nm = k-1;
y = w[nm];
g = rv1[nm];
h = rv1[k];
f = (T)(((y-z)*(y+z)+(g-h)*(g+h))/(2.*h*y));
g = (T)pythag(f, 1.);
f = (T)(((x-z)*(x+z)+h*((y/(f+SIGN(g,f)))-h))/x);
c = s = 1.;
for( j = l ; j <= nm ; j++ )
{
i = j+1;
g = rv1[i];
y = w[i];
h = s*g;
g = c*g;
z = (T)pythag(f,h);
rv1[j] = z;
c = f/z;
s = h/z;
f = x*c+g*s;
g = g*c-x*s;
h = y*s;
y *= c;
for( jj = 0 ; jj < n ; jj++ )
{
x = v[jj][j];
z = v[jj][i];
v[jj][j] = x*c+z*s;
v[jj][i] = z*c-x*s;
}
z = (T)pythag(f,h);
w[j] = z;
if( z )
{
z = (T)1./z;
c = f*z;
s = h*z;
}
f = c*g+s*y;
x = c*y-s*g;
for( jj = 0 ; jj < m ; jj++ )
{
y = a[jj][j];
z = a[jj][i];
a[jj][j] = y*c+z*s;
a[jj][i] = z*c-y*s;
}
}
rv1[l] = 0.;
rv1[k] = f;
w[k] = x;
}
}
CMatrix<T> sm(n,n);
for( i = n-1 ; i >= n-order ; i-- )
//for( i = 0 ; i < n ; i++ )
sm[i][i] = (T)1./w[i];
delete [] rv1;
delete [] w;
return (vmat*sm*(mat.Transpose()));
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -