📄 quadraticproblemsmo.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.mySVM.Optimizer;
import java.lang.Math;
import java.lang.Number;
public class quadraticProblemSMO extends quadraticProblem
{
protected double[] sum;
protected double is_zero;
protected int max_iteration;
public quadraticProblemSMO(){
is_zero = 1e-10;
max_allowed_error = 1e-3;
max_iteration = 10000;
};
public void set_n(int new_n){
super.set_n(new_n);
sum = new double[n];
};
public quadraticProblemSMO(double is_zero, double max_allowed_error, int max_iteration){
this.is_zero = is_zero;
this.max_allowed_error = max_allowed_error;
this.max_iteration = max_iteration;
};
protected final double x2tox1(double x2, boolean id, double A1, double b){
double x1;
if(id){
x1 = -x2;
}
else{
x1 = x2;
};
if(A1>0){
x1+=b;
}
else{
x1 -= b;
};
return x1;
};
protected final double x1tox2(double x1, boolean id, double A2, double b){
double x2;
if(id){
x2 = -x1;
}
else{
x2 = x1;
};
if(A2>0){
x2+=b;
}
else{
x2 -= b;
};
return x2;
};
protected final void simple_solve(int i, int j,
double H1, double H2,
double c0,
double c1, double c2,
double A1, double A2,
double l1, double l2,
double u1, double u2)
{
double x1 = x[i];
double x2 = x[j];
double t;
double den;
den = H1+H2;
if(((A1 > 0) && (A2 > 0)) ||
((A1 < 0) && (A2 < 0))){
den -= c0;
}
else{
den += c0;
};
den*=2;
if(den != 0){
double num;
num = -2*H1*x1-x2*c0-c1;
if(A1<0){
num = -num;
};
if(A2>0){
num += 2*H2*x2+x1*c0+c2;
}
else{
num -= 2*H2*x2+x1*c0+c2;
};
t = num/den;
double up;
double lo;
if(A1>0){
lo = l1-x1;
up = u1-x1;
}
else{
lo = x1-u1;
up = x1-l1;
};
if(A2<0){
if(l2-x2 > lo) lo = l2-x2;
if(u2-x2 < up) up = u2-x2;
}
else{
if(x2-l2 < up) up =x2-l2;
if(x2-u2 > lo) lo = x2-u2;
};
if(t < lo){
t = lo;
};
if(t > up){
t = up;
};
}
else{
// den = 0 => linear target function => set x at bound
double factor;
factor = 2*H1*x1+x2*c0+c1;
if(A1<0){
factor = -factor;
};
if(A2>0){
factor -= 2*H2*x2+x1*c0+c2;
}
else{
factor += 2*H2*x2+x1*c0+c2;
};
if(factor>0){
// t = lo
if(A1>0){
t = l1-x1;
}
else{
t = x1-u1;
};
if(A2<0){
if(l2-x2 > t) t = l2-x2;
}
else{
if(x2-u2 > t) t = x2-u2;
};
}
else{
// t = up
if(A1>0){
t = u1-x1;
}
else{
t = x1-l1;
};
if(A2<0){
if(u2-x2 < t) t = u2-x2;
}
else{
if(x2-l2 < t) t =x2-l2;
};
};
};
// calc new x from t
if(A1>0){
x1 += t;
}
else{
x1 -= t;
};
if(A2>0){
x2 -= t;
}
else{
x2 += t;
};
if(x1-l1 <= is_zero){
x1 = l1;
}
else if(x1-u1 >= -is_zero){
x1 = u1;
};
if(x2-l2 <= is_zero){
x2 = l2;
}
else if(x2-u2 >= -is_zero){
x2 = u2;
};
x[i] = x1;
x[j] = x2;
};
protected final boolean minimize_ij(int i, int j){
// minimize xi, xi with simple_solve
double sum_i; // sum_k Hik x_k
double sum_j;
// init sum_i,j
sum_i=sum[i];
sum_j=sum[j];
sum_i -= H[i*(n+1)]*x[i];
sum_i -= H[i*n+j]*x[j];
sum_j -= H[j*n+i]*x[i];
sum_j -= H[j*(n+1)]*x[j];
sum_i += c[i];
sum_j += c[j];
double old_xi = x[i];
double old_xj = x[j];
simple_solve(i,j,H[i*(n+1)]/2, H[j*(n+1)]/2,H[i*n+j],
sum_i, sum_j,A[i], A[j],
l[i], l[j],u[i], u[j]);
boolean ok;
double target;
target = (old_xi-x[i])*(H[i*(n+1)]/2*(old_xi+x[i])+sum_i)
+(old_xj-x[j])*(H[j*(n+1)]/2*(old_xj+x[j])+sum_j)
+H[i*n+j]*(old_xi*old_xj-x[i]*x[j]);
if(target < 0){
// cout<<"increase on SMO: "<<target<<endl;
x[i] = old_xi;
x[j] = old_xj;
old_xi=0;
old_xj=0;
ok=false;
}
else{
old_xi-=x[i];
old_xj-=x[j];
int k;
for(k=0;k<n;k++){
sum[k]-=H[i*n+k]*old_xi;
sum[k]-=H[j*n+k]*old_xj;
};
ok=true;
};
if((Math.abs(old_xi) > is_zero) || (Math.abs(old_xj) > is_zero)){
ok = true;
}
else{
ok = false;
};
return ok;
};
protected final void calc_lambda_eq(){
double lambda_eq_sum = 0;
int count = 0;
int i;
for(i=0;i<n;i++){
if((x[i] > l[i]) && (x[i]<u[i])){
if(A[i]>0){
lambda_eq_sum-= (sum[i]+c[i]);
}
else{
lambda_eq_sum+= sum[i]+c[i];
};
count++;
};
};
if(count>0){
lambda_eq_sum /= (double)count;
}
else{
double lambda_min = Double.NEGATIVE_INFINITY;
double lambda_max = Double.POSITIVE_INFINITY;
double nabla;
for(i=0;i<n;i++){
nabla = sum[i]+c[i];
if(x[i] <= l[i]){
// lower bound
if(A[i]>0){
if(-nabla > lambda_min){
lambda_min = -nabla;
};
}
else{
if(nabla < lambda_max){
lambda_max = nabla;
};
};
}
else{
// upper bound
if(A[i]>0){
if(-nabla < lambda_max){
lambda_max = -nabla;
};
}
else{
if(nabla > lambda_min){
lambda_min = nabla;
};
};
};
};
if(lambda_min > Double.NEGATIVE_INFINITY){
if(lambda_max < Double.POSITIVE_INFINITY){
lambda_eq_sum = (lambda_max+lambda_min)/2;
}
else{
lambda_eq_sum = lambda_min;
};
}
else{
lambda_eq_sum = lambda_max;
};
};
lambda_eq = lambda_eq_sum;
};
public final int solve(){
int error=0;
int i;
int j;
for(i=0;i<n;i++){
sum[i] = 0;
for(j=0;j<n;j++){
sum[i] += H[i*n+j]*x[j];
};
};
int iteration=0;
double this_error;
double this_lambda_eq;
double max_lambda_eq=0;
double max_error = Double.NEGATIVE_INFINITY;
double min_error = Double.POSITIVE_INFINITY;
int max_i = 0;
int min_i = 1;
int old_min_i=-1;
int old_max_i=-1;
calc_lambda_eq();
S:while(true){
// get i with largest KKT error
if(0 == error){
// cout<<"l";
max_error = Double.NEGATIVE_INFINITY;
min_error = Double.POSITIVE_INFINITY;
max_i = 0;
min_i = 1;
// heuristic for i
for(i=0;i<n;i++){
if(x[i] <= l[i]){
// at lower bound
this_error = -sum[i]-c[i];
if(A[i]>0){
this_lambda_eq = this_error;
this_error -= lambda_eq;
}
else{
this_lambda_eq = -this_error;
this_error += lambda_eq;
};
}
else if(x[i] >= u[i]){
// at upper bound
this_error = sum[i]+c[i];
if(A[i]>0){
this_lambda_eq = -this_error;
this_error += lambda_eq;
}
else{
this_lambda_eq = this_error;
this_error -= lambda_eq;
};
}
else{
// between bounds
this_error = sum[i]+c[i];
if(A[i]>0){
this_lambda_eq = -this_error;
this_error += lambda_eq;
}
else{
this_lambda_eq = this_error;
this_error -= lambda_eq;
};
if(this_error<0) this_error = -this_error;
}
if((this_error>max_error) && (old_max_i != i)){
max_i = i;
max_error = this_error;
max_lambda_eq = this_lambda_eq;
};
if((this_error<=min_error) && (i != old_min_i)){
min_i = i;
min_error = this_error;
};
};
old_max_i = max_i;
old_min_i = min_i;
}
else{
// error!=0, heuristic didn't work
max_i = (max_i+1)%n;
};
// problem solved?
if(max_error<=max_allowed_error){
error=0;
break S;
};
////////////////////////////////////////////////////////////
// // find element with maximal diff to max_i
double max_diff = -1;
double this_diff;
boolean n_up; // not at upper bound
boolean n_lo;
if(x[max_i] <= l[max_i]){
// at lower bound
n_lo = false;
}
else{
n_lo = true;
};
if(x[max_i] >= u[max_i]){
// at lower bound
n_up = false;
}
else{
n_up = true;
};
min_i = (max_i+1)%n;
for(i=0;i<n;i++){
if((i != max_i) &&
(n_up || (x[i] < u[i])) &&
(n_lo || (x[i] > l[i]))){
if(x[i] <= l[i]){
// at lower bound
this_error = -sum[i]-c[i];
if(A[i]<0){
this_error = -this_error;
};
}
else if(x[i] >= u[i]){
// at upper bound
this_error = sum[i]+c[i];
if(A[i]>0){
this_error = -this_error;
};
}
else{
// between bounds
this_error = sum[i]+c[i];
if(A[i]>0){
this_error = -this_error;
};
};
this_diff = Math.abs(this_error - max_lambda_eq);
if(this_diff>max_diff){
max_diff = this_diff;
min_i = i;
};
};
};
// ////////////////////////////////////////////////////////////
// optimize
int it=1;
while((! minimize_ij(min_i,max_i)) && (it<n)){
it++;
min_i = (min_i+1)%n;
if(min_i == max_i){
min_i = (min_i+1)%n;
};
};
if(it==n){
error++;
if(error >= n){
break S;
};
}
else{
error=0;
};
calc_lambda_eq();
// time up?
iteration++;
if(iteration>max_iteration){
error+=1;
break S;
};
};
// System.out.println("in SMO lambda = "+lambda_eq);
// for(i=0;i<n;i++){
// System.out.println(i+": "+x[i]+"\t"+(sum[i]+c[i]+lambda_eq));
// };
return error;
};
};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -