📄 logconditionalobjectivefunction.java
字号:
package edu.stanford.nlp.classify;import java.util.*;import edu.stanford.nlp.optimization.*;/** @author Dan Klein */public class LogConditionalObjectiveFunction extends AbstractCachingDiffFunction { int numFeatures = 0; int numClasses = 0; int[][] data = null; int[] labels = null; double sigma = 1.0; public int domainDimension() { return numFeatures*numClasses; } int classOf(int index) { return index % numClasses; } int featureOf(int index) { return index / numClasses; } int indexOf(int f, int c) { return f*numClasses+c; } public double[][] to2D(double[] x) { double[][] x2 = new double[numFeatures][numClasses]; for (int i=0; i<numFeatures; i++) for (int j=0; j<numClasses; j++) x2[i][j] = x[indexOf(i,j)]; return x2; } double abs(double x) { return (x >= 0.0 ? x : -1.0*x); } double addLPs(double x, double y) { double max = (x > y ? x : y); if (max == Double.NEGATIVE_INFINITY) return max; if (abs(x-y) > 20.0) return max; return Math.log(Math.exp(x-max)+Math.exp(y-max))+max; } protected void calculate(double[] x) { //System.out.println("Checking at: "+x[0]+" "+x[1]+" "+x[2]); value = 0.0; Arrays.fill(derivative, 0.0); double[] sums = new double[numClasses]; double[] probs = new double[numClasses]; double[] counts = new double[numClasses]; Arrays.fill(counts, 0.0); for (int d=0; d<data.length; d++) { int[] features = data[d]; // activation Arrays.fill(sums, 0.0); for (int c=0; c<numClasses; c++) { for (int f=0; f<features.length; f++) { int i = indexOf(features[f], c); sums[c] += x[i]; } } // expectation double total = Double.NEGATIVE_INFINITY; for (int c=0; c<numClasses; c++) { total = addLPs(total, sums[c]); } for (int c=0; c<numClasses; c++) { probs[c] = Math.exp(sums[c]-total); for (int f=0; f<features.length; f++) { int i = indexOf(features[f], c); derivative[i] += probs[c]; } } // observed for (int f=0; f<features.length; f++) { int i = indexOf(features[f], labels[d]); derivative[i] -= 1.0; } value -= sums[labels[d]]-total; } // priors if (true) { for (int i=0; i<x.length; i++) { double k = 1.0; double w = x[i]; value += k*w*w/2.0/sigma/sigma; derivative[i] += k*w/sigma/sigma; } } /* System.out.println("N: "+data.length); System.out.println("Value: "+value); double ds = 0.0; for (int i=0; i<x.length; i++) { ds += derivative[i]; System.out.println(i+" is: "+derivative[i]); } */ //System.out.println("Deriv sum is: "+ds); } public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels) { this(numFeatures, numClasses, data, labels, 1.0); } public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, double sigma) { this.numFeatures = numFeatures; this.numClasses = numClasses; this.data = data; this.labels = labels; this.sigma = sigma; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -