📄 expgain.java
字号:
|| Double.isNaN(alphas[i][j]) || Double.isNaN(alphachange)) // Print just a sampling of them... logger.info ("alpha["+i+"]["+j+"]="+alphas[i][j]+ " p="+p[i][j]+ " q="+q[i][j]+ " dalpha="+dalphas[i][j]+ " ddalpha="+ddalphas[i][j]+ " alphachange="+alphachange+ " min="+alphaMin[i][j]+ " max="+alphaMax[i][j]); if (Double.isNaN(alphas[i][j]) || Double.isNaN(dalphas[i][j]) || Double.isNaN(ddalphas[i][j]) || Double.isInfinite(alphas[i][j]) || Double.isInfinite(dalphas[i][j]) || Double.isInfinite(ddalphas[i][j])) alphachange = 0;// assert (!Double.isNaN(alphas[i][j]));// assert (!Double.isNaN(dalphas[i][j]));// assert (!Double.isNaN(ddalphas[i][j])); oldalpha = alphas[i][j]; // xxx assert (ddalphas[i][j] <= 0); //assert (Math.abs(alphachange) < 100.0) : alphachange; // xxx arbitrary? // Trying to prevent a cycle if (Math.abs(alphachange + alphaChangeOld[i][j]) / Math.abs(alphachange) < 0.01) newalpha = alphas[i][j] + alphachange / 2; else newalpha = alphas[i][j] + alphachange; if (alphachange < 0 && alphaMax[i][j] > alphas[i][j]) { //System.out.println ("Updating alphaMax["+i+"]["+j+"] = "+alphas[i][j]); alphaMax[i][j] = alphas[i][j]; } if (alphachange > 0 && alphaMin[i][j] < alphas[i][j]) { //System.out.println ("Updating alphaMin["+i+"]["+j+"] = "+alphas[i][j]); alphaMin[i][j] = alphas[i][j]; } if (newalpha <= alphaMax[i][j] && newalpha >= alphaMin[i][j]) // Newton wants to jump to a point inside the boundaries; let it alphas[i][j] = newalpha; else { // Newton wants to jump to a point outside the boundaries; bisect instead assert (alphaMax[i][j] != Double.POSITIVE_INFINITY); assert (alphaMin[i][j] != Double.NEGATIVE_INFINITY); alphas[i][j] = alphaMin[i][j] + (alphaMax[i][j] - alphaMin[i][j]) / 2; //System.out.println ("Newton tried to exceed bounds; bisecting. dalphas["+i+"]["+j+"]="+dalphas[i][j]+" alphaMin="+alphaMin[i][j]+" alphaMax="+alphaMax[i][j]); } alphachange = alphas[i][j] - oldalpha; if (Math.abs(alphachange) > maxAlphachange) maxAlphachange = Math.abs (alphachange); if (Math.abs (dalphas[i][j]) > maxDalpha) maxDalpha = Math.abs (dalphas[i][j]); alphaChangeOld[i][j] = alphachange; } logger.info ("After "+newton+" Newton iterations, maximum alphachange="+maxAlphachange+ " dalpha="+maxDalpha); } // Allow some memory to be freed //q = null; ddalphas = dalphas = alphaChangeOld = alphaMin = alphaMax = null; // "q[e^{\alpha g}]", p.4 //System.out.println ("Calculating qeag..."); // Note that we are using a gaussian prior, so we don't multiply by (1/numInstances) double[][] qeag = new double[numClasses][numFeatures]; for (int i = 0; i < ilist.size(); i++) { assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet()); Instance inst = ilist.getInstance(i); Labeling labeling = inst.getLabeling (); FeatureVector fv = (FeatureVector) inst.getData (); int fvMaxLocation = fv.numLocations()-1; for (int li = 0; li < numClasses; li++) { double modelLabelWeight = classifications[i].value(li); // Following line now done before outside of loop over instances // for (int fi = 0; fi < numFeatures; fi++) qeag[li][fi] += modelLabelWeight; // * 1.0; for (int fl = 0; fl < fv.numLocations(); fl++) { fli = fv.indexAtLocation(fl); // When the value of this feature "g" is zero, a value of 1.0 should be included // in the expectation; we'll actually add all these at the end (pre-"assuming" // that all features have value zero). Here we subtract the "assumed" // modelLabelWeight, and put in the true value based on non-zero valued feature "g". qeag[li][fli] += Math.log (modelLabelWeight * Math.exp (alphas[li][fli]) + (1-modelLabelWeight)); } } } //System.out.println ("Calculating klgain values..."); double[] klgains = new double[numFeatures]; double klgainIncr, alpha; for (int i = 0; i < numClasses; i++) for (int j = 0; j < numFeatures; j++) { assert (!Double.isInfinite(alphas[i][j])); alpha = alphas[i][j]; if (alpha == 0) continue; klgainIncr = (alpha * p[i][j]) - qeag[i][j] - (alpha*alpha/(2*gaussianPriorVariance)); if (klgainIncr < 0) { if (false) logger.info ("WARNING: klgainIncr["+i+"]["+j+"]="+klgainIncr+ " alpha="+alphas[i][j]+ " feature="+ilist.getDataAlphabet().lookupObject(j)+ " class="+ilist.getTargetAlphabet().lookupObject(i)); } else klgains[j] += klgainIncr; } if (false) { logger.info ("klgains.length="+klgains.length); for (int j = 0; j < numFeatures; j++) { if (j % (numFeatures/100) == 0) { for (int i = 0; i < numClasses; i++) { logger.info ("c="+i+" p["+ilist.getDataAlphabet().lookupObject(j)+"] = "+p[i][j]); logger.info ("c="+i+" q["+ilist.getDataAlphabet().lookupObject(j)+"] = "+q[i][j]); logger.info ("c="+i+" alphas["+ilist.getDataAlphabet().lookupObject(j)+"] = "+alphas[i][j]); logger.info ("c="+i+" qeag["+ilist.getDataAlphabet().lookupObject(j)+"] = "+qeag[i][j]); } logger.info ("klgains["+ilist.getDataAlphabet().lookupObject(j)+"] = "+klgains[j]); } } } return klgains; } public ExpGain (InstanceList ilist, LabelVector[] classifications, double gaussianPriorVariance) { super (ilist.getDataAlphabet(), calcExpGains (ilist, classifications, gaussianPriorVariance)); } private static LabelVector[] getLabelVectorsFromClassifications (Classification[] c) { LabelVector[] ret = new LabelVector[c.length]; for (int i = 0; i < c.length; i++) ret[i] = c[i].getLabelVector(); return ret; } public ExpGain (InstanceList ilist, Classification[] classifications, double gaussianPriorVariance) { super (ilist.getDataAlphabet(), calcExpGains (ilist, getLabelVectorsFromClassifications(classifications), gaussianPriorVariance)); } public static class Factory implements RankedFeatureVector.Factory { LabelVector[] classifications; double gaussianPriorVariance = 10.0; public Factory (LabelVector[] classifications) { this.classifications = classifications; } public Factory (LabelVector[] classifications, double gaussianPriorVariance) { this.classifications = classifications; this.gaussianPriorVariance = gaussianPriorVariance; } public RankedFeatureVector newRankedFeatureVector (InstanceList ilist) { assert (ilist.getTargetAlphabet() == classifications[0].getAlphabet()); return new ExpGain (ilist, classifications, gaussianPriorVariance); } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeInt(classifications.length); for (int i = 0; i < classifications.length; i++) out.writeObject(classifications[i]); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); int n = in.readInt(); this.classifications = new LabelVector[n]; for (int i = 0; i < n; i++) this.classifications[i] = (LabelVector)in.readObject(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -