📄 linsys.cpp
字号:
// ENERGY211/CME211//// linalg.cpp - Application file for Project 2//// This file includes definitions for the functions Inverse,// SolveAxb and helper functions//#include <stdexcept>#include <iostream>#include <cmath>#include "linsys.h"using namespace std;// Text of error messages, used by Matrix::ReportError// Do not change this!char *LinearSystem::ErrorMessages[] = { "matrix singular", // ERROR_MATRIX_SINGULAR "matrix not square", // ERROR_MATRIX_NOT_SQUARE "size mismatch", // ERROR_SIZE_MISMATCH} ;// This is used to throw exceptions from functions defined// in this file. The first argument is the name of the// function in which the exception is to be thrown, and// the second argument is the reason for throwing the// exception. Use the error messages given in the handout.void LinearSystem::ReportError( ErrorCode code ){ string prefix( "Error: " ); throw std::runtime_error( prefix + ErrorMessages[code] );}// Compute the LU factorization of the matrix A. The// first argument A is overwritten with the upper// triangular factor U. The second argument is// overwritten with a unit lower triangular matrix// containing the multipliers used during Gaussian// elimination.//// An error must be reported if A turns out to be// singular (as detected by a zero diagonal element)// or if A is not a square matrix.void LinearSystem::GaussElim(){ if ( m_Factored == true ) return; int m = m_A.get_rows(); int n = m_A.get_cols(); if ( m != n ) ReportError( ERROR_MATRIX_NOT_SQUARE ); m_U = m_A; m_L.Identity( m, n ); for ( int j = 0; j < n - 1; j++ ) { if ( m_U[j][j] == 0.0 ) ReportError( ERROR_MATRIX_SINGULAR ); for ( int i = j + 1; i < m; i++ ) { // A[i][j] <- 0 // update other elements of ith row double mult = m_U[i][j] / m_U[j][j]; m_L[i][j] = mult; m_U[i][j] = 0.0; for ( int k = j + 1; k < n; k++ ) m_U[i][k] -= mult * m_U[j][k]; } } m_Factored = true;}// Use forward subsitution to solve the system Ly = b, where// L is a unit lower triangular matrix. The solution y is// to be returned.//// An error must be reported if L is not square.ColVector LinearSystem::ForwardSub( const Matrix& L, const ColVector& b ){ int m = L.get_rows(); int n = L.get_cols(); if ( m != n ) ReportError( ERROR_MATRIX_NOT_SQUARE ); ColVector y( b ); for ( int i = 0; i < m; i++ ) { for ( int j = 0; j < i; j++ ) y[i] -= L[i][j] * y[j]; } return y;}// Use back substitution to solve the system Ux = y, where// U is an upper triangular matrix.//// An error must be reported if U is not a square matrix, or// if U is singular (which occurs if U has a zero diagonal// element).ColVector LinearSystem::BackSub( const Matrix& U, const ColVector& y ){ int m = U.get_rows(); int n = U.get_cols(); if ( m != n ) ReportError( ERROR_MATRIX_NOT_SQUARE ); ColVector x( y ); for ( int i = m - 1; i >= 0; i-- ) { if ( U[i][i] == 0.0 ) ReportError( ERROR_MATRIX_SINGULAR ); for ( int j = i + 1; j < n; j++ ) x[i] -= U[i][j] * x[j]; x[i] /= U[i][i]; } return x;}// Solve the system Ax = b as follows:// 1. Factor A = LU using Gaussian elimination// 2. Solve Ly = b using forward substitution// 3. Solve Ux = y using back substitutionColVector LinearSystem::DirectSolve( const ColVector& b ){ int m = m_A.get_rows(); int n = m_A.get_cols(); if ( m != n ) ReportError( ERROR_MATRIX_NOT_SQUARE ); if ( m != b.get_rows() ) ReportError( ERROR_SIZE_MISMATCH ); // Gaussian elimination GaussElim(); ColVector y = ForwardSub( m_L, b ); ColVector x = BackSub( m_U, y ); return x;}ColVector LinearSystem::IterativeSolve( const ColVector& b, LinearSystem::IterMethod m, double w ){ switch( m ) { case MethodJacobi: return Jacobi( b ); case MethodGaussSeidel: return GaussSeidel( b ); case MethodSOR: return SOR( b, w ); default: throw exception(); }}ColVector LinearSystem::Jacobi( const ColVector& b ){ ColVector Md = m_A.GetDiag(); Matrix M; M.SetDiag( Md ); Matrix N = M - m_A; return StationarySolve( M, N, b );}ColVector LinearSystem::StationarySolve( const Matrix& M, const Matrix& N, const ColVector& b ){ ColVector x; x.Zeros( M.get_cols(), 1 ); int niter = 0; do { ColVector oldx = x; ColVector r = N * x + b; LinearSystem S( M ); x = S.DirectSolve( r ); niter++; double err = ColVector::norm( x - oldx ); if ( err < m_tol || niter >= m_maxiter ) break; } while ( true ); return x;}ColVector LinearSystem::GaussSeidel( const ColVector& b ){ Matrix M = m_A.TriL(); Matrix N = M - m_A; return StationarySolve( M, N, b );}ColVector LinearSystem::SOR( const ColVector& b, double w ){ // (1/w)D + L + U - (1/w-1)D ColVector Md = m_A.GetDiag(); Matrix M; M.SetDiag( Md ); M /= w; Matrix L = m_A.TriL( -1 ); M += L; Matrix N = M - m_A; return StationarySolve( M, N, b );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -