📄 svm.java
字号:
};
};
if(quadraticLossNeg){
for(pos_i=0;pos_i<working_set_size;pos_i++){
if(my_which_alpha[pos_i]){
(qp.H)[pos_i*(working_set_size+1)] += 1/Cneg;
(qp.u)[pos_i] = Double.MAX_VALUE;
};
};
};
if(quadraticLossPos){
for(pos_i=0;pos_i<working_set_size;pos_i++){
if(! my_which_alpha[pos_i]){
(qp.H)[pos_i*(working_set_size+1)] += 1/Cpos;
(qp.u)[pos_i] = Double.MAX_VALUE;
};
};
};
};
/**
* Initialises the working set
* @exception Exception on any error
*/
protected void init_working_set()
{
// calculate sum
int i,j;
project_to_constraint();
// skip kernel calculation as all alphas = 0
for(i=0; i<examples_total;i++){
sum[i] = 0;
at_bound[i] = 0;
};
// first working set is random
j=0;
i=0;
while((i<working_set_size) && (j < examples_total)){
working_set[i] = j;
if(is_alpha_neg(j)){
which_alpha[i] = true;
}
else{
which_alpha[i] = false;
};
i++;
j++;
};
update_working_set();
};
/**
* Calls the optimizer
*/
protected abstract void optimize();
/**
* Stores the optimizer results
*/
protected void put_optimizer_values()
{
// update nabla, sum, examples.
// sum[i] += (primal_j^*-primal_j-alpha_j^*+alpha_j)K(i,j)
// check for |nabla| < is_zero (nabla <-> nabla*)
int i=0;
int j=0;
int pos_i;
double the_new_alpha;
double[] kernel_row;
double alpha_diff;
double my_sum[] = sum;
pos_i=working_set_size;
while(pos_i>0){
pos_i--;
if(which_alpha[pos_i]){
the_new_alpha = primal[pos_i];
}
else{
the_new_alpha = -primal[pos_i];
};
// next three statements: keep this order!
i = working_set[pos_i];
alpha_diff = the_new_alpha-alphas[i];
alphas[i] = the_new_alpha;
if(alpha_diff != 0){
// update sum ( => nabla)
kernel_row = the_kernel.get_row(i);
for(j=examples_total-1;j>=0;j--){
my_sum[j] += alpha_diff*kernel_row[j];
};
};
};
};
/**
* Checks if the optimization converged
* @return boolean true optimzation if converged
*/
protected boolean convergence()
{
double the_lambda_eq = 0;
int total = 0;
double alpha_sum=0;
double alpha=0;
int i;
boolean result = true;
// actual convergence-test
total = 0; alpha_sum=0;
for(i=0;i<examples_total;i++){
alpha = alphas[i];
alpha_sum += alpha;
if((alpha>is_zero) && (alpha-Cneg < -is_zero)){
// alpha^* = - nabla
the_lambda_eq += -nabla(i); //all_ys[i]-epsilon_neg-sum[i];
total++;
}
else if((alpha<-is_zero) && (alpha+Cpos > is_zero)){
// alpha = nabla
the_lambda_eq += nabla(i); //all_ys[i]+epsilon_pos-sum[i];
total++;
};
};
logln(4,"lambda_eq = "+(the_lambda_eq/total));
if(total>0){
lambda_eq = the_lambda_eq / total;
}
else{
// keep WS lambda_eq
lambda_eq = lambda_WS; //(lambda_eq+4*lambda_WS)/5;
logln(4,"*** no SVs in convergence(), lambda_eq = "+lambda_eq+".");
};
if(target_count>2){
// estimate lambda from WS
if(target_count>20){
// desperate attempt to get good lambda!
lambda_eq = ((40-target_count)*lambda_eq + (target_count-20)*lambda_WS)/20;
logln(5,"Re-Re-calculated lambda from WS: "+lambda_eq);
if(target_count>40){
// really desperate, kick one example out!
i = working_set[target_count%working_set_size];
if(is_alpha_neg(i)){
lambda_eq = -nabla(i);
}
else{
lambda_eq = nabla(i);
};
logln(5,"set lambda_eq to nabla("+i+"): "+lambda_eq);
};
}
else{
lambda_eq = lambda_WS;
logln(5,"Re-calculated lambda_eq from WS: "+lambda_eq);
};
};
// check linear constraint
if(java.lang.Math.abs(alpha_sum+sum_alpha) > convergence_epsilon){
// equality constraint violated
logln(4,"No convergence: equality constraint violated: |"+(alpha_sum+sum_alpha)+"| >> 0");
project_to_constraint();
result = false;
};
i=0;
while((i<examples_total) && (result != false)){
if(lambda(i)>=-convergence_epsilon){
i++;
}
else{
result = false;
};
};
return result;
};
protected abstract double nabla(int i);
/**
* lagrangion multiplier of variable i
* @param i variable index
* @return lambda
*/
protected double lambda(int i)
{
double alpha;
double result;
if(is_alpha_neg(i)){
result = - java.lang.Math.abs(nabla(i)+lambda_eq);
}
else{
result = - java.lang.Math.abs(nabla(i)-lambda_eq);
};
// default = not at bound
alpha=alphas[i];
if(alpha>is_zero){
// alpha*
if(alpha-Cneg >= - is_zero){
// upper bound active
result = -lambda_eq-nabla(i);
};
}
else if(alpha >= -is_zero){
// lower bound active
if(is_alpha_neg(i)){
result = nabla(i) + lambda_eq;
}
else{
result = nabla(i)-lambda_eq;
};
}
else if(alpha+Cpos <= is_zero){
// upper bound active
result = lambda_eq - nabla(i);
};
return result;
};
protected boolean feasible(int i)
{
boolean is_feasible = true;
double alpha = alphas[i];
double the_lambda = lambda(i);
if(alpha-Cneg >= - is_zero){
// alpha* at upper bound
if(the_lambda >= 0){
at_bound[i]++;
if(at_bound[i] == shrink_const) to_shrink++;
}
else{
at_bound[i] = 0;
};
}
else if((alpha<=is_zero) && (alpha >= -is_zero)){
// lower bound active
if(the_lambda >= 0){
at_bound[i]++;
if(at_bound[i] == shrink_const) to_shrink++;
}
else{
at_bound[i] = 0;
};
}
else if(alpha+Cpos <= is_zero){
// alpha at upper bound
if(the_lambda >= 0){
at_bound[i]++;
if(at_bound[i] == shrink_const) to_shrink++;
}
else{
at_bound[i] = 0;
};
}
else{
// not at bound
at_bound[i] = 0;
};
if((the_lambda >= feasible_epsilon) || (at_bound[i] >= shrink_const)){
is_feasible = false;
};
return is_feasible;
};
protected abstract boolean is_alpha_neg(int i);
/**
* log the output plus newline
* @param level warning level
* @param message Message test
*/
protected void logln(int level, String message) {
LogService.logMessage(message, YALE_VERBOSITY[level-1]);
};
/**
* predict values on the testset with model
*/
public void predict(ExampleSet to_predict)
{
int i;
double prediction;
Example example;
//int size = the_examples.count_examples(); // IM 04/02/12
int size = to_predict.count_examples();
//System.out.println("Size: " + size);
for(i=0;i<size;i++){
example = to_predict.get_example(i);
prediction = predict(example);
to_predict.set_y(i,prediction);
};
logln(4,"Prediction generated");
};
private double predict(int i) {
return predict(the_examples.get_example(i));
}
/**
* predict a single example
*/
protected double predict(Example example)
{
int i;
int[] sv_index;
double[] sv_att;
double the_sum=the_examples.get_b();
double alpha;
for(i=0;i<examples_total;i++){
alpha = alphas[i];
if(alpha != 0){
sv_index = the_examples.index[i];
sv_att = the_examples.atts[i];
the_sum += alpha*the_kernel.calculate_K(sv_index,sv_att,example.index,example.att);
};
};
return the_sum;
};
/**
* check internal variables, for debugging only
*/
protected void check()
{
double tsum;
int i,j;
double s=0;
for(i=0; i<examples_total;i++){
s += alphas[i];
tsum = 0;
for(j=0; j<the_examples.count_examples();j++){
tsum += alphas[j]*the_kernel.calculate_K(i,j);
};
if(Math.abs(tsum-sum[i]) > is_zero){
logln(1,"ERROR: sum["+i+"] off by "+(tsum-sum[i]));
//throw(new Exception("ERROR: sum["+i+"] off by "+(tsum-sum[i])));
//System.exit(1);
};
};
if(Math.abs(s+sum_alpha) > is_zero){
logln(1,"ERROR: sum_alpha is off by "+(s+sum_alpha));
// throw(new Exception("ERROR: sum_alpha is off by "+(s+sum_alpha)));
//System.exit(1);
};
};
/** Returns a double array of estimated performance values. These are accuracy, precision and recall.
* Works only for classification SVMs. */
public double[] getXiAlphaEstimation(Kernel kernel) {
double r_delta = 0.0d;
for (int j = 0; j < examples_total; j++) {
double norm_x = kernel.calculate_K(j,j);
for (int i = 0; i < examples_total; i++) {
double r_current = norm_x - kernel.calculate_K(i,j);
if (r_current > r_delta){
r_delta = r_current;
}
}
}
int total_pos = 0;
int total_neg = 0;
int estim_pos = 0;
int estim_neg = 0;
double xi = 0.0d;
for(int i = 0; i < examples_total; i++){
double alpha = the_examples.get_alpha(i);
double prediction = predict(i);
double y = the_examples.get_y(i);
if(y>0){
if(prediction>1){
xi=0;
}
else{
xi=1-prediction;
};
if(2*alpha*r_delta+xi >= 1){
estim_pos++;
};
total_pos++;
} else{
if(prediction<-1){
xi=0;
}
else{
xi=1+prediction;
};
if(2*(-alpha)*r_delta+xi >= 1){
estim_neg++;
};
total_neg++;
};
};
//System.out.println("estim_pos: " + estim_pos + ", estim_neg: " + estim_neg + ", total_pos: " + total_pos + ", total_neg: " + total_neg);
double[] result = new double[3];
result[0] = 1.0d - (double)(estim_pos+estim_neg) / (double)(total_pos+total_neg);
result[1] = (double)(total_pos-estim_pos) / (double)(total_pos-estim_pos+estim_neg);
result[2] = 1.0d - (double)estim_pos / (double)total_pos;
return result;
}
};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -