⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logconditionalobjectivefunction.java

📁 Standord Classifier实现了一个基于Java的最大熵分类器。用于模式识别
💻 JAVA
字号:
package edu.stanford.nlp.classify;import java.util.*;import edu.stanford.nlp.optimization.*;/** @author Dan Klein */public class LogConditionalObjectiveFunction extends AbstractCachingDiffFunction {  int numFeatures = 0;  int numClasses = 0;  int[][] data = null;  int[] labels = null;  double sigma = 1.0;  public int domainDimension() {    return numFeatures*numClasses;  }  int classOf(int index) {    return index % numClasses;  }  int featureOf(int index) {    return index / numClasses;  }  int indexOf(int f, int c) {    return f*numClasses+c;  }   public double[][] to2D(double[] x) {    double[][] x2 = new double[numFeatures][numClasses];    for (int i=0; i<numFeatures; i++)      for (int j=0; j<numClasses; j++)	x2[i][j] = x[indexOf(i,j)];    return x2;  }  double abs(double x) {    return (x >= 0.0 ? x : -1.0*x);  }  double addLPs(double x, double y) {    double max = (x > y ? x : y);    if (max == Double.NEGATIVE_INFINITY)      return max;    if (abs(x-y) > 20.0)      return max;    return Math.log(Math.exp(x-max)+Math.exp(y-max))+max;  }  protected void calculate(double[] x) {    //System.out.println("Checking at: "+x[0]+" "+x[1]+" "+x[2]);    value = 0.0;    Arrays.fill(derivative, 0.0);    double[] sums = new double[numClasses];    double[] probs = new double[numClasses];    double[] counts = new double[numClasses];    Arrays.fill(counts, 0.0);    for (int d=0; d<data.length; d++) {      int[] features = data[d];      // activation      Arrays.fill(sums, 0.0);      for (int c=0; c<numClasses; c++) {	for (int f=0; f<features.length; f++) {	  int i = indexOf(features[f], c);	  sums[c] += x[i];	}      }      // expectation      double total = Double.NEGATIVE_INFINITY;      for (int c=0; c<numClasses; c++) {	total = addLPs(total, sums[c]);      }      for (int c=0; c<numClasses; c++) {	probs[c] = Math.exp(sums[c]-total);	for (int f=0; f<features.length; f++) {	  int i = indexOf(features[f], c);	  derivative[i] += probs[c];	}      }      // observed      for (int f=0; f<features.length; f++) {	int i = indexOf(features[f], labels[d]);	derivative[i] -= 1.0;      }      value -= sums[labels[d]]-total;    }    // priors    if (true) {      for (int i=0; i<x.length; i++) {	double k = 1.0;	double w = x[i];	value += k*w*w/2.0/sigma/sigma;	derivative[i] += k*w/sigma/sigma;      }    }    /*    System.out.println("N: "+data.length);    System.out.println("Value: "+value);    double ds = 0.0;    for (int i=0; i<x.length; i++) {      ds += derivative[i];      System.out.println(i+" is: "+derivative[i]);    }    */    //System.out.println("Deriv sum is: "+ds);  }  public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels) {    this(numFeatures, numClasses, data, labels, 1.0);  }  public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, double sigma) {    this.numFeatures = numFeatures;    this.numClasses = numClasses;    this.data = data;    this.labels = labels;    this.sigma = sigma;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -