⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 wmllinearsystem.cpp

📁 Wild Math Library数值计算库
💻 CPP
📖 第 1 页 / 共 2 页
字号:
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::SymmetricInverse (const GMatrix<Real>& rkA,
    GMatrix<Real>& rkInvA)
{
    // Same algorithm as SolveSymmetric, but applied simultaneously to
    // columns of identity matrix.
    int iSize = rkA.GetRows();
    GMatrix<Real> kTmp = rkA;

    Real* afV = new Real[iSize];
    assert( afV );

    int i0, i1;
    for (i0 = 0; i0 < iSize; i0++)
    {
        for (i1 = 0; i1 < iSize; i1++)
            rkInvA[i0][i1] = ( i0 != i1 ? (Real)0.0 : (Real)1.0 );
    }

    for (i1 = 0; i1 < iSize; i1++)
    {
        for (i0 = 0; i0 < i1; i0++)
            afV[i0] = kTmp[i1][i0]*kTmp[i0][i0];

        afV[i1] = kTmp[i1][i1];
        for (i0 = 0; i0 < i1; i0++)
            afV[i1] -= kTmp[i1][i0]*afV[i0];

        kTmp[i1][i1] = afV[i1];
        for (i0 = i1+1; i0 < iSize; i0++)
        {
            for (int i2 = 0; i2 < i1; i2++)
                kTmp[i0][i1] -= kTmp[i0][i2]*afV[i2];
            kTmp[i0][i1] /= afV[i1];
        }
    }
    delete[] afV;

    for (int iCol = 0; iCol < iSize; iCol++)
    {
        // forward substitution
        for (i0 = 0; i0 < iSize; i0++)
        {
            for (i1 = 0; i1 < i0; i1++)
                rkInvA[i0][iCol] -= kTmp[i0][i1]*rkInvA[i1][iCol];
        }

        // diagonal division
        for (i0 = 0; i0 < iSize; i0++)
        {
            if ( Math<Real>::FAbs(kTmp[i0][i0]) <= Math<Real>::EPSILON )
                return false;
            rkInvA[i0][iCol] /= kTmp[i0][i0];
        }

        // back substitution
        for (i0 = iSize-2; i0 >= 0; i0--)
        {
            for (i1 = i0+1; i1 < iSize; i1++)
                rkInvA[i0][iCol] -= kTmp[i1][i0]*rkInvA[i1][iCol];
        }
    }

    return true;
}
//----------------------------------------------------------------------------

//----------------------------------------------------------------------------
// conjugate gradient methods
//----------------------------------------------------------------------------
template <class Real>
Real LinearSystem<Real>::Dot (int iSize, const Real* afU, const Real* afV)
{
    Real fDot = (Real)0.0;
    for (int i = 0; i < iSize; i++)
        fDot += afU[i]*afV[i];
    return fDot;
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::Multiply (const GMatrix<Real>& rkA, const Real* afX,
    Real* afProd)
{
    int iSize = rkA.GetRows();
    memset(afProd,0,iSize*sizeof(Real));
    for (int iRow = 0; iRow < iSize; iRow++)
    {
        for (int iCol = 0; iCol < iSize; iCol++)
            afProd[iRow] += rkA[iRow][iCol]*afX[iCol];
    }
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::Multiply (int iSize, const SparseMatrix& rkA,
    const Real* afX, Real* afProd)
{
    memset(afProd,0,iSize*sizeof(Real));
    typename SparseMatrix::const_iterator pkIter = rkA.begin();
    for (/**/; pkIter != rkA.end(); pkIter++)
    {
        int i = pkIter->first.first;
        int j = pkIter->first.second;
        Real fValue = pkIter->second;
        afProd[i] += fValue*afX[j];
        if ( i != j )
            afProd[j] += fValue*afX[i];
    }
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::UpdateX (int iSize, Real* afX, Real fAlpha,
    const Real* afP)
{
    for (int i = 0; i < iSize; i++)
        afX[i] += fAlpha*afP[i];
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::UpdateR (int iSize, Real* afR, Real fAlpha,
    const Real* afW)
{
    for (int i = 0; i < iSize; i++)
        afR[i] -= fAlpha*afW[i];
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::UpdateP (int iSize, Real* afP, Real fBeta,
    const Real* afR)
{
    for (int i = 0; i < iSize; i++)
        afP[i] = afR[i] + fBeta*afP[i];
}
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::SolveSymmetricCG (const GMatrix<Real>& rkA,
    const Real* afB, Real* afX)
{
    // based on the algorithm in "Matrix Computations" by Golum and Van Loan
    assert( rkA.GetRows() == rkA.GetColumns() );
    int iSize = rkA.GetRows();
    Real* afR = new Real[iSize];
    Real* afP = new Real[iSize];
    Real* afW = new Real[iSize];

    // first iteration
    memset(afX,0,iSize*sizeof(Real));
    memcpy(afR,afB,iSize*sizeof(Real));
    Real fRho0 = Dot(iSize,afR,afR);
    memcpy(afP,afR,iSize*sizeof(Real));
    Multiply(rkA,afP,afW);
    Real fAlpha = fRho0/Dot(iSize,afP,afW);
    UpdateX(iSize,afX,fAlpha,afP);
    UpdateR(iSize,afR,fAlpha,afW);
    Real fRho1 = Dot(iSize,afR,afR);

    // remaining iterations
    const int iMax = 1024;
    int i;
    for (i = 1; i < iMax; i++)
    {
        Real fRoot0 = Math<Real>::Sqrt(fRho1);
        Real fNorm = Dot(iSize,afB,afB);
        Real fRoot1 = Math<Real>::Sqrt(fNorm);
        if ( fRoot0 <= ms_fTolerance*fRoot1 )
            break;

        Real fBeta = fRho1/fRho0;
        UpdateP(iSize,afP,fBeta,afR);
        Multiply(rkA,afP,afW);
        fAlpha = fRho1/Dot(iSize,afP,afW);
        UpdateX(iSize,afX,fAlpha,afP);
        UpdateR(iSize,afR,fAlpha,afW);
        fRho0 = fRho1;
        fRho1 = Dot(iSize,afR,afR);
    }

    delete[] afW;
    delete[] afP;
    delete[] afR;

    return i < iMax;
}
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::SolveSymmetricCG (int iSize,
    const SparseMatrix& rkA, const Real* afB, Real* afX)
{
    // based on the algorithm in "Matrix Computations" by Golum and Van Loan
    Real* afR = new Real[iSize];
    Real* afP = new Real[iSize];
    Real* afW = new Real[iSize];

    // first iteration
    memset(afX,0,iSize*sizeof(Real));
    memcpy(afR,afB,iSize*sizeof(Real));
    Real fRho0 = Dot(iSize,afR,afR);
    memcpy(afP,afR,iSize*sizeof(Real));
    Multiply(iSize,rkA,afP,afW);
    Real fAlpha = fRho0/Dot(iSize,afP,afW);
    UpdateX(iSize,afX,fAlpha,afP);
    UpdateR(iSize,afR,fAlpha,afW);
    Real fRho1 = Dot(iSize,afR,afR);

    // remaining iterations
    const int iMax = 1024;
    int i;
    for (i = 1; i < iMax; i++)
    {
        Real fRoot0 = Math<Real>::Sqrt(fRho1);
        Real fNorm = Dot(iSize,afB,afB);
        Real fRoot1 = Math<Real>::Sqrt(fNorm);
        if ( fRoot0 <= ms_fTolerance*fRoot1 )
            break;

        Real fBeta = fRho1/fRho0;
        UpdateP(iSize,afP,fBeta,afR);
        Multiply(iSize,rkA,afP,afW);
        fAlpha = fRho1/Dot(iSize,afP,afW);
        UpdateX(iSize,afX,fAlpha,afP);
        UpdateR(iSize,afR,fAlpha,afW);
        fRho0 = fRho1;
        fRho1 = Dot(iSize,afR,afR);
    }

    delete[] afW;
    delete[] afP;
    delete[] afR;

    return i < iMax;
}
//----------------------------------------------------------------------------

//----------------------------------------------------------------------------
// banded matrices
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::ForwardEliminate (int iReduceRow,
    BandedMatrix<Real>& rkA, Real* afB)
{
    // the pivot must be nonzero in order to proceed
    Real fDiag = rkA(iReduceRow,iReduceRow);
    if ( fDiag == (Real)0.0 )
        return false;

    Real fInvDiag = ((Real)1.0)/fDiag;
    rkA(iReduceRow,iReduceRow) = (Real)1.0;

    // multiply row to be consistent with diagonal term of 1
    int iColMin = iReduceRow + 1;
    int iColMax = iColMin + rkA.GetUBands();
    if ( iColMax > rkA.GetSize() )
        iColMax = rkA.GetSize();

    int iCol;
    for (iCol = iColMin; iCol < iColMax; iCol++)
        rkA(iReduceRow,iCol) *= fInvDiag;

    afB[iReduceRow] *= fInvDiag;

    // reduce remaining rows
    int iRowMin = iReduceRow + 1;
    int iRowMax = iRowMin + rkA.GetLBands();
    if ( iRowMax > rkA.GetSize() )
        iRowMax = rkA.GetSize();

    for (int iRow = iRowMin; iRow < iRowMax; iRow++)
    {
        Real fMult = rkA(iRow,iReduceRow);
        rkA(iRow,iReduceRow) = (Real)0.0;
        for (iCol = iColMin; iCol < iColMax; iCol++)
            rkA(iRow,iCol) -= fMult*rkA(iReduceRow,iCol);
        afB[iRow] -= fMult*afB[iReduceRow];
    }

    return true;
}
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::SolveBanded (const BandedMatrix<Real>& rkA,
    const Real* afB, Real* afX)
{
    BandedMatrix<Real> kTmp = rkA;
    int iSize = rkA.GetSize();
    memcpy(afX,afB,iSize*sizeof(Real));

    // forward elimination
    int iRow;
    for (iRow = 0; iRow < iSize; iRow++)
    {
        if ( !ForwardEliminate(iRow,kTmp,afX) )
            return false;
    }

    // backward substitution
    for (iRow = iSize-2; iRow >= 0; iRow--)
    {
        int iColMin = iRow + 1;
        int iColMax = iColMin + kTmp.GetUBands();
        if ( iColMax > iSize )
            iColMax = iSize;
        for (int iCol = iColMin; iCol < iColMax; iCol++)
            afX[iRow] -= kTmp(iRow,iCol)*afX[iCol];
    }

    return true;
}
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::ForwardEliminate (int iReduceRow,
    BandedMatrix<Real>& rkA, GMatrix<Real>& rkB)
{
    // the pivot must be nonzero in order to proceed
    Real fDiag = rkA(iReduceRow,iReduceRow);
    if ( fDiag == (Real)0.0 )
        return false;

    Real fInvDiag = ((Real)1.0)/fDiag;
    rkA(iReduceRow,iReduceRow) = (Real)1.0;

    // multiply row to be consistent with diagonal term of 1
    int iColMin = iReduceRow + 1;
    int iColMax = iColMin + rkA.GetUBands();
    if ( iColMax > rkA.GetSize() )
        iColMax = rkA.GetSize();

    int iCol;
    for (iCol = iColMin; iCol < iColMax; iCol++)
        rkA(iReduceRow,iCol) *= fInvDiag;
    for (iCol = 0; iCol <= iReduceRow; iCol++)
        rkB(iReduceRow,iCol) *= fInvDiag;

    // reduce remaining rows
    int iRowMin = iReduceRow + 1;
    int iRowMax = iRowMin + rkA.GetLBands();
    if ( iRowMax > rkA.GetSize() )
        iRowMax = rkA.GetSize();

    for (int iRow = iRowMin; iRow < iRowMax; iRow++)
    {
        Real fMult = rkA(iRow,iReduceRow);
        rkA(iRow,iReduceRow) = (Real)0.0;
        for (iCol = iColMin; iCol < iColMax; iCol++)
            rkA(iRow,iCol) -= fMult*rkA(iReduceRow,iCol);
        for (iCol = 0; iCol <= iReduceRow; iCol++)
            rkB(iRow,iCol) -= fMult*rkB(iReduceRow,iCol);
    }

    return true;
}
//----------------------------------------------------------------------------
template <class Real>
void LinearSystem<Real>::BackwardEliminate (int iReduceRow,
    BandedMatrix<Real>& rkA, GMatrix<Real>& rkB)
{
    int iRowMax = iReduceRow - 1;
    int iRowMin = iReduceRow - rkA.GetUBands();
    if ( iRowMin < 0 )
        iRowMin = 0;

    for (int iRow = iRowMax; iRow >= iRowMin; iRow--)
    {
        Real fMult = rkA(iRow,iReduceRow);
        rkA(iRow,iReduceRow) = (Real)0.0;
        for (int iCol = 0; iCol < rkB.GetColumns(); iCol++)
            rkB(iRow,iCol) -= fMult*rkB(iReduceRow,iCol);
    }
}
//----------------------------------------------------------------------------
template <class Real>
bool LinearSystem<Real>::Invert (const BandedMatrix<Real>& rkA,
    GMatrix<Real>& rkInvA)
{
    int iSize = rkA.GetSize();
    BandedMatrix<Real> kTmp = rkA;
    int iRow;
    for (iRow = 0; iRow < iSize; iRow++)
    {
        for (int iCol = 0; iCol < iSize; iCol++)
        {
            if ( iRow != iCol )
                rkInvA(iRow,iCol) = (Real)0.0;
            else
                rkInvA(iRow,iRow) = (Real)1.0;
        }
    }

    // forward elimination
    for (iRow = 0; iRow < iSize; iRow++)
    {
        if ( !ForwardEliminate(iRow,kTmp,rkInvA) )
            return false;
    }

    // backward elimination
    for (iRow = iSize-1; iRow >= 1; iRow--)
        BackwardEliminate(iRow,kTmp,rkInvA);

    return true;
}
//----------------------------------------------------------------------------

//----------------------------------------------------------------------------
// explicit instantiation
//----------------------------------------------------------------------------
namespace Wml
{
#ifdef WML_INSTANTIATE_BEFORE
template<> float LinearSystem<float>::ms_fTolerance = 1e-06f;
template class WML_ITEM LinearSystem<float>;

template<> double LinearSystem<double>::ms_fTolerance = 1e-06;
template class WML_ITEM LinearSystem<double>;

#else
template class WML_ITEM LinearSystem<float>;
template<> float LinearSystem<float>::ms_fTolerance = 1e-06f;

template class WML_ITEM LinearSystem<double>;
template<> double LinearSystem<double>::ms_fTolerance = 1e-06;
#endif
}
//----------------------------------------------------------------------------

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -