📄 smoreg.java
字号:
return "Whether lower order polyomials are also used (only "
+ "available for non-linear polynomial kernels).";
}
/**
* Check whether lower-order terms are being used.
* @return Value of lowerOrder.
*/
public boolean getLowerOrderTerms() {
return m_lowerOrder;
}
/**
* Set whether lower-order terms are to be used. Defaults
* to false if a linear machine is built.
* @param v Value to assign to lowerOrder.
*/
public void setLowerOrderTerms(boolean v) {
if (m_exponent == 1.0 || m_useRBF) {
m_lowerOrder = false;
} else {
m_lowerOrder = v;
}
}
/**
* Turns off checks for missing values, etc. Use with caution.
*/
public void turnChecksOff() {
m_checksTurnedOff = true;
}
/**
* Turns on checks for missing values, etc.
*/
public void turnChecksOn() {
m_checksTurnedOff = false;
}
/**
* Prints out the classifier.
*
* @return a description of the classifier as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
int printed = 0;
if ((m_alpha == null) && (m_sparseWeights == null)) {
return "SMOreg : No model built yet.";
}
try {
text.append("SMOreg\n\n");
text.append("Kernel used : \n");
if(m_useRBF) {
text.append(" RBF kernel : K(x,y) = e^-(" + m_gamma + "* <x-y,x-y>^2)");
} else if (m_exponent == 1){
text.append(" Linear Kernel : K(x,y) = <x,y>");
} else {
if (m_featureSpaceNormalization) {
if (m_lowerOrder){
text.append(" Normalized Poly Kernel with lower order : K(x,y) = (<x,y>+1)^" + m_exponent + "/" +
"((<x,x>+1)^" + m_exponent + "*" + "(<y,y>+1)^" + m_exponent + ")^(1/2)");
} else {
text.append(" Normalized Poly Kernel : K(x,y) = <x,y>^" + m_exponent + "/" + "(<x,x>^" +
m_exponent + "*" + "<y,y>^" + m_exponent + ")^(1/2)");
}
} else {
if (m_lowerOrder){
text.append(" Poly Kernel with lower order : K(x,y) = (<x,y> + 1)^" + m_exponent);
} else {
text.append(" Poly Kernel : K(x,y) = <x,y>^" + m_exponent);
}
}
}
text.append("\n\n");
// display the linear transformation
String trans = "";
if (m_filterType == FILTER_STANDARDIZE) {
//text.append("LINEAR TRANSFORMATION APPLIED : \n");
trans = "(standardized) ";
//text.append(trans + m_data.classAttribute().name() + " = " +
// m_Alin + " * " + m_data.classAttribute().name() + " + " + m_Blin + "\n\n");
} else if (m_filterType == FILTER_NORMALIZE) {
//text.append("LINEAR TRANSFORMATION APPLIED : \n");
trans = "(normalized) ";
//text.append(trans + m_data.classAttribute().name() + " = " +
// m_Alin + " * " + m_data.classAttribute().name() + " + " + m_Blin + "\n\n");
}
// If machine linear, print weight vector
if (!m_useRBF && m_exponent == 1.0) {
text.append("Machine Linear: showing attribute weights, ");
text.append("not support vectors.\n");
// We can assume that the weight vector is stored in sparse
// format because the classifier has been built
text.append(trans + m_data.classAttribute().name() + " =\n");
for (int i = 0; i < m_sparseWeights.length; i++) {
if (m_sparseIndices[i] != (int)m_classIndex) {
if (printed > 0) {
text.append(" + ");
} else {
text.append(" ");
}
text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) +
" * ");
if (m_filterType == FILTER_STANDARDIZE) {
text.append("(standardized) ");
} else if (m_filterType == FILTER_NORMALIZE) {
text.append("(normalized) ");
}
if (!m_checksTurnedOff) {
text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n");
} else {
text.append("attribute with index " +
m_sparseIndices[i] +"\n");
}
printed++;
}
}
} else {
text.append("Support Vector Expansion :\n");
text.append(trans + m_data.classAttribute().name() + " =\n");
printed = 0;
for (int i = 0; i < m_alpha.length; i++) {
double val = m_alpha[i] - m_alpha_[i];
if (java.lang.Math.abs(val) < 1e-4)
continue;
if (printed > 0) {
text.append(" + ");
} else {
text.append(" ");
}
text.append(Utils.doubleToString(val, 12, 4)
+ " * K[X(" + i + "), X]\n");
printed++;
}
}
if (m_b > 0) {
text.append(" + " + Utils.doubleToString(m_b, 12, 4));
} else {
text.append(" - " + Utils.doubleToString(-m_b, 12, 4));
}
if (m_useRBF || m_exponent != 1.0) {
text.append("\n\nNumber of support vectors: " + printed);
}
int numEval = 0;
int numCacheHits = -1;
if(m_kernel != null)
{
numEval = m_kernel.numEvals();
numCacheHits = m_kernel.numCacheHits();
}
text.append("\n\nNumber of kernel evaluations: " + numEval);
if (numCacheHits >= 0 && numEval > 0)
{
double hitRatio = 1 - numEval/(numCacheHits+numEval);
text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3) + "% cached)");
}
} catch (Exception e) {
return "Can't print the classifier.";
}
return text.toString();
}
/**
* Main method for testing this class.
*/
public static void main(String[] argv) {
Classifier scheme;
try {
scheme = new SMOreg();
System.out.println(Evaluation.evaluateModel(scheme, argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
/**
* Debuggage function.
* Compute the value of the objective function.
*/
protected double objFun() throws Exception {
double res = 0;
double t = 0, t2 = 0;
for(int i = 0; i < m_alpha.length; i++){
for(int j = 0; j < m_alpha.length; j++){
t += (m_alpha[i] - m_alpha_[i]) * (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i,j,m_data.instance(i));
}
t2 += m_data.instance(i).classValue() * (m_alpha[i] - m_alpha_[i]) - m_epsilon * (m_alpha[i] + m_alpha_[i]);
}
res += -0.5 * t + t2;
return res;
}
/**
* Debuggage function.
* Compute the value of the objective function.
*/
protected double objFun(int i1, int i2,
double alpha1, double alpha1_,
double alpha2, double alpha2_) throws Exception {
double res = 0;
double t = 0, t2 = 0;
for(int i = 0; i < m_alpha.length; i++){
double alphai;
double alphai_;
if(i == i1){
alphai = alpha1; alphai_ = alpha1_;
} else if(i == i2){
alphai = alpha2; alphai_ = alpha2_;
} else {
alphai = m_alpha[i]; alphai_ = m_alpha_[i];
}
for(int j = 0; j < m_alpha.length; j++){
double alphaj;
double alphaj_;
if(j == i1){
alphaj = alpha1; alphaj_ = alpha1_;
} else if(j == i2){
alphaj = alpha2; alphaj_ = alpha2_;
} else {
alphaj = m_alpha[j]; alphaj_ = m_alpha_[j];
}
t += (alphai - alphai_) * (alphaj - alphaj_) * m_kernel.eval(i,j,m_data.instance(i));
}
t2 += m_data.instance(i).classValue() * (alphai - alphai_) - m_epsilon * (alphai + alphai_);
}
res += -0.5 * t + t2;
return res;
}
/**
* Debuggage function.
* Check that the set I0, I1, I2 and I3 cover the whole set of index
* and that no attribute appears in two different sets.
*/
protected void checkSets() throws Exception{
boolean[] test = new boolean[m_data.numInstances()];
for (int i = m_I0.getNext(-1); i != -1; i = m_I0.getNext(i)) {
if(test[i]){
throw new Exception("Fatal error! indice " + i + " appears in two different sets.");
} else {
test[i] = true;
}
if( !((0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()) ||
(0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight())) ){
throw new Exception("Warning! I0 contains an incorrect indice.");
}
}
for (int i = m_I1.getNext(-1); i != -1; i = m_I1.getNext(i)) {
if(test[i]){
throw new Exception("Fatal error! indice " + i + " appears in two different sets.");
} else {
test[i] = true;
}
if( !( m_alpha[i] == 0 && m_alpha_[i] == 0) ){
throw new Exception("Fatal error! I1 contains an incorrect indice.");
}
}
for (int i = m_I2.getNext(-1); i != -1; i = m_I2.getNext(i)) {
if(test[i]){
throw new Exception("Fatal error! indice " + i + " appears in two different sets.");
} else {
test[i] = true;
}
if( !(m_alpha[i] == 0 && m_alpha_[i] == m_C * m_data.instance(i).weight()) ){
throw new Exception("Fatal error! I2 contains an incorrect indice.");
}
}
for (int i = m_I3.getNext(-1); i != -1; i = m_I3.getNext(i)) {
if(test[i]){
throw new Exception("Fatal error! indice " + i + " appears in two different sets.");
} else {
test[i] = true;
}
if( !(m_alpha_[i] == 0 && m_alpha[i] == m_C * m_data.instance(i).weight()) ){
throw new Exception("Fatal error! I3 contains an incorrect indice.");
}
}
for (int i = 0; i < test.length; i++){
if(!test[i]){
throw new Exception("Fatal error! indice " + i + " doesn't belong to any set.");
}
}
}
/**
* Debuggage function
* Checks that :
* alpha*alpha_=0
* sum(alpha[i] - alpha_[i]) = 0
*/
protected void checkAlphas() throws Exception{
double sum = 0;
for(int i = 0; i < m_alpha.length; i++){
if(!(0 == m_alpha[i] || m_alpha_[i] == 0)){
throw new Exception("Fatal error! Inconsistent alphas!");
}
sum += (m_alpha[i] - m_alpha_[i]);
}
if(sum > 1e-10){
throw new Exception("Fatal error! Inconsistent alphas' sum = " + sum);
}
}
/**
* Debuggage function.
* Display the current status of the program.
* @param i1 the first current indice
* @param i2 the second current indice
*/
protected void displayStat(int i1, int i2) throws Exception {
System.err.println("\n-------- Status : ---------");
System.err.println("\n i, alpha, alpha'\n");
for(int i = 0; i < m_alpha.length; i++){
double result = (m_bLow + m_bUp)/2.0;
for (int j = 0; j < m_alpha.length; j++) {
result += (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i));
}
System.err.print(" " + i + ": (" + m_alpha[i] + ", " + m_alpha_[i] +
"), " + (m_data.instance(i).classValue() - m_epsilon) + " <= " +
result + " <= " + (m_data.instance(i).classValue() + m_epsilon));
if(i == i1){
System.err.print(" <-- i1");
}
if(i == i2){
System.err.print(" <-- i2");
}
System.err.println();
}
System.err.println("bLow = " + m_bLow + " bUp = " + m_bUp);
System.err.println("---------------------------\n");
}
/**
* Debuggage function
* Compute and display bLow, lUp and so on...
*/
protected void displayB() throws Exception {
//double bUp = Double.NEGATIVE_INFINITY;
//double bLow = Double.POSITIVE_INFINITY;
//int iUp = -1, iLow = -1;
for(int i = 0; i < m_data.numInstances(); i++){
double Fi = m_data.instance(i).classValue();
for(int j = 0; j < m_alpha.length; j++){
Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i));
}
System.err.print("(" + m_alpha[i] + ", " + m_alpha_[i] + ") : ");
System.err.print((Fi - m_epsilon) + ", " + (Fi + m_epsilon));
double fim = Fi - m_epsilon, fip = Fi + m_epsilon;
String s = "";
if (m_I0.contains(i)){
if ( 0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()){
s += "(in I0a) bUp = min(bUp, " + fim + ") bLow = max(bLow, " + fim + ")";
}
if ( 0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight()){
s += "(in I0a) bUp = min(bUp, " + fip + ") bLow = max(bLow, " + fip + ")";
}
}
if (m_I1.contains(i)){
s += "(in I1) bUp = min(bUp, " + fip + ") bLow = max(bLow, " + fim + ")";
}
if (m_I2.contains(i)){
s += "(in I2) bLow = max(bLow, " + fip + ")";
}
if (m_I3.contains(i)){
s += "(in I3) bUp = min(bUp, " + fim + ")";
}
System.err.println(" " + s + " {" + (m_alpha[i]-1) + ", " + (m_alpha_[i]-1) + "}");
}
System.err.println("\n\n");
}
/**
* Debuggage function.
* Checks if the equations (6), (8a), (8b), (8c), (8d) hold.
* (Refers to "Improvements to SMO Algorithm for SVM Regression".)
* Prints warnings for each equation which doesn't hold.
*/
protected void checkOptimality() throws Exception {
double bUp = Double.POSITIVE_INFINITY;
double bLow = Double.NEGATIVE_INFINITY;
int iUp = -1, iLow = -1;
for(int i = 0; i < m_data.numInstances(); i++){
double Fi = m_data.instance(i).classValue();
for(int j = 0; j < m_alpha.length; j++){
Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i));
}
double fitilde = 0, fibarre = 0;
if(m_I0.contains(i) && 0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()){
fitilde = Fi - m_epsilon;
fibarre = Fi - m_epsilon;
}
if(m_I0.contains(i) && 0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight()){
fitilde = Fi + m_epsilon;
fibarre = Fi + m_epsilon;
}
if(m_I1.contains(i)){
fitilde = Fi - m_epsilon;
fibarre = Fi + m_epsilon;
}
if(m_I2.contains(i)){
fitilde = Fi + m_epsilon;
fibarre = Double.POSITIVE_INFINITY;
}
if(m_I3.contains(i)){
fitilde = Double.NEGATIVE_INFINITY;
fibarre = Fi - m_epsilon;
}
if(fibarre < bUp){
bUp = fibarre;
iUp = i;
}
if(fitilde > bLow){
bLow = fitilde;
iLow = i;
}
}
if(!(bLow <= bUp + 2 * m_tol)){
System.err.println("Warning! Optimality not reached : inequation (6) doesn't hold!");
}
boolean noPb = true;
for(int i = 0; i < m_data.numInstances(); i++){
double Fi = m_data.instance(i).classValue();
for(int j = 0; j < m_alpha.length; j++){
Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i));
}
double Ei = Fi - ((m_bUp + m_bLow) / 2.0);
if((m_alpha[i] > 0) && !(Ei >= m_epsilon - m_tol)){
System.err.println("Warning! Optimality not reached : inequation (8a) doesn't hold for " + i);
noPb = false;
}
if((m_alpha[i] < m_C * m_data.instance(i).weight()) && !(Ei <= m_epsilon + m_tol)){
System.err.println("Warning! Optimality not reached : inequation (8b) doesn't hold for " + i);
noPb = false;
}
if((m_alpha_[i] > 0) && !(Ei <= -m_epsilon + m_tol)){
System.err.println("Warning! Optimality not reached : inequation (8c) doesn't hold for " + i);
noPb = false;
}
if((m_alpha_[i] < m_C * m_data.instance(i).weight()) && !(Ei >= -m_epsilon - m_tol)){
System.err.println("Warning! Optimality not reached : inequation (8d) doesn't hold for " + i);
noPb = false;
}
}
if(!noPb){
System.err.println();
//displayStat(-1,-1);
//displayB();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -