📄 sparsemat.cpp
字号:
#include <cmath>#include "sparsemat.h"using namespace std;// Text of error messages, used by Matrix::ReportError// Do not change this!char *SparseMatrix::ErrorMessages[] = { "", // ERROR_INVALID_SIZE "dimension mismatch", // ERROR_SIZE_MISMATCH "index out of range", // ERROR_INVALID_INDEX} ;void SparseMatrix::ReportError( ErrorCode code ){ string prefix( "Error: " ); throw std::runtime_error( prefix + ErrorMessages[code] );}SparseMatrix::SparseMatrix( int rows, int cols ){ if ( cols == 0 ) cols = rows; m_rows = rows; m_cols = cols;}SparseMatrix::SparseMatrix( const SparseMatrix& A ){ m_rows = A.m_rows; m_cols = A.m_cols; m_data = A.m_data; m_row_indices = A.m_row_indices; m_col_indices = A.m_col_indices;}SparseMatrix::~SparseMatrix(){}void SparseMatrix::SetSize( int rows, int cols ){ m_rows = rows; m_cols = cols; m_row_indices.clear(); m_col_indices.clear(); m_data.clear();}ostream& operator<<( ostream& os, const SparseMatrix& A ){ for ( int i = 0; i < A.get_rows(); i++ ) { for ( int j = 0; j < A.get_cols(); j++ ) { double x = A( i, j ); os.width(10); os << x; } cout << endl; } return os;}double& SparseMatrix::operator()( int i, int j ){ if ( i >= m_rows || j >= m_cols ) ReportError( ERROR_INVALID_INDEX ); m_row_indices[i].insert( j ); m_col_indices[j].insert( i ); return m_data[std::make_pair(i, j)];}double SparseMatrix::operator()( int i, int j ) const{ return Get(i, j);}void SparseMatrix::Squeeze( double tol ){ for ( iset_iter i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( fabs( m_data[std::make_pair(i->first, *j)] ) <= tol ) { i->second.erase( *j ); if ( i->second.size() == 0 ) m_row_indices.erase( i ); m_col_indices[*j].erase( i->first ); if ( m_col_indices[*j].size() == 0 ) m_col_indices.erase( *j ); } } }}int SparseMatrix::nnz() const{ int total = 0; for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { total += i->second.size(); } return total;}elt_citer SparseMatrix::get_row_begin( int i ) const{ return m_row_indices.find( i )->second.begin();}elt_citer SparseMatrix::get_row_end( int i ) const{ return m_row_indices.find( i )->second.end();}elt_citer SparseMatrix::get_col_begin( int i ) const{ return m_col_indices.find( i )->second.begin();}elt_citer SparseMatrix::get_col_end( int i ) const{ return m_col_indices.find( i )->second.end();}SparseMatrix SparseMatrix::operator*( const SparseMatrix& B ) const{ if ( m_cols != B.m_rows ) ReportError( ERROR_SIZE_MISMATCH ); SparseMatrix C( m_rows, B.m_cols ); for ( int i = 0; i < m_rows; i++ ) { for ( int j = 0; j < m_cols; j++ ) { double elt = 0.0; elt_citer irow = get_row_begin( i ); elt_citer jcol = B.get_col_begin( j ); for ( ; irow != get_row_end( i ); irow++ ) { elt += Get( i, *irow ) * B( *irow, j ); } if ( elt != 0.0 ) C( i, j ) = elt; } } return C;}SparseMatrix& SparseMatrix::operator=( const SparseMatrix& A ){ m_rows = A.m_rows; m_cols = A.m_cols; m_data = A.m_data; m_row_indices = A.m_row_indices; m_col_indices = A.m_col_indices; return *this;}SparseMatrix& SparseMatrix::operator+=( const SparseMatrix& A ){ if ( m_rows != A.m_rows || m_cols != A.m_cols ) ReportError( ERROR_SIZE_MISMATCH ); for ( iset_citer i = A.m_row_indices.begin(); i != A.m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = A( i->first, *j ); Set( i->first, *j, Get( i->first, *j ) + elt ); } } return *this;}SparseMatrix& SparseMatrix::operator-=( const SparseMatrix& A ){ *this += -A; return *this;}SparseMatrix& SparseMatrix::operator*=( double s ){ for ( mat_iter i = m_data.begin(); i != m_data.end(); i++ ) i->second *= s; return *this;}SparseMatrix& SparseMatrix::operator/=( double s ){ *this *= (1.0 / s); return *this;}SparseMatrix operator*( double s, const SparseMatrix& A ){ return A * s;}SparseMatrix SparseMatrix::operator+( const SparseMatrix& B ) const{ if ( m_rows != B.m_rows || m_cols != B.m_cols ) ReportError( ERROR_SIZE_MISMATCH ); SparseMatrix C( *this ); for ( iset_citer i = B.m_row_indices.begin(); i != B.m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = B( i->first, *j ); C( i->first, *j ) += elt; } } return C;}double SparseMatrix::Get( int i, int j ) const{ if ( i >= m_rows || j >= m_cols ) ReportError( ERROR_INVALID_INDEX ); mat_citer iter = m_data.find( make_pair( i, j ) ); if ( iter == m_data.end() ) return 0.0; return iter->second;}void SparseMatrix::Set( int i, int j, double x ){ m_data[make_pair(i,j)] = x; m_row_indices[i].insert( j ); m_col_indices[j].insert( i );}void SparseMatrix::Identity( int rows, int cols ){ if ( cols == 0 ) cols = rows; SetSize( rows, cols ); for ( int i = 0; i < rows && i < cols; i++ ) Set( i, i, 1.0 );}SparseMatrix SparseMatrix::Transpose() const{ SparseMatrix T( m_cols, m_rows ); for ( mat_citer i = m_data.begin(); i != m_data.end(); i++ ) { T( i->first.second, i->first.first ) = Get( i->first.first, i->first.second ); } return T;}SparseMatrix SparseMatrix::TriU( int diag ) const{ SparseMatrix S( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( i->first <= *j - diag ) S( i->first, *j ) = Get( i->first, *j ); } } return S;}SparseMatrix SparseMatrix::TriL( int diag ) const{ SparseMatrix S( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( i->first >= *j - diag ) S( i->first, *j ) = Get( i->first, *j ); } } return S;}SparseMatrix SparseMatrix::operator-() const{ SparseMatrix C( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = Get( i->first, *j ); C( i->first, *j ) = -elt; } } return C;}SparseMatrix SparseMatrix::operator*( double s ) const{ SparseMatrix C( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = Get( i->first, *j ); C( i->first, *j ) = elt * s; } } return C;}SparseMatrix SparseMatrix::operator-( const SparseMatrix& B ) const{ return *this + (-B);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -