📄 vectril.c
字号:
zpr[jnz] = xpr[inz++] - ypr[i];
else
zpr[jnz] = -ypr[i];
zir[jnz++] = yir[i];
knz = yinx[i+2]-inz;
memcpy(zir + jnz, xir + inz, knz * sizeof(int));
memcpy(zpr + jnz, xpr + inz, knz * sizeof(double));
jnz += knz;
}
return jnz;
}
/* ************************************************************
PROCEDURE sptotril - For sparse x=vec(X), lets
z = vec( tril(X) + triu(X,1)' ). If skew = 1 then
z = vec( tril(X) - triu(X)' ).
INPUT
xir, xpr, pxjc0, xjc1 - sparse input vector, *pxjc0 points to
first nonzero of vectorized matrix X.
first - subscript of X(1,1) in long vector x.
n - order of n x n matrix X.
skew - if 1, then set SUBTRACT triu(X,1)' and set diag(z)=all-0.
iwsize - n + xnnz + 1+nnz(triu(X,1)) + log_2(1+nnz(triu(X,1))).
Observe that nnz(triu(X,1)) <= MIN(n*(n-1)/2, xnnz), and
xnnz <= MIN(n^2, xjc1-*pxjc0). Thus
iwsize <= n*(2*n+1)+log_2(1+n*(n-1)/2).
OUTPUT
zir - length znnz int array, subscripts of z := vec(tril(x)+triu(x,1)').
zpr - length znnz vector, nonzeros of z.
WORK
cwork - length nnz(triu(X,1)) <= n*(n-1)/2 char array.
iwork - length iwsize integer working array
ypr - length xnnz vector; xnnz <= n^2.
RETURNS znnz
************************************************************ */
int sptotril(int *zir, double *zpr, const int *xir, const double *xpr,
int *pxjc0, const int xjc1, const int first, const int n,
const bool skew, int iwsize, char *cwork, int *iwork,
double *ypr)
{
int xjc0, xnnz, trilnnz, triujc0;
int *triujc, *yir;
/* ------------------------------------------------------------
Let iwork[0:n-2] point to row-starts for storing triu(X,1)
row-wise. Let xnnz be nnz(X). Update *pxjc0 to point beyond this
block
------------------------------------------------------------ */
xjc0 = *pxjc0;
xnnz = sptriujcT(iwork, xir, xjc0, xjc1, first, n);
*pxjc0 = xjc0 + xnnz;
/* ------------------------------------------------------------
Partition integer working array
------------------------------------------------------------ */
triujc = iwork;
yir = iwork + (n-1);
iwork = yir + xnnz;
iwsize -= n-1 + xnnz;
/* ------------------------------------------------------------
------------------------------------------------------------ */
if(n > 1)
triujc0 = triujc[0];
else
triujc0 = xnnz; /* 1 x 1 matrix --> triu(X)=[] */
trilnnz = sptrilandtriu(yir, ypr, triujc, xir,xpr,xjc0,xjc1, first,n, skew);
if(!skew)
return spadd(zir,zpr, yir,ypr, trilnnz, yir+triujc0,ypr+triujc0,
xnnz-triujc0, iwsize, cwork, iwork);
else
return spsub(zir,zpr, yir,ypr, trilnnz, yir+triujc0,ypr+triujc0,
xnnz-triujc0, iwsize, cwork, iwork);
}
/* ************************************************************
PROCEDURE vectril - Applies sptotril(xk) for each PSD block k.
On output, each PSD block is lower triangular, i.e.
Zk = tril(Xk+Xk')/2.
INPUT
xir,xpr,xnnz
psdNL - K.s
blkstart - length psdN+1 array. PSD block k has subscripts
blkstart[k]:blkstart[k+1]-1.
isblk - length psdDim array, with k = xblk(i-blkstart[0]) iff
blkstart[k] <= i < blkstart[k+1], k=0:psdN-1.
rpsdN - number of real PSD blocks
psdN - number of PSD blocks
iwsize - maxn*(2*maxn+1)+log_2(1+maxn*(maxn-1)/2), where maxn := max(K.s).
OUTPUT
zir - length znnz <= xnnz int array: subscripts of z = vectril(x).
zpr - length znnz <= xnnz vector: nonzeros of z = vectril(x).
WORKING ARRAYS
cwork - length maxn*(maxn-1)/2 char array, where maxn := max(K.s).
iwork - length iwsize integer working array
fwork - length max(K.s.^2) vector. (Note: not double for Hermitian
blocks, since we treat real and imag parts seperately.)
RETURNS znnz
************************************************************ */
int vectril(int *zir, double *zpr, const int *xir, const double *xpr,
const int xnnz, const int *psdNL,
const int *blkstart, const int *isblk,
const int rpsdN, const int psdN, const int iwsize,
char *cwork, int *iwork, double *fwork)
{
int inz, jnz, k, nk;
/* ------------------------------------------------------------
Copy f,l,q,r parts without change. Let inz point to first
PSD-nonzero in x, jnz in z.
------------------------------------------------------------ */
inz = 0; /* pointer into x */
intbsearch(&inz, xir, xnnz, blkstart[0]); /* inz points to start PSD */
isblk -= blkstart[0];
memcpy(zir, xir, inz * sizeof(int));
memcpy(zpr, xpr, inz * sizeof(double));
jnz = inz; /* jnz points to start PSD in z */
/* ------------------------------------------------------------
Process all PSD blocks
------------------------------------------------------------ */
while(inz < xnnz){
k = isblk[xir[inz]];
nk = psdNL[k];
jnz += sptotril(zir + jnz, zpr + jnz, xir, xpr, &inz, xnnz, blkstart[k],
nk,0, iwsize, cwork, iwork, fwork);
/* ------------------------------------------------------------
For the imaginary part, we do a skew transpose: tril(IM Xk)-triu(IM Xk)'.
This will make the diagonal of the imaginary block zero.
------------------------------------------------------------ */
if(k >= rpsdN){
jnz += sptotril(zir + jnz, zpr + jnz, xir, xpr, &inz, xnnz,
blkstart[k]+SQR(nk), nk,1, iwsize, cwork, iwork, fwork);
}
}
return jnz;
}
/* ============================================================
MAIN: MEXFUNCTION
============================================================ */
/* ************************************************************
PROCEDURE mexFunction - Entry for Matlab
y = vectril(x,K)
For the PSD submatrices, we let Yk = tril(Xk+Xk').
Complex numbers are stored as vec([real(Xk) imag(Xk)]).
NB: x and y are sparse.
************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
const int nrhs, const mxArray *prhs[])
{
int i, j, k, jnz, m,lenfull, firstPSD, maxn, iwsize;
jcir x,y;
int *iwork, *psdNL, *blkstart, *xblk;
char *cwork;
double *fwork;
coneK cK;
/* ------------------------------------------------------------
Check for proper number of arguments
------------------------------------------------------------ */
mxAssert(nrhs >= NPARIN, "vectril requires more input arguments");
mxAssert(nlhs <= NPAROUT, "vectril produces less output arguments");
/* ------------------------------------------------------------
Disassemble cone K structure
------------------------------------------------------------ */
conepars(K_IN, &cK);
/* ------------------------------------------------------------
Compute statistics based on cone K structure
------------------------------------------------------------ */
firstPSD = cK.frN + cK.lpN + cK.qDim;
for(i = 0; i < cK.rconeN; i++) /* add dim of rotated cone */
firstPSD += cK.rconeNL[i];
lenfull = firstPSD + cK.rDim + cK.hDim;
/* ------------------------------------------------------------
Get inputs x, blkstart
------------------------------------------------------------ */
mxAssert(mxGetM(X_IN) == lenfull, "X size mismatch.");
m = mxGetN(X_IN); /* number of columns to handle */
mxAssert( mxIsSparse(X_IN), "X should be sparse.");
x.pr = mxGetPr(X_IN);
x.jc = mxGetJc(X_IN);
x.ir = mxGetIr(X_IN);
/* ------------------------------------------------------------
Allocate output Y = sparse([],[],[],length(x),m,nnz(x))
------------------------------------------------------------ */
Y_OUT = mxCreateSparse(lenfull, m, x.jc[m], mxREAL);
y.pr = mxGetPr(Y_OUT);
y.jc = mxGetJc(Y_OUT);
y.ir = mxGetIr(Y_OUT);
y.jc[0] = 0;
/* ------------------------------------------------------------
If x = [], then we are ready with y=[]. Otherwise, proceed:
------------------------------------------------------------ */
if(x.jc[m] > 0){
/* ------------------------------------------------------------
Allocate iwork[iwsize],
iwsize := maxn*(2*maxn+1)+log_2(1+maxn*(maxn-1)/2), where maxn := max(K.s);
cwork[maxn*(maxn-1)/2], fwork(maxn^2), int psdNL(length(K.s)).
int blkstart(sdpN+1), xblk(sdpDim)
------------------------------------------------------------ */
maxn = MAX(cK.rMaxn,cK.hMaxn);
iwsize = log(1 + maxn*(maxn-1)/2) / log(2);
iwsize += maxn * (2*maxn+1);
iwork = (int *) mxCalloc(MAX(1,iwsize), sizeof(int));
cwork = (char *) mxCalloc(MAX(1,maxn*(maxn-1)/2), sizeof(char));
fwork = (double *) mxCalloc(MAX(1,SQR(maxn)), sizeof(double));
psdNL = (int *) mxCalloc(MAX(1,cK.sdpN), sizeof(int));
blkstart = (int *) mxCalloc(1 + cK.sdpN, sizeof(int));
xblk = (int *) mxCalloc(MAX(1,cK.rDim + cK.hDim), sizeof(int));
/* ------------------------------------------------------------
double -> int for K.s
------------------------------------------------------------ */
for(i = 0; i < cK.sdpN; i++)
psdNL[i] = cK.sdpNL[i];
/* ------------------------------------------------------------
Let k = xblk(j-blkstart[0]) iff
blkstart[k] <= j < blkstart[k+1], k=0:psdN-1.
------------------------------------------------------------ */
j = firstPSD;
for(i = 0; i < cK.rsdpN; i++){ /* real sym */
blkstart[i] = j;
j += SQR(psdNL[i]);
}
for(; i < cK.sdpN; i++){ /* complex herm. */
blkstart[i] = j;
j += 2*SQR(psdNL[i]);
}
blkstart[cK.sdpN] = j;
mxAssert(j - firstPSD == cK.rDim + cK.hDim, "Size mismatch blkstart, K.");
j = 0;
for(k = 0; k < cK.sdpN; k++){
i = blkstart[k+1] - blkstart[0];
while(j < i)
xblk[j++] = k;
}
/* ------------------------------------------------------------
Let y(:,i)= vectril(x(:,i)), for i=1:m.
------------------------------------------------------------ */
jnz = 0; /* points into y */
for(i = 0; i < m; i++){
y.jc[i] = jnz;
jnz += vectril(y.ir+jnz,y.pr+jnz, x.ir+x.jc[i],x.pr+x.jc[i],
x.jc[i+1]-x.jc[i],
psdNL, blkstart, xblk, cK.rsdpN,cK.sdpN, iwsize,
cwork, iwork, fwork);
}
y.jc[m] = jnz; /* nnz written into y */
mxAssert(jnz <= x.jc[m],"");
/* ------------------------------------------------------------
REALLOC: Shrink Y to its current size
------------------------------------------------------------ */
jnz = MAX(jnz,1);
if( (y.pr = (double *) mxRealloc(y.pr, jnz*sizeof(double))) == NULL)
mexErrMsgTxt("Memory reallocation error");
mxSetPr(Y_OUT,y.pr);
if( (y.ir = (int *) mxRealloc(y.ir, jnz*sizeof(int))) == NULL)
mexErrMsgTxt("Memory reallocation error");
mxSetIr(Y_OUT,y.ir);
mxSetNzmax(Y_OUT,jnz);
/* ------------------------------------------------------------
Release working arrays
------------------------------------------------------------ */
mxFree(xblk);
mxFree(blkstart);
mxFree(psdNL);
mxFree(fwork);
mxFree(iwork);
mxFree(cwork);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -