📄 tridiagonal.c
字号:
/* the i is the cross-shore index for use with v2i and v2i _[up,down] */void tridiagonal_solve_v(double *ad, double *adl, double *adu, double *b, double *x, double *work, double *v2i, double *v3i, int n){ int j; double AA11, AA12, AA21, AA22, idetAA; /* the terms in the 2x2 matrix and it's determinant */ double iAA11, iAA12, iAA21, iAA22; /* the inverse of the 2x2 matrix */ double rhs1, rhs2; /* Solve the T*xx = bb equation */ /* assumes that v2i v3i etc have already been calculated */ tridiagonal_solve_for_v_work(ad+1,adl+1,adu+1,b+1, x+1,work, n-2); AA11 = ad[0] - adu[0]*v2i[0]; //_up[i-1]; /* lines 77-80 of periodic_tridiag.m */ AA12 = adl[0] - adu[0]*v3i[0]; //_up[i-1]; AA21 = adu[n-1] - adl[n-1]*v2i[n-3]; //_dn[i-1]; AA22 = ad[n-1] - adl[n-1]*v3i[n-3]; //_dn[i-1]; idetAA = 1.0/(AA11*AA22 - AA12*AA21); /* Inverse of AA is = (1/det) [ A22 A21; A12 A11 ] */ iAA11 = idetAA * AA22; iAA12 = -idetAA * AA12; iAA21 = -idetAA * AA21; iAA22 = idetAA * AA11; rhs1 = b[0] - adu[0]*x[1]; /* lines 82-83 of periodic_tridiag.m */ rhs2 = b[n-1] - adl[n-1]*x[n-2]; x[0] = iAA11*rhs1 + iAA12*rhs2; /* line 87-89 of periodic tridiag.m */ x[n-1] = iAA21*rhs1 + iAA22*rhs2; for (j=1;j<n-1;j++) { x[j] += -(x[0] * v2i[j-1] + x[n-1]* v3i[j-1]); }}/* Sets up the the RHS of the U tridiagonal equation, this is essentially doing a matrix transpose *//* perhaps put this function in field.[ch] ?? */void transpose_U_field2D(field2D *bU, const field2D *RHS_U){ const int M = RHS_U->M; const int N = RHS_U->N; int BM = bU->M; int BN = bU->N; const double *rhs_ud; double *bud = bU->data; int i,k; /* First check to make sure that BM == N and BN ==M */ if ( (BM!=N) || (BN!=M) ) { /* if the sizes aren't transposable then a problem! */ funwaveC_error_message("** tridiagonal.c: bad sizes in transpose_U_field2D()"); } for (k=0;k<M;k++) { rhs_ud = &(RHS_U->data[k]); for (i=0;i<BM;i++) { *bud++ = *rhs_ud; rhs_ud += M; } }}void tridiagonal_solve_UV(field2D *U, field2D *V, const field2D *RHS_U, const field2D *RHS_V){ int i; int NV = V->N; int MV = V->M; int M = V->M; int MU = U->M; int NU = U->N - 2; /* probably not right needs to be U->N -2 */ double *vd, *utmpd; /* pointers to V and U^T data */ double *veq_ad, *veq_adl, *veq_adu, *rhs_vd; /* pointers to V eq upper diag, lower diag, diag, and rhs */ double *ueq_ad, *ueq_adl, *ueq_adu, *bud; /* pointers to V eq upper diag, lower diag, diag, and rhs */ double *v2i, *v3i; double *veq_adl2; transpose_U_field2D(bU, RHS_U); /* take the transpose of the matrix */ /* Now loop over the U tridiagonal solves */ for (i=0;i<MU;i++) { bud = REF2(bU, i) ; // &(bU->data[i*NU]); ueq_ad = REF2(AD_U, i); // & (AD_U->data[i*NU]); ueq_adl = REF2(ADL_U, i); // &(ADL_U->data[i*NU]); ueq_adu = REF2(ADU_U, i); // &(ADU_U->data[i*NU]); utmpd = REF2(UTMP, i); // &(UTMP->data[i*NU]); tridiagonal_solve_u(ueq_ad, ueq_adl, ueq_adu, bud, utmpd , worku, NU); // not periodic but w/ u=0 bc // assert( check_u_tridiagonal_solution(ueq_ad, ueq_adl, ueq_adu, utmpd, bud, NU) ); } /* Next transpose UTMP -> U */ transpose_U_field2D(U, UTMP); /* U = transpose(UTMP) */ /* What about boundary conditions??? */ /* Loop through the various V tridiagonal solvers */ // field2D_copy(V, RHS_V); // double *vd2 = (double *) g_malloc(sizeof(double)*M); for (i=1;i<NV-1;i++) { vd = &(V->data[i*MV]); rhs_vd = REF2(RHS_V, i) ; // &(RHS_V->data[i*MV]); veq_ad = REF2(AD_V, i ) ; //&(AD_V->data[i*MV]); veq_adl = REF2(ADL_V, i ); //&(ADL_V->data[i*MV]); veq_adl2 = REF2(ADL2_V, i ); //&(ADL_V->data[i*MV]); veq_adu = REF2(ADU_V, i ); // &(ADU_V->data[i*MV]); v2i = REF2(V2i, i-1); v3i = REF2(V3i, i-1); tridiagonal_solve_v(veq_ad, veq_adl, veq_adu, rhs_vd, vd , workv, v2i, v3i, M); // remember this is periodic // solve_cyc_tridiag_nonsym(veq_ad, veq_adu, veq_adl2, rhs_vd, vd, M); //assert( check_v_tridiagonal_solution(veq_ad, veq_adl, veq_adu, vd, rhs_vd, M) ); //assert( check_v_vtmp_integral(vd, rhs_vd, M, i) ); } field2D_apply_nogradient_bc(V); /* Apply dV/dx = 0 at boundary: i=0 and i=N */ // g_free(vd2);}void diff_two_vectors(double *a, double *b, int num){ int i; double e; double max = 0.0; int imax = -1; for (i=0;i<num;i++) { e = fabs( a[i]-b[i] ); if (e > max) { max = e; imax = i; } } fprintf(stderr," ** max difference = %e at i=%d\n", max, imax);}/* solve following system w/o the corner elements and then use * Sherman-Morrison formula to compensate for them * * diag[0] abovediag[0] 0 ..... belowdiag[N-1] * belowdiag[0] diag[1] abovediag[1] ..... * 0 belowdiag[1] diag[2] * 0 0 belowdiag[2] ..... * ... ... * abovediag[N-1] ... *///int solve_cyc_tridiag_nonsym(const double diag[], size_t d_stride,// const double abovediag[], size_t a_stride,// const double belowdiag[], size_t b_stride,// const double rhs[], size_t r_stride,// double x[], size_t x_stride, size_t N)int solve_cyc_tridiag_nonsym(const double diag[], const double abovediag[], const double belowdiag[], const double rhs[], double x[], int N){ int d_stride = 1; int a_stride = 1; int b_stride = 1; int r_stride = 1; int x_stride = 1; double *alpha = (double *) g_malloc (N * sizeof (double)); double *zb = (double *) g_malloc (N * sizeof (double)); double *zu = (double *) g_malloc (N * sizeof (double)); double *w = (double *) g_malloc (N * sizeof (double)); double beta; int i; /* Bidiagonalization (eliminating belowdiag) & rhs update diag' = alpha rhs' = zb rhs' for Aq=u is zu */ zb[0] = rhs[0]; if (diag[0] != 0) beta = -diag[0]; else beta = 1; { const double q = 1 - abovediag[0]*belowdiag[0]/(diag[0]*diag[d_stride]); if (fabs(q/beta) > 0.5 && fabs(q/beta) < 2) { beta *= (fabs(q/beta) < 1) ? 0.5 : 2; } } zu[0] = beta; alpha[0] = diag[0] - beta; for (i = 1; i+1 < N; i++) { const double t = belowdiag[b_stride*(i - 1)]/alpha[i-1]; alpha[i] = diag[d_stride*i] - t*abovediag[a_stride*(i - 1)]; zb[i] = rhs[r_stride*i] - t*zb[i-1]; zu[i] = -t*zu[i-1]; /* FIXME!!! */ if (alpha[i] == 0) { funwaveC_error_message("** status = GSL_EZERODIV;"); } } i = N-1; const double t = belowdiag[b_stride*(i - 1)]/alpha[i-1]; alpha[i] = diag[d_stride*i] - abovediag[a_stride*i]*belowdiag[b_stride*i]/beta - t*abovediag[a_stride*(i - 1)]; zb[i] = rhs[r_stride*i] - t*zb[i-1]; zu[i] = abovediag[a_stride*i] - t*zu[i-1]; /* FIXME!!! */ if (alpha[i] == 0) { funwaveC_error_message("** status = GSL_EZERODIV;"); } /* backsubstitution */ { int i, j; w[N-1] = zu[N-1]/alpha[N-1]; x[N-1] = zb[N-1]/alpha[N-1]; for (i = N - 2, j = 0; j <= N - 2; j++, i--) { w[i] = (zu[i] - abovediag[a_stride*i] * w[i+1])/alpha[i]; x[i*x_stride] = (zb[i] - abovediag[a_stride*i] * x[x_stride*(i + 1)])/alpha[i]; } } /* Sherman-Morrison */ { const double vw = w[0] + belowdiag[b_stride*(N - 1)]/beta * w[N-1]; const double vx = x[0] + belowdiag[b_stride*(N - 1)]/beta * x[x_stride*(N - 1)]; /* FIXME!!! */ if (vw + 1 == 0) { funwaveC_error_message("** status = GSL_EZERODIV;"); } { int i; for (i = 0; i < N; i++) x[i] -= vx/(1 + vw)*w[i]; } } g_free (zb); g_free (zu); g_free (w); g_free (alpha);}void make_ADL2_from_ADL(field2D *ADL2_V, const field2D *ADL){ int N = ADL->N; int M = ADL->M; int i,j; double tmp; for (i=0;i<N;i++) { tmp = DR2(ADL,i,0); for (j=1;j<M;j++) { DR2(ADL2_V,i,j-1) = DR2(ADL_V,i,j); } DR2(ADL2_V,i,M-1) = tmp; }}#define IPP(j,M) (((j)==(M)-1) ? 0 : (j)+1)#define IMM(j,M) (((j)==0) ? (M)-1 : (j)-1)int check_v_tridiagonal_solution(double *veq_ad, double *veq_adl, double *veq_adu, double *vd, double *rhs_vd, int M){ int i, ip, im; double epsilon = 0.0; double vepsilon = 0.0; double flimit = 1e-10; double e; for (i=0;i<M;i++) { ip = IPP(i,M); im = IMM(i,M); e = (veq_ad[i]*vd[i] + veq_adl[i]*vd[im] + veq_adu[i]*vd[ip]) - rhs_vd[i]; epsilon += e*e; vepsilon += vd[i]*vd[i]; } epsilon /= M; epsilon = sqrt(epsilon); vepsilon /= M; vepsilon = sqrt(vepsilon); if (vepsilon < 1e-22) flimit = 1e-8; if (vepsilon < 1e-30) flimit = 1e-4; if (vepsilon < 1e-50) flimit = 1e-3; if ( epsilon/vepsilon > flimit ) { fprintf(stderr,"** error: sqrt((A*v-b)^2) = epsilon = %e, sqrt(v*v) = %e \n",epsilon,vepsilon); return 0; } return 1;}int check_u_tridiagonal_solution(double *ueq_ad, double *ueq_adl, double *ueq_adu, double *ud, double *bud, int N){ int i, ip, im; double epsilon = 0.0; double uepsilon = 0.0; double flimit = 1e-10; double e; for (i=0;i<N;i++) { ip = i+1; im = i-1; e = (ueq_ad[i]*ud[i-1] + ueq_adl[i-1]*ud[i-2] + ueq_adu[i-1]*ud[i]) - bud[i]; epsilon += e*e; uepsilon += ud[i]*ud[i]; } epsilon /= N; epsilon = sqrt(epsilon); uepsilon /= N; uepsilon = sqrt(uepsilon); if ( epsilon/uepsilon > flimit ) { fprintf(stderr,"** error: sqrt((A*u-b)^2) = epsilon = %e, sqrt(u*u) = %e ",epsilon,uepsilon); return 0; } return 1;}int check_v_vtmp_integral(double *vd, double *vtmpd, int M, int index){ int i; double iv =0.0; double ivtmp = 0.0; double flimit = 1e-12; for (i=0;i<M;i++) { iv += *vd++; ivtmp += *vtmpd++; } if ( fabs(iv-ivtmp) > flimit ) { fprintf(stderr,"** error: at i=%d, int[v dy] = %.10e, int[ vtmp dy] = %.10e, diff=%e, ratio = %e \n", index, iv,ivtmp, fabs(iv-ivtmp), fabs(iv-ivtmp)/fabs(iv)) ; return 1; } return 1;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -