📄 getada3.c
字号:
mxAssert(iwsiz >= MAX(pcK->rMaxn,pcK->hMaxn), "iwork too small in getada3()");
daj = fwork; /* lenfull */
fwork = daj + lenud; /* fwsiz */
fwsiz -= lenud;
mxAssert(fwsiz >= 2 * MAX(SQR(pcK->rMaxn),2 * SQR(pcK->hMaxn)), "fwork too small in getada3()");
/* ------------------------------------------------------------
Make "lenfull" vector index valid into daj.
------------------------------------------------------------ */
daj -= blkstart[0]; /* We'll use daj[blkstart[0]:end] */
/* ------------------------------------------------------------
Initialize dzknnz = 0, meaning dz=[]. Later we will merge
columns from dzstruct, with dz, and partition into selected blocks.
------------------------------------------------------------ */
mxAssert(dznnz > 0, ""); /* we know that there exist nonempty PSD: */
for(i = 0; i < nblk; i++)
dzknnz[i] = 0;
for(j = 1; dzstructjc[j] == 0; j++); /* 1st nonzero PSD constraint */
/* ============================================================
MAIN getada LOOP: loop over nodes perm(0:m-1)
============================================================ */
for(--j; j < m; j++){
permj = perm[j];
/* ------------------------------------------------------------
Make dzir: the PSD-nonzero locations, with pointers
to the selected PSD blocks. nz-locs = merge(dzir,dzstruct(:,j)).
------------------------------------------------------------ */
i = dzstructjc[j];
while( i < dzstructjc[j+1]){
k = xblk[dzstructir[i]];
knz = i; /* add dzstructir(knz:i-1) */
intbsearch(&i, dzstructir, dzstructjc[j+1], blkstart[k+1]);
mxAssert(i > knz,"");
exmerge(dzir+dzjc[k], dzstructir+knz, dzknnz[k],i-knz,iwsiz,
cwork,iwork);
dzknnz[k] += i-knz; /* number added */
}
/* ------------------------------------------------------------
Compute daj = P(d)*aj = vec(D*Aj*D).
------------------------------------------------------------ */
nnzbj = spsqrscale(daj,blksj,dzjc,dzir,dzknnz, udsqr,
At.ir,At.pr,Ajc1[permj],At.jc[permj+1],
blkstart, xblk, psdNL, rsdpN, fwork, iwork);
/* iwork(max(K.s)), fwork(2 * (rMaxn^2 + 2*hMaxn^2))*/
mxAssert(nnzbj <= nblk, ""); /* number of nz-matrix-blocks */
/* ------------------------------------------------------------
For all i with invpermi < j:
ada_ij = a_i'*daj.
------------------------------------------------------------ */
for(inz = ada.jc[permj]; inz < ada.jc[permj+1]; inz++){
i = ada.ir[inz];
if(invperm[i] <= j){
adaij = ada.pr[inz];
if(invperm[i] < j)
for(knz = Ajc1[i]; knz < At.jc[i+1]; knz++)
adaij += At.pr[knz] * daj[At.ir[knz]];
else{ /* diag entry: absd[j] = sum(abs(aj.*daj)) */
absadajj = adaij;
for(knz = Ajc1[i]; knz < At.jc[i+1]; knz++){
termj = At.pr[knz] * daj[At.ir[knz]];
adaij += termj;
absadajj += fabs(termj);
}
absd[permj] = absadajj;
}
ada.pr[inz] = adaij;
}
}
/* ------------------------------------------------------------
Set daj = all-0
------------------------------------------------------------ */
for(knz = 0; knz < nnzbj; knz++){
i = blksj[knz];
spzeros(daj,dzir+dzjc[i],dzknnz[i]);
}
} /* j = 0:m-1 */
mxAssert(dzjc[nblk-1]+dzknnz[nblk-1] == dznnz,"");
}
/* ============================================================
MEXFUNCTION
============================================================ */
/* ************************************************************
PROCEDURE mexFunction - Entry for Matlab
************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
const int nrhs, const mxArray *prhs[])
{
mxArray *myplhs[NPAROUT];
coneK cK;
const mxArray *MY_FIELD;
int lenfull, lenud, m, i, j, k, fwsiz, iwsiz, dznnz, maxadd;
const double *permPr, *Ajc1Pr, *blkstartPr, *udsqr;
const int *dzstructjc, *dzstructir;
double *fwork, *absd;
int *blkstart, *iwork, *Ajc1, *psdNL, *xblk, *perm, *invperm, *dzjc;
char *cwork;
jcir At, ada;
/* ------------------------------------------------------------
Check for proper number of arguments
------------------------------------------------------------ */
mxAssert(nrhs >= NPARIN, "getADA requires more input arguments.");
mxAssert(nlhs <= NPAROUT, "getADA produces less output arguments.");
/* ------------------------------------------------------------
Disassemble cone K structure
------------------------------------------------------------ */
conepars(K_IN, &cK);
/* ------------------------------------------------------------
Compute some statistics based on cone K structure
------------------------------------------------------------ */
lenud = cK.rDim + cK.hDim; /* for PSD */
lenfull = cK.lpN + cK.qDim + lenud;
/* ------------------------------------------------------------
Allocate working array blkstart(|K.s|+1).
------------------------------------------------------------ */
blkstart = (int *) mxCalloc(cK.sdpN + 1, sizeof(int));
/* ------------------------------------------------------------
Translate blkstart from Fortran-double to C-int
------------------------------------------------------------ */
MY_FIELD = mxGetField(K_IN,0,"blkstart"); /*K.blkstart*/
mxAssert( MY_FIELD != NULL, "Missing K.blkstart.");
mxAssert(mxGetM(MY_FIELD) * mxGetN(MY_FIELD) == 2+cK.lorN+cK.sdpN, "Size mismatch K.blkstart.");
blkstartPr = mxGetPr(MY_FIELD) + cK.lorN + 1; /* point to start of PSD */
for(i = 0; i <= cK.sdpN; i++){ /* to integers */
j = blkstartPr[i];
blkstart[i] = --j;
}
/* ------------------------------------------------------------
INPUT sparse constraint matrix At:
------------------------------------------------------------ */
mxAssert(mxGetM(AT_IN) == lenfull, "Size mismatch At"); /* At */
m = mxGetN(AT_IN);
mxAssert(mxIsSparse(AT_IN), "At should be sparse.");
At.pr = mxGetPr(AT_IN);
At.jc = mxGetJc(AT_IN);
At.ir = mxGetIr(AT_IN);
/* ------------------------------------------------------------
Get SCALING VECTOR: udsqr
------------------------------------------------------------ */
mxAssert(mxGetM(UDSQR_IN) * mxGetN(UDSQR_IN) == lenud, "udsqr size mismatch."); /* udsqr */
udsqr = mxGetPr(UDSQR_IN);
/* ------------------------------------------------------------
Get Ajc1
------------------------------------------------------------ */
mxAssert(mxGetM(AJC1_IN)*mxGetN(AJC1_IN) == m, "Ajc1 size mismatch");
Ajc1Pr = mxGetPr(AJC1_IN);
/* ------------------------------------------------------------
DISASSEMBLE Aord structure: Aord.{dz,sperm}
------------------------------------------------------------ */
mxAssert(mxIsStruct(AORD_IN), "Aord should be a structure.");
MY_FIELD = mxGetField(AORD_IN,0,"dz"); /* Aord.dz */
mxAssert( MY_FIELD != NULL, "Missing field Aord.dz.");
mxAssert(mxGetN(MY_FIELD) >= m, "Size mismatch Aord.dz.");
mxAssert(mxGetM(MY_FIELD) == lenfull, "Aord.dz size mismatch");
mxAssert(mxIsSparse(MY_FIELD), "Aord.dz should be sparse.");
dzstructjc = mxGetJc(MY_FIELD);
dzstructir = mxGetIr(MY_FIELD);
MY_FIELD = mxGetField(AORD_IN,0,"sperm"); /* Aord.sperm */
mxAssert( MY_FIELD != NULL, "Missing field Aord.sperm.");
mxAssert(mxGetM(MY_FIELD) * mxGetN(MY_FIELD) == m, "Aord.sperm size mismatch");
permPr = mxGetPr(MY_FIELD);
/* ------------------------------------------------------------
Allocate output matrix ADA as a duplicate of ADA_IN:
------------------------------------------------------------ */
mxAssert(mxGetM(ADA_IN) == m && mxGetN(ADA_IN) == m, "Size mismatch ADA.");
mxAssert(mxIsSparse(ADA_IN), "ADA should be sparse.");
ADA_OUT = mxDuplicateArray(ADA_IN); /* ADA = ADA_IN */
ada.jc = mxGetJc(ADA_OUT);
ada.ir = mxGetIr(ADA_OUT);
ada.pr = mxGetPr(ADA_OUT);
/* ------------------------------------------------------------
Create output vector absd(m)
------------------------------------------------------------ */
ABSD_OUT = mxCreateDoubleMatrix(m,1,mxREAL);
absd = mxGetPr(ABSD_OUT);
/* ------------------------------------------------------------
The following ONLY if there are PSD blocks:
------------------------------------------------------------ */
if(cK.sdpN > 0){
maxadd = dzstructjc[1];
for(i = 1; i < m; i++)
if(dzstructjc[i+1] > dzstructjc[i] + maxadd)
maxadd = dzstructjc[i+1] - dzstructjc[i];
/* ------------------------------------------------------------
ALLOCATE integer work array iwork(iwsiz), with
iwsiz = MAX(m, 2*nblk + dznnz +
max(maxadd+2+log_2(1+maxadd), max(nk(PSD)))),
where dznnz = dzstructjc[m].
------------------------------------------------------------ */
dznnz = dzstructjc[m];
iwsiz = floor(log(1+maxadd)/log(2)); /* double to int */
iwsiz += maxadd + 2;
iwsiz = 2*cK.sdpN + dznnz + MAX(iwsiz,MAX(cK.rMaxn,cK.hMaxn));
iwork = (int *) mxCalloc(MAX(iwsiz,m), sizeof(int));
/* ------------------------------------------------------------
ALLOCATE integer working arrays:
Ajc1(m) psdNL[cK.sdpN], dzjc(cK.sdpN+1), perm(m), invperm(m), xblk(lenud).
cwork(maxadd).
------------------------------------------------------------ */
Ajc1 = (int *) mxCalloc(MAX(m,1), sizeof(int));
psdNL = (int *) mxCalloc(1+2*cK.sdpN + lenud, sizeof(int));
xblk = psdNL + cK.sdpN; /* Not own alloc: we'll subtract blkstart[0] */
dzjc = xblk + lenud; /*dzjc(sdpN+1) */
perm = (int *) mxCalloc(MAX(2 * m,1), sizeof(int));
invperm = perm + m; /* invperm(m) */
cwork = (char *) mxCalloc(MAX(1,maxadd), sizeof(char));
/* ------------------------------------------------------------
ALLOCATE float working array:
fwork[fwsiz] with fwsiz = lenud + 2 * max(rMaxn^2, 2*hMaxn^2).
------------------------------------------------------------ */
fwsiz = lenud + 2 * MAX(SQR(cK.rMaxn),2*SQR(cK.hMaxn));
fwork = (double *) mxCalloc(MAX(fwsiz,1), sizeof(double));
/* ------------------------------------------------------------
perm to integer C-style
------------------------------------------------------------ */
for(i = 0; i < m; i++){
j = permPr[i];
perm[i] = --j;
}
/* ------------------------------------------------------------
Let invperm(perm) = 0:m-1.
------------------------------------------------------------ */
for(i = 0; i < m; i++)
invperm[perm[i]] = i;
/* ------------------------------------------------------------
Let psdNL = K.s in integer, Ajc1 = Ajc1Pr in integer.
------------------------------------------------------------ */
for(i = 0; i < cK.sdpN; i++) /* K.s */
psdNL[i] = cK.sdpNL[i];
for(i = 0; i < m; i++)
Ajc1[i] = Ajc1Pr[i];
/* ------------------------------------------------------------
Let k = xblk(j-blkstart[0]) iff
blkstart[k] <= j < blkstart[k+1], k=0:nblk-1.
------------------------------------------------------------ */
j = blkstart[0];
xblk -= j; /* Make blkstart[0]:blkstart[end] valid indices */
for(k = 0; k < cK.sdpN; k++){
i = blkstart[k+1];
while(j < i)
xblk[j++] = k;
}
/* ------------------------------------------------------------
ACTUAL COMPUTATION: handle constraint aj=At(:,perm(j)), j=0:m-1.
------------------------------------------------------------ */
dzblkpartit(dzjc, dzstructir, xblk, dznnz, cK.sdpN);
getada3(ada, absd, At,udsqr,Ajc1, dzjc,dzstructjc,dzstructir, blkstart,
xblk,psdNL, perm,invperm, m,lenud, &cK, fwork,fwsiz,
iwork,iwsiz, cwork);
/* ------------------------------------------------------------
RELEASE WORKING ARRAYS (for PSD blocks only).
------------------------------------------------------------ */
mxFree(fwork);
mxFree(cwork);
mxFree(perm);
mxFree(psdNL);
mxFree(Ajc1);
} /* ~isempty(K.s) */
/* ------------------------------------------------------------
If no PSD-blocks, than we merely compute absd = diag(ADA)
ALLOCATE integer work array iwork(m), with
------------------------------------------------------------ */
else{
iwork = (int *) mxCalloc(MAX(1,m), sizeof(int));
cpspdiag(absd, ada,m);
}
/* ------------------------------------------------------------
Let ADA = (ADA+ADA')/2, so that it gets symmetric.
------------------------------------------------------------ */
spmakesym(ada,m,iwork); /* uses iwork(m) */
/* ------------------------------------------------------------
RELEASE WORKING ARRAYS iwork and blkstart.
------------------------------------------------------------ */
mxFree(iwork);
mxFree(blkstart);
/* ------------------------------------------------------------
Copy requested output parameters (at least 1), release others.
------------------------------------------------------------ */
i = MAX(nlhs, 1);
memcpy(plhs,myplhs, i * sizeof(mxArray *));
for(; i < NPAROUT; i++)
mxDestroyArray(myplhs[i]);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -