📄 qpssvmlib.c
字号:
/*-----------------------------------------------------------------------
qpssvmlib.c: Library of solvers for QP task required in StructSVM learning.
Synopsis:
exitflag = qpssvm_solver( &get_col, diag_H, f, b, I, x, n, tmax,
tolabs, tolrel, &t, &History, verb );
Description:
It solves the following QP task:
min 0.5*x'*H*x + f'*x
x
subject to
sum(x(find(I==k))) <= b for all k=1:max(I)
x >= 0
where I is a set of positive indices from (1 to max(I)).
A precision of the found solution is given by the parameters tmax,
tolabs and tolrel which define the stopping conditions:
UB-LB <= tolabs -> exitflag = 1 Abs. tolerance.
UB-LB <= UB*tolrel -> exitflag = 2 Relative tolerance.
t >= tmax -> exitflag = 0 Number of iterations.
UB ... Upper bound on the optimal solution, i.e., Q_P.
LB ... Lower bound on the optimal solution, i.e., Q_D.
t ... Number of iterations.
Inputs/Outputs:
const void* (*get_col)(long) retunr pointer to i-th column of H
diag_H [double n x n] diagonal of H.
f [double n x 1] is an arbitrary vector.
b [double 1 x 1] scalar
I [uint16_T n x 1] Indices (1..max(I)); max(I) <= n
x [double n x 1] solution vector (inital solution).
n [long 1 x 1] dimension of H.
tmax [long 1 x 1] Max number of steps.
tolrel [double 1 x 1] Relative tolerance.
tolabs [double 1 x 1] Absolute tolerance.
t [long 1 x 1] Number of iterations.
History [double 2 x t] Value of LB and UB wrt. number of iterations.
verb [int 1 x 1] if > 0 then prints info every verb-th iteation.
For more info refer to TBA
Modifications:
20-Feb-2006, VF
18-feb-2006, VF
-------------------------------------------------------------------- */
#include "mex.h"
#include "matrix.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <limits.h>
#define HISTORY_BUF 1000000
#define MINUS_INF INT_MIN
#define PLUS_INF INT_MAX
#define ABS(A) (((A) >= 0) ? (A) : (-A))
#define MIN(A,B) (((A) < (B)) ? (A) : (B))
#define MAX(A,B) (((A) > (B)) ? (A) : (B))
#define INDEX(ROW,COL,DIM) ((COL*DIM)+ROW)
/* --------------------------------------------------------------
QPSSVM solver
Usage: exitflag = qpssvm_solver( &get_col, diag_H, f, b, I, x, n, tmax,
tolabs, tolrel, &t, &History, verb );
-------------------------------------------------------------- */
int qpssvm_solver(const void* (*get_col)(long),
double *diag_H,
double *f,
double b,
uint16_T *I,
double *x,
long n,
long tmax,
double tolabs,
double tolrel,
long *ptr_t,
double **ptr_History,
long verb)
{
double *Hx;
double *d;
double *History;
double *col_u, *col_v;
double *tmp_ptr;
double LB;
double UB;
double tmp;
double yu;
double den1;
double num1;
double tau1;
double improv1;
double den2;
double num2;
double tau2;
double improv2;
double tmp_num;
double tmp_den;
double tau;
double delta;
double sumx;
long m;
long t;
long u;
long v;
long k;
long i, j;
long History_size;
int exitflag;
/* ------------------------------------------------------------
Initialization
------------------------------------------------------------ */
/* count cumber of constraints */
for( i=0, m=0; i < n; i++ ) m = MAX(m,I[i]);
/* alloc Hx [n x m] */
Hx = mxCalloc(m*n, sizeof(double));
if( Hx == NULL ) mexErrMsgTxt("Not enough memory.");
/* alloc History [2 x HISTORY_BUF] */
History_size = (tmax < HISTORY_BUF ) ? tmax+1 : HISTORY_BUF;
History = mxCalloc(History_size*2,sizeof(double));
if( History == NULL ) mexErrMsgTxt("Not enough memory.");
/* alloc d [n x 1] */
d = mxCalloc(n, sizeof(double));
if( d == NULL ) mexErrMsgTxt("Not enough memory.");
/* Hx = zeros(n,m);
for k=1:m,
inx = find(I==k);
Hx(:,k) = H(:,inx)*x(inx);
end
*/
for( i=0; i < n; i++ ) {
if( x[i] > 0 ) {
u = I[i]-1;
col_u = (double*)get_col(i);
for( j=0; j < n; j++ ) {
Hx[INDEX(j,u,n)] += col_u[j]*x[i];
}
}
}
/* d = sum(Hx,2) + f; */
for( i=0; i < n; i++ ) {
for( j=0; j < m; j++ ) {
d[i] += Hx[INDEX(i,j,n)];
}
d[i] += f[i];
}
/* UB = 0.5*x'*(f+d); */
/* LB = 0.5*x'*(f-d); */
for( i=0, UB = 0, LB=0; i < n; i++) {
UB += x[i]*(f[i]+d[i]);
LB += x[i]*(f[i]-d[i]);
}
UB = 0.5*UB;
LB = 0.5*LB;
/*
for k=1:m,
tmp = min(d(find(I==k)));
if tmp < 0, LB = LB + b*tmp; end
end
*/
for( i=0; i < m; i++ ) {
for( j=0, tmp = PLUS_INF; j < n; j++ ) {
if( I[j]-1 == i ) tmp = MIN(tmp, d[j]);
}
if( tmp < 0) LB += b*tmp;
}
exitflag = 0;
t = 0;
History[INDEX(0,0,2)] = LB;
History[INDEX(1,0,2)] = UB;
/* -- Main loop ---------------------------------------- */
while( (exitflag == 0) && (t < tmax))
{
t++;
exitflag = 1;
for( k=0; k < m; k++ )
{
/*
inx = find(I==k);
[tmp,u] = min(d(inx)); u = inx(u);
*/
for( i=0, tmp = PLUS_INF, delta = 0; i < n; i++ ) {
if( I[i]-1 == k) {
delta += x[i]*d[i];
if( tmp > d[i] ) {
tmp = d[i];
u = i;
}
}
}
/* if d(u) < 0, yu = b; else yu = 0; end */
if( d[u] < 0) yu = b; else yu = 0;
/* delta = x(inx)'*d(inx) - yu*d(u); */
delta -= yu*d[u];
if( delta > tolabs/m && delta > tolrel*ABS(UB)/m)
{
exitflag = 0;
col_u = (double*)get_col(u);
/* -- Kozinec like update ------ */
/*
y = x; y(inx) = 0; y(u) = yu;
% den1 = (x-y)'*H*(x-y)
den1 = (x(inx)-y(inx))'*(Hx(inx,k)-yu*H(inx,u));
num1 = (x-y)'*d;
*/
for( i=0, sumx = 0, den1 = 0, num1 = 0; i < n; i++ ) {
sumx += x[i];
if( i == u ) {
num1 += (x[i]-yu)*d[i];
den1 += (x[i]-yu)*(Hx[INDEX(i,k,n)]-yu*col_u[i]);
} else if( I[i]-1 == k) {
num1 += x[i]*d[i];
den1 += x[i]*(Hx[INDEX(i,k,n)]-yu*col_u[i]);
}
}
tau1 = MIN(1, num1/den1);
if( tau1 < 1 )
improv1 = num1*num1/(2*den1);
else {
/* Improv1 = 0.5*x'*(d+f) - 0.5*y'*(yu*H(:,u)+(d-f)-Hx(:,k)) - f'*y; */
for( i = 0, improv1 = 0; i < n; i++ ) {
if( i == u ) {
improv1 += 0.5*x[i]*(d[i]+f[i])
- 0.5*yu*(yu*col_u[i]+d[i]-f[i]-Hx[INDEX(i,k,n)]) - f[i]*yu;
} else if( I[i]-1 == k ) {
improv1 += 0.5*x[i]*(d[i]+f[i]);
} else {
improv1 += 0.5*x[i]*(d[i]+f[i])
- 0.5*x[i]*(yu*col_u[i]+d[i]-f[i]-Hx[INDEX(i,k,n)]) - f[i]*x[i];
}
}
}
/* -- MDM like update --------- */
improv2 = MINUS_INF;
if( sumx > 0) {
for(i = 0; i < n; i++ ) {
if( (I[i]-1 == k) && (i != u) && (x[i] > 0)) {
tmp_num = d[i] - d[u];
tmp_den = diag_H[u] - 2*col_u[i] + diag_H[i];
if( tmp_den > 0 ) {
tau = MIN(1,tmp_num/(x[i]*tmp_den));
if( tau < 1 ) {
tmp = tmp_num*tmp_num/(2*tmp_den);
} else {
/* tmp = x(i)*(d(i)-d(u))-0.5*x(i)^2*(H(u,u)-2*H(i,u)+H(i,i)); */
tmp = x[i]*tmp_num-0.5*x[i]*x[i]*tmp_den;
}
if( tmp > improv2 ) {
improv2 = tmp;
tau2 = tau;
v = i;
}
}
}
}
}
/* -- Apply the better line segment -------------- */
if( improv1 > improv2 ) {
/*
d = d + tau1*(yu*H(:,u)-Hx(:,k));
Hx(:,k) = Hx(:,k)*(1-tau1) + tau1*H(:,u)*yu;
x(setdif(inx,u)) = x(setdiff(inx,u))*(1-tau1);
x(u) = x(u)*(1-tau1) + yu*tau1;
*/
for( i = 0; i < n; i++ ) {
d[i] += tau1*(yu*col_u[i]-Hx[INDEX(i,k,n)]);
Hx[INDEX(i,k,n)] = Hx[INDEX(i,k,n)]*(1-tau1) + tau1*yu*col_u[i];
if( i == u)
x[i] = x[i]*(1-tau1) + tau1*yu;
else if( I[i]-1 == k )
x[i] = x[i]*(1-tau1);
}
}
else
{
col_v = (double*)get_col(v);
for( i = 0; i < n; i++ ) {
tmp = x[v]*tau2*(col_u[i]-col_v[i]);
Hx[INDEX(i,k,n)] += tmp;
d[i] += tmp;
}
x[u] += tau2*x[v];
x[v] -= tau2*x[v];
}
/* mexPrintf("t=%d,k=%d, u=%d, tau1=%f, den1=%f, num1=%f, delta=%f\n",
t,k,u,tau1,den1,num1,delta);*/
/* -- Update the upper bound ---------------------- */
for( i=0, UB = 0; i < n; i++) {
UB += x[i]*(f[i]+d[i]);
}
UB = 0.5*UB;
}
}
/* -- Computing LB --------------------------------------*/
/*
LB = 0.5*x'*(f-d); % LB = -0.5*x'*H*x;
for k=1:n,
tmp = min(d(find(I==k)));
if tmp < 0, LB = LB + b*tmp; end
end */
for( i=0, LB=0; i < n; i++) {
LB += 0.5*x[i]*(f[i]-d[i]);
}
for( i=0; i < m; i++ ) {
for( j=0, tmp = PLUS_INF; j < n; j++ ) {
if( I[j]-1 == i ) tmp = MIN(tmp, d[j]);
}
if( tmp < 0) LB += b*tmp;
}
/* Store LB and UB */
if( t < History_size ) {
History[INDEX(0,t,2)] = LB;
History[INDEX(1,t,2)] = UB;
}
else {
tmp_ptr = mxCalloc((History_size+HISTORY_BUF)*2,sizeof(double));
if( tmp_ptr == NULL ) mexErrMsgTxt("Not enough memory.");
for( i = 0; i < History_size; i++ ) {
tmp_ptr[INDEX(0,i,2)] = History[INDEX(0,i,2)];
tmp_ptr[INDEX(1,i,2)] = History[INDEX(1,i,2)];
}
tmp_ptr[INDEX(0,t,2)] = LB;
tmp_ptr[INDEX(1,t,2)] = UB;
History_size += HISTORY_BUF;
mxFree( History );
History = tmp_ptr;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -