📄 jmatrix.cpp
字号:
// JMatrix.cpp: implementation of the JMatrix class.
//
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "JMatrix.h"
#include <math.h>
//////////////////////////////////////////////////////////////////////
// constructor & destructor
JMatrix::JMatrix( unsigned int ROW, unsigned int COL ) : row( ROW ), col( COL )
{
if ( row == 0 ) row = 1 ;
if ( col == 0 ) col = 1 ;
unsigned int N = row * col ;
data = new double[ N ] ;
for ( unsigned int i=0 ; i<N ; i++ ) data[i] = 0 ;
}
JMatrix::JMatrix( const JMatrix & A )
{
row = A.row ;
col = A.col ;
unsigned int N = row * col ;
data = new double[ N ] ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] = A.data[i] ;
}
JMatrix::~JMatrix()
{
delete [] data ;
}
//////////////////////////////////////////////////////////////////////
// query
unsigned int JMatrix::DimRow( ) const
{
return row ;
}
unsigned int JMatrix::DimCol( ) const
{
return col ;
}
//////////////////////////////////////////////////////////////////////
// operator
double JMatrix::operator () ( unsigned int ROW, unsigned int COL ) const
{
// T[ROW][COL]
return data[ ROW * col + COL ] ;
}
double * JMatrix::operator [] ( unsigned int ROW )
{
// T[ROW]
return data + ROW * col ;
}
const JMatrix JMatrix::operator + ( double v ) const
{
// T + v
JMatrix R( * this ) ;
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) R.data[i] += v ;
return R ;
}
const JMatrix JMatrix::operator - ( double v ) const
{
// T - v
JMatrix R( * this ) ;
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) R.data[i] -= v ;
return R ;
}
const JMatrix JMatrix::operator * ( double v ) const
{
// T * v
JMatrix R( * this ) ;
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) R.data[i] *= v ;
return R ;
}
const JMatrix JMatrix::operator + ( const JMatrix & A ) const
{
// T + A
JMatrix R( * this ) ;
if ( row == A.row && col == A.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) R.data[i] += A.data[i] ;
}
return R ;
}
const JMatrix JMatrix::operator - ( const JMatrix & A ) const
{
// T - A
JMatrix R( * this ) ;
if ( row == A.row && col == A.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) R.data[i] -= A.data[i] ;
}
return R ;
}
const JMatrix JMatrix::operator * ( const JMatrix & A ) const
{
// T * A
JMatrix R( row, A.col ) ;
if ( col == A.row )
{
unsigned int i, j, k ;
for ( i=0 ; i<R.row ; i++ )
{
for ( j=0 ; j<R.col ; j++ )
{
for ( k=0 ; k<col ; k++ )
R.data[ i * R.col + j ] += data[ i * col + k ] * A.data[ k * A.col + j ] ;
}
}
}
return R ;
}
JMatrix & JMatrix::operator = ( const JMatrix & A )
{
// T = A
if ( row == A.row && col == A.col )
{
unsigned int N = row * col ;
for ( unsigned int i=0 ; i<N ; i++ ) data[i] = A.data[i] ;
}
return * this ;
}
JMatrix & JMatrix::operator += ( double v )
{
// T = T + v
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] += v ;
return * this ;
}
JMatrix & JMatrix::operator -= ( double v )
{
// T = T - v
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] -= v ;
return * this ;
}
JMatrix & JMatrix::operator *= ( double v )
{
// T = T * v
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] *= v ;
return * this ;
}
JMatrix & JMatrix::operator += ( const JMatrix & A )
{
// T = T + A
if ( row == A.row && col == A.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] += A.data[i] ;
}
return * this ;
}
JMatrix & JMatrix::operator -= ( const JMatrix & A )
{
// T = T - A
if ( row == A.row && col == A.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] -= A.data[i] ;
}
return * this ;
}
//////////////////////////////////////////////////////////////////////
// operation
JMatrix & JMatrix::Transpose()
{
// T = T'
JMatrix T( * this ) ;
row = T.col ; col = T.row ;
unsigned int i, j ;
for ( i=0 ; i<row ; i++ )
{
for ( j=0 ; j<col ; j++ ) data[ i * col + j ] = T.data[ j * T.col + i ] ;
}
return * this ;
}
JMatrix & JMatrix::Inverse( double tol )
{
// find T^(-1), by Gauss elination method with column pivoting
if ( row == col )
{
unsigned int i, j, k, max_row ;
double cc, max_ele ;
JMatrix A( row, row ) ; // auxiliary matrix, initially identity
unsigned int * index = new unsigned int[ row ] ; // index array
{
for ( i=0 ; i<row ; i++ )
{
A.data[ i * row + i ] = 1 ; index[i] = i ;
}
}
for ( i=0 ; i<row ; i++ ) // eliminate T
{
max_ele = fabs( data[ index[i] * row + i ] ) ; // pivoting i-th column
max_row = i ;
for ( j=i+1 ; j<row ; j++ )
{
cc = fabs( data[ index[j] * row + i ] ) ;
if ( cc > max_ele )
{
max_ele = cc ; max_row = j ;
}
}
if ( max_ele < tol ) // if singuler, quit
{
delete [] index ; return * this ;
}
if ( max_row != i ) // exchange the two rows
{
j = index[max_row] ; index[max_row] = index[i] ; index[i] = j ;
}
{ // make the pivot unit
cc = data[ index[i] * row + i ] ;
for ( j=i+1 ; j<row ; j++ )
data[ index[i] * row + j ] /= cc ;
for ( j=0 ; j<row ; j++ )
A.data[ index[i] * row + j ] /= cc ;
}
for ( j=i+1 ; j<row ; j++ ) // eliminate the elements below
{
cc = data[ index[j] * row + i ] ;
for ( k=i+1 ; k<row ; k++ )
data[ index[j] * row + k ] -= data[ index[i] * row + k ] * cc ;
for ( k=0 ; k<row ; k++ )
A.data[ index[j] * row + k ] -= A.data[ index[i] * row + k ] * cc ;
}
}
for ( i=0 ; i<row ; i++ ) // make A the solution
{
for ( j=row-2 ; ; j-- ) // note that j is of type : unsigned int
{
for ( k=j+1 ; k<row ; k++ )
A.data[ index[j] * row + i ] -= data[ index[j] * row + k ] * A.data[ index[k] * row + i ] ;
if ( j == 0 ) break ;
}
}
for ( i=0 ; i<row ; i++ ) // T = A
{
for ( j=0 ; j<row ; j++ )
data[ i * row + j ] = A.data[ index[i] * row + j ] ;
}
delete [] index ; // clear up
}
return * this ;
}
JMatrix & JMatrix::Zero()
{
// T = 0
unsigned int N = row * col ;
for ( unsigned int i=0 ; i<N ; i++ ) data[i] = 0 ;
return * this ;
}
//////////////////////////////////////////////////////////////////////
// Norms
JMatrix & JMatrix::EqAdd( const JMatrix & A, const JMatrix & B )
{
// T = A + B
if ( row == A.row && row == B.row && col == A.col && col == B.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] = A.data[i] + B.data[i] ;
}
return * this ;
}
JMatrix & JMatrix::EqSub( const JMatrix & A, const JMatrix & B )
{
// T = A - B
if ( row == A.row && row == B.row && col == A.col && col == B.col )
{
unsigned int N = row * col ;
for( unsigned int i=0 ; i<N ; i++ ) data[i] = A.data[i] - B.data[i] ;
}
return * this ;
}
JMatrix & JMatrix::EqMul( const JMatrix & A, const JMatrix & B )
{
// T = A * B
if ( row == A.row && col == B.col && A.col == B.row )
{
unsigned int i, j, k ;
for ( i=0 ; i<row ; i++ )
{
for ( j=0 ; j<col ; j++ )
{
data[ i * col + j ] = 0 ;
for ( k=0 ; k<A.col ; k++ )
data[ i * col + j ] += A.data[ i * A.col + k ] * B.data[ k * B.col + j ] ;
}
}
}
return * this ;
}
//////////////////////////////////////////////////////////////////////
// Norms
double JMatrix::NormF() const
{
unsigned int N = row * col ;
double F = 0 ;
for ( unsigned int i=0 ; i<N ; i++ ) F += data[i] * data[i] ;
return sqrt( F ) ;
}
double JMatrix::Norm1() const
{
double F=0, e ;
unsigned int i, j ;
for ( i=0 ; i<col ; i++ )
{
e = 0 ;
for ( j=0 ; j<row ; j++ ) e += fabs( data[ j * col + i ] ) ;
if ( F < e ) F = e ;
}
return F ;
}
double JMatrix::NormInf() const
{
double F=0, e ;
unsigned int i, j ;
for ( i=0 ; i<row ; i++ )
{
e = 0 ;
for ( j=0 ; j<col ; j++ ) e += fabs( data[ i * col + j ] ) ;
if ( F < e ) F = e ;
}
return F ;
}
//////////////////////////////////////////////////////////////////////
// solver for linear system Ax=b
int JMatrix::LinSol( JMatrix & B, double tol )
{
// solve T * Xi = Bi, i=1, 2, ..., m; save Xi in Bi on return.
if ( row == col && row == B.row )
{
unsigned int i, j, k, max_row ;
double cc, max_ele ;
JMatrix B2( B ) ; // auxiliary matrix
unsigned int * index = new unsigned int[ row ] ; // index array
{
for ( i=0 ; i<row ; i++ ) index[i] = i ;
}
for ( i=0 ; i<row ; i++ ) // eliminate T
{
max_ele = fabs( data[ index[i] * col + i ] ) ; // pivoting in i-th column
max_row = i ;
for ( j=i+1 ; j<row ; j++ )
{
cc = fabs( data[ index[j] * col + i ] ) ;
if ( cc > max_ele )
{
max_ele = cc ; max_row = j ;
}
}
if ( max_ele < tol ) // if singuler, quit
{
delete [] index ; return 0 ;
}
if ( max_row != i ) // exchange the two rows
{
j = index[max_row] ; index[max_row] = index[i] ; index[i] = j ;
}
{ // make the pivot unit
cc = data[ index[i] * col + i ] ;
for ( j=i+1 ; j<col ; j++ )
data[ index[i] * col + j ] /= cc ;
for ( j=0 ; j<B2.col ; j++ )
B2.data[ index[i] * B2.col + j ] /= cc ;
}
for ( j=i+1 ; j<row ; j++ ) // eliminate the elements below
{
cc = data[ index[j] * col + i ] ;
for ( k=i+1 ; k<col ; k++ )
data[ index[j] * col + k ] -= data[ index[i] * col + k ] * cc ;
for ( k=0 ; k<B2.col ; k++ )
B2.data[ index[j] * B2.col + k ] -= B2.data[ index[i] * B2.col + k ] * cc ;
}
}
for ( i=0 ; i<B2.col ; i++ ) // make B2 the solution
{
for ( j=row-2 ; ; j-- ) // note that j is of type : unsigned int
{
for ( k=j+1 ; k<col ; k++ )
B2.data[ index[j] * B2.col + i ] -= data[ index[j] * col + k ]
* B2.data[ index[k] * B2.col + i ] ;
if ( j == 0 ) break ;
}
}
for ( i=0 ; i<B.row ; i++ ) // B = B2
{
for ( j=0 ; j<B.col ; j++ )
B.data[ i * B.col + j ] = B2.data[ index[i] * B2.col + j ] ;
}
delete [] index ; // clear up
return 1 ;
}
return 0 ;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -