📄 triuaux.c
字号:
/* Complex case. Assume IM diag(l) is all-0.
Solves L' * y = x */
void prpipartbwsolve(double *y,double *ypi, const double *l,const double *lpi,
const double *x,const double *xpi,
const int m, const int j)
{
int k;
const double *lk, *lkpi;
double ldoty, ldotyim;
/* ------------------------------------------------------------
The last equation, l(:,m)'*y=x(m), yields y(m) = x(m)/l(m,m).
For k = m-1:j, we solve
l(k+1:m,k)'*y(k+1:m) + l(k,k)*y(k) = x(k). Assume l(k,k) == 0.
------------------------------------------------------------ */
k = m-1;
lk = l + m*k; lkpi = lpi + m*k;
ldoty = 0.0; ldotyim = 0.0;
for(k = m-1; k > j; k--){
/* y(k) = (x(k) - l(k+1:m,k)'*y(k+1:m)) / l(k,k) */
y[k] = (x[k] - ldoty) / lk[k];
ypi[k] = (xpi[k] - ldotyim) / lk[k];
/* ldoty = l(k:m,k-1)'*y(k:m) */
lk -= m; lkpi -= m;
ldoty = realdot(lk+k,y+k,m-k) + realdot(lkpi+k,ypi+k,m-k);;
ldotyim = realdot(lk+k,ypi+k,m-k) - realdot(lkpi+k,y+k,m-k);;
}
y[j] = (x[j] - ldoty) / lk[j];
ypi[j] = (xpi[j] - ldotyim) / lk[j];
}
/* ************************************************************
PROCEDURE invutxu: Computes U'\(X/U) in only n*(n+1)^2/2
multiplications.
INPUT
u - m x m full matrix, triu(u) = U.
x - m x m full symmetric matrix: instead of triu(X/U), we actually
compute tril(U'\X)' literally.
m - order.
OUTPUT
y - m x m full matrix; triu(Y) = triu(U'\X/U).
WORK
xu - length m^2 working vector, to store triu(X/U).
************************************************************ */
void invutxu(double *y, const double *u, const double *x,
const int m, double *xu)
{
int j, jcol;
/* ------------------------------------------------------------
Compute xu = (X/U)' = U'\X because X is symmetric.
------------------------------------------------------------ */
jcol = 0;
for(j = 0; j < m; j++, jcol += m) /* Computes complete U'\X */
partfwsolve(xu+jcol, u, x+jcol, m,m);
/* ------------------------------------------------------------
Let triu(xu) = tril(ux)' = triu(X/U) NB:discards triu(U'\X)
------------------------------------------------------------ */
tril2sym(xu,m);
/* ------------------------------------------------------------
Compute y = triu(U'\XU): 1*n + 2*(n-1) + ... + n*1 mults.
viz. fwsolve to j-th entry in column j.
------------------------------------------------------------ */
for(j = 1, jcol = 0; j <= m; j++, jcol += m)
partfwsolve(y+jcol, u, xu+jcol, m,j);
}
/* ************************************************************
PROCEDURE prpiinvutxu: Computes U'\(X/U) with complex data.
INPUT
u - m x m full matrix, triu(u) = U.
x - m x m full symmetric matrix: instead of triu(X/U), we actually
compute tril(U'\X)' literally.
m - order.
OUTPUT
y - m x m full matrix; triu(Y) = triu(U'\X/U).
WORK
xu - length 2 * m^2 working vector, to store triu(X/U).
************************************************************ */
void prpiinvutxu(double *y,double *ypi, const double *u,const double *upi,
const double *x,const double *xpi, const int m, double *xu)
{
int j, jcol;
double *xupi;
/* ------------------------------------------------------------
Partition xu in real and imaginary part
------------------------------------------------------------ */
xupi = xu + SQR(m);
/* ------------------------------------------------------------
Compute xu = (X/U)' = U'\X because X is symmetric.
------------------------------------------------------------ */
jcol = 0;
for(j = 0; j < m; j++, jcol += m) /* Computes complete U'\X */
prpipartfwsolve(xu+jcol,xupi+jcol, u,upi, x+jcol,xpi+jcol, m,m);
/* ------------------------------------------------------------
Let triu(xu) = tril(ux)' = triu(X/U) NB:discards triu(U'\X)
------------------------------------------------------------ */
tril2herm(xu,xupi,m);
/* ------------------------------------------------------------
Compute y = triu(U'\XU): 1*n + 2*(n-1) + ... + n*1 mults.
viz. fwsolve to j-th entry in column j.
------------------------------------------------------------ */
for(j = 1, jcol = 0; j <= m; j++, jcol += m)
prpipartfwsolve(y+jcol,ypi+jcol, u,upi, xu+jcol,xupi+jcol, m,j);
}
/* ************************************************************
PROCEDURE invltxl: Computes L'\(X/L) in only n*(n+1)^2/2
multiplications.
INPUT
l - m x m full matrix, tril(l) = L.
x - m x m full symmetric matrix: instead of tril(X/L), we actually
compute triu(L'\X)' literally.
m - order.
OUTPUT
y - m x m full matrix; tril(Y) = tril(L'\X/L).
WORK
xl - length m^2 working vector, to store tril(X/L).
************************************************************ */
void invltxl(double *y, const double *l, const double *x,
const int m, double *xl)
{
int j, jcol;
/* ------------------------------------------------------------
Compute xl = tril(X'/L)' = triu(L'\X)
------------------------------------------------------------ */
jcol = 0;
for(j = 0; j < m; j++, jcol += m) /* Computes complete L'\X */
partbwsolve(xl+jcol, l, x+jcol, m,0);
/* ------------------------------------------------------------
Let tril(xl) = triu(lx)' = tril(X/L) NB:discards tril(L'\X)
------------------------------------------------------------ */
triu2sym(xl,m);
/* ------------------------------------------------------------
Compute y = tril(L'\XL): 1*n + 2*(n-1) + ... + n*1 mults.
viz. bwsolve to j-th entry in column j=0:m-1.
------------------------------------------------------------ */
for(j = 0, jcol = 0; j < m; j++, jcol += m)
partbwsolve(y+jcol, l, xl+jcol, m,j);
}
/* complex case. xl is 2*m^2. assume IM diag(l) is all-0. */
void prpiinvltxl(double *y,double *ypi, const double *l,const double *lpi,
const double *x,const double *xpi, const int m, double *xl)
{
int j, jcol;
double *xlpi;
/* ------------------------------------------------------------
Partition xl in real and imaginary part
------------------------------------------------------------ */
xlpi = xl + SQR(m);
/* ------------------------------------------------------------
Compute xl = tril(X'/L)' = triu(L'\X)
------------------------------------------------------------ */
jcol = 0;
for(j = 0; j < m; j++, jcol += m) /* Computes complete L'\X */
prpipartbwsolve(xl+jcol,xlpi+jcol, l,lpi, x+jcol,xpi+jcol, m,0);
/* ------------------------------------------------------------
Let tril(xl) = triu(lx)' = tril(X/L) NB:discards tril(L'\X)
------------------------------------------------------------ */
triu2herm(xl,xlpi,m);
/* ------------------------------------------------------------
Compute y = tril(L'\XL)
by bwsolve to j-th entry in column j=0:m-1.
------------------------------------------------------------ */
for(j = 0, jcol = 0; j < m; j++, jcol += m)
prpipartbwsolve(y+jcol,ypi+jcol, l,lpi, xl+jcol,xlpi+jcol, m,j);
}
/* ************************************************************
PROCEDURE psdscaleK - Computes y = D(d)x over PSD blocks.
Uses D=U'*U factorization, (transp == 0) Y = UXU'
or (transp == 1) Y = U'XU.
INPUT
x - length lenud input vector.
ud - Cholesky factor of d for PSD part (after PERM ordering).
perm - ordering: UD=chol(d(perm,perm)), for numerical stability.
If perm==NULL, then no reordering is applied.
cK - structure describing symmetric cone K.
OUTPUT
y - length lenud output vector, y=D(d)x for PSD blocks.
WORK
fwork - fwork(2*max(K.s)^2): length 2 * max(rmaxn^2,2*hmaxn^2)
working vector.
REMARK lenud := cK.rDim + cK.hDim
************************************************************ */
void psdscaleK(double *y, const double *ud, const int *perm, const double *x,
const coneK cK, const char transp, double *fwork)
{
int k,nk,nksqr;
double *z, *zpi;
char use_pivot;
/* ------------------------------------------------------------
Partition fwork into fwork(psdblk) and z(psdblk), where
psdblk = max(rmaxn^2,2*hmaxn^2). Let zpi = z+hmaxn^2.
------------------------------------------------------------ */
use_pivot = (perm != (const int *) NULL);
z = fwork + MAX(SQR(cK.rMaxn),2*SQR(cK.hMaxn));
zpi = z + SQR(cK.hMaxn);
/* ------------------------------------------------------------
PSD: (I) full and !transp
Y = Ld' * X * Ld. Let Y=X(p,p), where Ld = Ud' (stored in tril(Ud)).
tril(Y_new) = tril(Ld'* tril(Y*Ld)).
------------------------------------------------------------ */
if(!transp){
if(use_pivot){ /* with pivoting */
for(k = 0; k < cK.rsdpN; k++){ /* real symmetric */
nk = cK.sdpNL[k];
matperm(z,x,perm,nk);
realltxl(y,ud,z,nk,fwork);
tril2sym(y,nk);
nksqr = SQR(nk);
y += nksqr; ud += nksqr;
x += nksqr; perm += nk;
}
for(; k < cK.sdpN; k++){ /* complex Hermitian */
nk = cK.sdpNL[k];
nksqr = SQR(nk);
matperm(z,x,perm,nk);
matperm(zpi,x+nksqr,perm,nk);
prpiltxl(y,y+nksqr,ud,ud+nksqr,z,zpi,nk,fwork);
tril2herm(y,y+nksqr,nk);
nksqr += nksqr; /* 2*n^2 for real+imag */
y += nksqr; ud += nksqr;
x += nksqr; perm += nk;
}
}
else{ /* without pivoting */
for(k = 0; k < cK.rsdpN; k++){ /* real symmetric */
nk = cK.sdpNL[k];
realltxl(y,ud,x,nk,fwork);
tril2sym(y,nk);
nksqr = SQR(nk);
y += nksqr; ud += nksqr;
x += nksqr;
}
for(; k < cK.sdpN; k++){ /* complex Hermitian */
nk = cK.sdpNL[k];
nksqr = SQR(nk);
prpiltxl(y,y+nksqr,ud,ud+nksqr,x,x+nksqr,nk,fwork);
tril2herm(y,y+nksqr,nk);
nksqr += nksqr; /* 2*n^2 for real+imag */
y += nksqr; ud += nksqr;
x += nksqr;
}
}
}
else{
/* ------------------------------------------------------------
(II) transp == 1 then Y = Ud' * X * Ud
------------------------------------------------------------ */
if(use_pivot){ /* with pivoting */
for(k = 0; k < cK.rsdpN; k++){ /* real symmetric */
nk = cK.sdpNL[k];
realutxu(z,ud,x,nk,fwork);
triu2sym(z,nk);
invmatperm(y,z,perm,nk); /* Y(perm,perm) = Z */
nksqr = SQR(nk);
y += nksqr; ud += nksqr;
x += nksqr; perm += nk;
}
for(; k < cK.sdpN; k++){ /* complex Hermitian */
nk = cK.sdpNL[k];
nksqr = SQR(nk);
prpiutxu(z,zpi, ud,ud+nksqr,x,x+nksqr,nk,fwork);
triu2herm(z,zpi,nk);
invmatperm(y,z,perm,nk); /* Y(perm,perm) = Z */
invmatperm(y+nksqr,zpi,perm,nk); /* imaginary part */
nksqr += nksqr; /* 2*n^2 for real+imag */
y += nksqr; ud += nksqr;
x += nksqr; perm += nk;
}
}
else{ /* without pivoting */
for(k = 0; k < cK.rsdpN; k++){ /* real symmetric */
nk = cK.sdpNL[k];
realutxu(y,ud,x,nk,fwork);
triu2sym(y,nk);
nksqr = SQR(nk);
y += nksqr; ud += nksqr;
x += nksqr;
}
for(; k < cK.sdpN; k++){ /* complex Hermitian */
nk = cK.sdpNL[k];
nksqr = SQR(nk);
prpiutxu(y,y+nksqr, ud,ud+nksqr,x,x+nksqr,nk,fwork);
triu2herm(y,y+nksqr,nk);
nksqr += nksqr; /* 2*n^2 for real+imag */
y += nksqr; ud += nksqr;
x += nksqr;
}
}
}
}
#ifdef SEDUMI_OLD
/* ************************************************************
PROCEDURE scaleK - Computes y = D(d)x.
For PSD, uses D=U'*U factorization, (transp == 0) Y = UXU'
or (transp == 1) Y = U'XU.
INPUT
x - length N(K) input vector.
d - scaling vector, only LP and Lorentz part needed.
ud - Cholesky factor of d for PSD part (after PERM ordering).
qdetd - sqrt(det(d)) for Lorentz part.
perm - ordering: UD=chol(d(perm,perm)), for numerical stability.
cK - structure describing symmetric cone K.
invdx - length cK.qDim vector containing D(d)\x for Lorentz part.
This is optional. If invdx == NULL, then not used.
OUTPUT
y - length N(K) output vector, y=D(d)x.
dmult - If !NULL then lorN-vector containing mu[k], such that
(D(d)x)_k = y_k + mu[k] * d_k, where "_k" is the kth Lorentz block.
WORK
fwork - fwork(2*max(K.s)^2): length 2 * max(rmaxn^2,2*hmaxn^2)
working vector.
************************************************************ */
void scaleK(double *y, double *dmult, const double *d, const double *ud,
const double *qdetd, const int *perm, const double *x,
const coneK cK, const char transp, const double *invdx,
double *fwork)
{
int k,nk;
double detdk;
/* ------------------------------------------------------------
LP: y = d .* x
------------------------------------------------------------ */
realHadamard(y, d,x,cK.lpN);
y += cK.lpN; /* Next, point to lorentz & sdp blocks */
d += cK.lpN; x += cK.lpN;
/* ------------------------------------------------------------
LORENTZ (1/3): y = D(d) x
------------------------------------------------------------ */
if(dmult == (double *) NULL)
for(k = 0; k < cK.lorN; k++){
nk = cK.lorNL[k];
qlmul(y,d,x,qdetd[k],nk);
y += nk; d += nk; x +=nk;
}
else
/* ------------------------------------------------------------
LORENTZ (2/3): D(d) x = y + dmult * d. This storage scheme avoids
cancelation in y.
------------------------------------------------------------ */
if(invdx == (const double *) NULL)
for(k = 0; k < cK.lorN; k++){
nk = cK.lorNL[k];
dmult[k] = qscale(y,d,x,qdetd[k],nk);
y += nk; d += nk; x +=nk;
}
else{
/* ------------------------------------------------------------
LORENTZ (3/3): D(d) x = D(d^2) invdx = y + dmult * d
USES D(d^2)invdx = (d'*invdx) * d + det(d) * [-invdx(1); invdx(2:nk)]
We let y = det(d) * [-invdx(1); invdx(2:nk)].
------------------------------------------------------------ */
for(k = 0; k < cK.lorN; k++){
nk = cK.lorNL[k];
dmult[k] = realdot(d,invdx,nk);
detdk = SQR(qdetd[k]);
y[0] = - detdk * invdx[0];
scalarmul(y+1, detdk,invdx+1,nk-1);
y += nk; d += nk; invdx +=nk;
}
x += cK.qDim; /* point beyond Lorentz */
}
/* ------------------------------------------------------------
PSD scale
------------------------------------------------------------ */
psdscaleK(y, ud, perm, x, cK, transp, fwork);
}
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -