📄 svdcmp.c
字号:
int col, j, k ; double temp, ww, sum ;/* Right*/ ww = 0.0 ; // Insures failure of upcoming if first time for (col=cols-1 ; col>=0 ; col--) { if (ww != 0.0) { for (j=col+1 ; j<cols ; j++) // Double division avoids underflow v[j*cols+col] = (a[col*cols+j] / a[col*cols+col+1]) / ww ; for (j=col+1 ; j<cols ; j++) { sum = 0.0 ; for (k=col+1 ; k<cols ; k++) sum += a[col*cols+k] * v[k*cols+j] ; for (k=col+1 ; k<cols ; k++) v[k*cols+j] += v[k*cols+col] * sum ; } } for (j=col+1 ; j<cols ; j++) v[col*cols+j] = v[j*cols+col] = 0.0 ; v[col*cols+col] = 1.0 ; ww = work[col] ; }/* Left*/ for (col=cols-1 ; col>=0 ; col--) { for (j=col+1 ; j<cols ; j++) a[col*cols+j] = 0.0 ; if (w[col] == 0.0) { for (j=col ; j<rows ; j++) a[j*cols+col] = 0.0 ; } else { ww = 1.0 / w[col] ; for (j=col+1 ; j<cols ; j++) { sum = 0.0 ; for (k=col+1 ; k<rows ; k++) sum += a[k*cols+col] * a[k*cols+j] ; temp = sum / a[col*cols+col] * ww ; for (k=col ; k<rows ; k++) a[k*cols+j] += a[k*cols+col] * temp ; } for (j=col ; j<rows ; j++) a[j*cols+col] *= ww ; } a[col*cols+col] += 1.0 ; }}/*-------------------------------------------------------------------------------- qr--------------------------------------------------------------------------------*/static void qr ( int rows , int cols , int lower , int index , double *a , double *v , double *w , double *work ){ int col, colp1, row ; double c, cn, s, sn, thisw, rot1, rot2, hypot, temp, ww ; ww = w[index] ; sn = work[index] ; rot1 = work[index-1] ; rot2 = w[index-1] ; temp = ((rot2-ww) * (rot2+ww) + (rot1-sn) * (rot1+sn)) / (2.0 * sn * rot2) ; hypot = RSS ( temp , 1.0 ) ; thisw = w[lower] ; cn = ((thisw-ww) * (thisw+ww) + sn * ((rot2 / (temp + SIGN(hypot,temp))) - sn )) / thisw ; c = s = 1.0 ; for (col=lower ; col<index ; col++) { colp1 = col+1 ; rot1 = work[colp1] ; sn = s * rot1 ; rot1 = c * rot1 ; hypot = RSS ( cn , sn ) ; work[col] = hypot ; c = cn / hypot ; s = sn / hypot ; cn = thisw * c + rot1 * s ; rot1 = rot1 * c - thisw * s ; rot2 = w[colp1] ; sn = rot2 * s ; rot2 *= c ; for (row=0 ; row<cols ; row++) { thisw = v[row*cols+col] ; temp = v[row*cols+colp1] ; v[row*cols+col] = thisw * c + temp * s ; v[row*cols+colp1] = temp * c - thisw * s ; } hypot = RSS ( cn , sn ) ; w[col] = hypot ; if (hypot != 0.0) { c = cn / hypot ; s = sn / hypot ; } cn = c * rot1 + s * rot2 ; thisw = c * rot2 - s * rot1 ; for (row=0 ; row<rows ; row++) { rot1 = a[row*cols+col] ; rot2 = a[row*cols+colp1] ; a[row*cols+col] = rot1 * c + rot2 * s ; a[row*cols+colp1] = rot2 * c - rot1 * s ; } } w[index] = thisw ; work[lower] = 0.0 ; work[index] = cn ;}/*-------------------------------------------------------------------------------- verify_nonneg - Flip sign of this singular value and its vector if negative--------------------------------------------------------------------------------*/static void verify_nonneg ( int cols , int index , double *w , double *v ){ int i ; if (w[index] < 0.0) { w[index] = -w[index] ; for (i=0 ; i<cols ; i++) v[i*cols+index] = -v[i*cols+index] ; }}/*-------------------------------------------------------------------------------- Backsubstitution algorithm for solving Ax=b where A generated u, w, v Inputs are not destroyed, so it may be called with several b's. The user must have filled in the public RHS 'b' before calling this.--------------------------------------------------------------------------------*/void SingularValueDecomp::backsub ( double thresh , // Threshold for zeroing singular values. Typically 1.e-8. double *x // Output of solution ){ int row, col, cc ; double sum, *mat ; if (u == NULL) // Did we replace a with u mat = a ; else // or preserve it? mat = u ;/* Set the threshold according to the maximum singular value*/ sum = 0.0 ; // Will hold max w for (col=0 ; col<cols ; col++) { if (w[col] > sum) sum = w[col] ; } thresh *= sum ; if (thresh <= 0.0) // Avoid dividing by zero in next step thresh = 1.e-30 ;/* Find U'b*/ for (col=0 ; col<cols ; col++) { sum = 0.0 ; if (w[col] > thresh) { for (row=0 ; row<rows ; row++) sum += mat[row*cols+col] * b[row] ; sum /= w[col] ; } work[col] = sum ; }/* Multiply by V*/ for (col=0 ; col<cols ; col++) { sum = 0.0 ; for (cc=0 ; cc<cols ; cc++) sum += v[col*cols+cc] * work[cc] ; x[col] = sum ; }}#if 0/*-------------------------------------------------------------------------------- Optional main to test it--------------------------------------------------------------------------------*/#define RANDMAX 32767void main (){ int rep, m, n, i, j, k ; double *x, *sa, sum, err, wmin, wmax ; char msg[81] ; SingularValueDecomp *s ; printf ( "\nEnter m, n:" ) ; while (gets (msg) == 0 ) ; sscanf ( msg , "%d %d" , &m, &n ) ; if (m <= 0 || n <= 0) exit ( 0 ) ; sa = (double *) malloc ( m * n * sizeof(double) ) ; x = (double *) malloc ( n * sizeof(double) ) ; s = new SingularValueDecomp ( m , n , 0 ) ; for (rep=0;;rep++) { if (kbhit()) { if (getch() == 27) exit ( 0 ) ; } if ((m == n) && ! rep) { // Ill cond for (i=0 ; i<m ; i++) { for (j=0 ; j<n ; j++) sa[i*n+j] = s->a[i*n+j] = 1.0 / (i + j + 1.0) ; s->b[i] = (double) (rand() - RANDMAX/2) / (double) RANDMAX ; } } else { for (i=0 ; i<m ; i++) { for (j=0 ; j<n ; j++) sa[i*n+j] = s->a[i*n+j] = (double) (rand() - RANDMAX/2) / (double) RANDMAX ; s->b[i] = (double) (rand() - RANDMAX/2) / (double) RANDMAX ; } } s->svdcmp () ; wmin = 1.e30 ; wmax = -1.e30 ; for (i=0 ; i<n ; i++) { if (s->w[i] < wmin) wmin = s->w[i] ; if (s->w[i] > wmax) wmax = s->w[i] ; } printf ( "\n(%lf %lf)", wmin, wmax ) ; err = 0.0 ; for (i=0 ; i<m ; i++) { for (j=0 ; j<n ; j++) { sum = 0.0 ; for (k=0 ; k<n ; k++) sum += s->a[i*n+k] * s->w[k] * s->v[j*n+k] ; err += fabs ( sum - sa[i*n+j] ) ; } } printf ( " Rep=%lf", err ) ; if (fabs(err) > 1.e-10) { printf ( "\a" ) ; getch() ; } err = 0.0 ; for (i=0 ; i<n ; i++) { for (j=0 ; j<n ; j++) { sum = 0.0 ; for (k=0 ; k<m ; k++) sum += s->a[k*n+i] * s->a[k*n+j] ; if (i == j) err += fabs ( sum - 1.0 ) ; else err += fabs ( sum ) ; } for (j=0 ; j<n ; j++) { sum = 0.0 ; for (k=0 ; k<n ; k++) sum += s->v[k*n+i] * s->v[k*n+j] ; if (i == j) err += fabs ( sum - 1.0 ) ; else err += fabs ( sum ) ; } } printf ( " Orthog=%lf", err ) ; if (fabs(err) > 1.e-10) { printf ( "\a" ) ; getch() ; } s->backsub ( 1.e-8 , x ) ; err = 0.0 ; for (i=0 ; i<m ; i++) { sum = 0.0 ; for (j=0 ; j<n ; j++) sum += x[j] * sa[i*n+j] ; err += fabs ( sum - s->b[i] ) ; } printf ( " Back=%lf", err ) ; if ((m == n) && (fabs(err) > 1.e-10)) { printf ( "\a" ) ; getch() ; } }}#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -