📄 gistrainer.java
字号:
outcomeLabels = di.outcomeLabels; numOutcomes = outcomeLabels.length; iprob = Math.log(1.0/numOutcomes); predLabels = di.predLabels; numPreds = predLabels.length; display("\tNumber of Event Tokens: " + numTokens +"\n"); display("\t Number of Outcomes: " + numOutcomes +"\n"); display("\t Number of Predicates: " + numPreds +"\n"); // set up feature arrays int[][] predCount = new int[numPreds][numOutcomes]; for (TID=0; TID<numTokens; TID++) for (int j=0; j<contexts[TID].length; j++) predCount[contexts[TID][j]][di.outcomeList[TID]] += numTimesEventsSeen[TID]; //printTable(predCount); di = null; // don't need it anymore // A fake "observation" to cover features which are not detected in // the data. The default is to assume that we observed "1/10th" of a // feature during training. final double smoothingObservation = _smoothingObservation; final double logSmoothingObservation = Math.log(_smoothingObservation); // Get the observed expectations of the features. Strictly speaking, // we should divide the counts by the number of Tokens, but because of // the way the model's expectations are approximated in the // implementation, this is cancelled out when we compute the next // iteration of a parameter, making the extra divisions wasteful. params = new TIntDoubleHashMap[numPreds]; modifiers = new TIntDoubleHashMap[numPreds]; observedExpects = new TIntDoubleHashMap[numPreds]; int initialCapacity; float loadFactor = (float)0.9; if (numOutcomes < 3) { initialCapacity = 2; loadFactor = (float)1.0; } else if (numOutcomes < 5) { initialCapacity = 2; } else { initialCapacity = (int)numOutcomes/2; } for (PID=0; PID<numPreds; PID++) { params[PID] = new TIntDoubleHashMap(initialCapacity, loadFactor); modifiers[PID] = new TIntDoubleHashMap(initialCapacity, loadFactor); observedExpects[PID] = new TIntDoubleHashMap(initialCapacity, loadFactor); for (OID=0; OID<numOutcomes; OID++) { if (predCount[PID][OID] > 0) { params[PID].put(OID, 0.0); modifiers[PID].put(OID, 0.0); observedExpects[PID].put(OID,Math.log(predCount[PID][OID])); } else if (_simpleSmoothing) { params[PID].put(OID, 0.0); modifiers[PID].put(OID, 0.0); observedExpects[PID].put(OID, logSmoothingObservation); } } params[PID].compact(); modifiers[PID].compact(); observedExpects[PID].compact(); } // compute the expected value of correction int cfvalSum = 0; for (TID=0; TID<numTokens; TID++) { for (int j=0; j<contexts[TID].length; j++) { PID = contexts[TID][j]; if (!modifiers[PID].containsKey(outcomes[TID])) { cfvalSum+=numTimesEventsSeen[TID]; } } cfvalSum += (constant - contexts[TID].length) * numTimesEventsSeen[TID]; } if (cfvalSum == 0) { cfObservedExpect = Math.log(NEAR_ZERO);//nearly zero so log is defined } else { cfObservedExpect = Math.log(cfvalSum); } correctionParam = 0.0; predCount = null; // don't need it anymore display("...done.\n"); modelDistribution = new double[numOutcomes]; numfeats = new int[numOutcomes]; /***************** Find the parameters ************************/ display("Computing model parameters...\n"); findParameters(iterations); /*************** Create and return the model ******************/ return new GISModel(params, predLabels, outcomeLabels, constant, correctionParam); } /* Estimate and return the model parameters. */ private void findParameters(int iterations) { double prevLL = 0.0; double currLL = 0.0; display("Performing " + iterations + " iterations.\n"); for (int i=1; i<=iterations; i++) { if (i<10) display(" " + i + ": "); else if (i<100) display(" " + i + ": "); else display(i + ": "); currLL=nextIteration(); if (i > 1) { if (prevLL > currLL) { System.err.println("Model Diverging: loglikelihood decreased"); break; } if (currLL-prevLL < LLThreshold) { break; } } prevLL=currLL; } // kill a bunch of these big objects now that we don't need them observedExpects = null; modifiers = null; numTimesEventsSeen = null; contexts = null; } /** * Use this model to evaluate a context and return an array of the * likelihood of each outcome given that context. * * @param context The integers of the predicates which have been * observed at the present decision point. * @return The normalized probabilities for the outcomes given the * context. The indexes of the double[] are the outcome * ids, and the actual string representation of the * outcomes can be obtained from the method * getOutcome(int i). */ public void eval(int[] context, double[] outsums) { for (int oid=0; oid<numOutcomes; oid++) { outsums[oid] = iprob; numfeats[oid] = 0; } int[] activeOutcomes; for (int i=0; i<context.length; i++) { TIntDoubleHashMap predParams = params[context[i]]; activeOutcomes = predParams.keys(); for (int j=0; j<activeOutcomes.length; j++) { int oid = activeOutcomes[j]; numfeats[oid]++; outsums[oid] += constantInverse * predParams.get(oid); } } double SUM = 0.0; for (int oid=0; oid<numOutcomes; oid++) { outsums[oid] = Math.exp(outsums[oid] + ((1.0 - ((double) numfeats[oid]/constant)) * correctionParam)); SUM += outsums[oid]; } for (int oid=0; oid<numOutcomes; oid++) outsums[oid] /= SUM; } /* Compute one iteration of GIS and retutn log-likelihood.*/ private double nextIteration() { // compute contribution of p(a|b_i) for each feature and the new // correction parameter double loglikelihood = 0.0; CFMOD=0.0; int numEvents=0; for (TID=0; TID<numTokens; TID++) { // TID, modeldistribution and PID are globals used in // the updateModifiers procedure. They need to be set. eval(contexts[TID],modelDistribution); for (int j=0; j<contexts[TID].length; j++) { PID = contexts[TID][j]; modifiers[PID].forEachEntry(updateModifiers); for (OID=0;OID<numOutcomes;OID++) { if (!modifiers[PID].containsKey(OID)) { CFMOD+=modelDistribution[OID]*numTimesEventsSeen[TID]; } } } CFMOD+=(constant-contexts[TID].length)*numTimesEventsSeen[TID]; loglikelihood+=Math.log(modelDistribution[outcomes[TID]])*numTimesEventsSeen[TID]; numEvents+=numTimesEventsSeen[TID]; } display("."); // compute the new parameter values for (PID=0; PID<numPreds; PID++) { params[PID].forEachEntry(updateParams); modifiers[PID].transformValues(backToZeros); // re-initialize to 0.0's } if (CFMOD > 0.0) correctionParam +=(cfObservedExpect - Math.log(CFMOD)); display(". loglikelihood="+loglikelihood+"\n"); return(loglikelihood); } private void display (String s) { if (printMessages) System.out.print(s); } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -