⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gistrainer.java

📁 最大熵分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
          predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]*values[ti][j];        }        else {                    predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti];        }      }    }    //printTable(predCount);    di = null; // don't need it anymore    // A fake "observation" to cover features which are not detected in    // the data.  The default is to assume that we observed "1/10th" of a    // feature during training.    final double smoothingObservation = _smoothingObservation;    // Get the observed expectations of the features. Strictly speaking,    // we should divide the counts by the number of Tokens, but because of    // the way the model's expectations are approximated in the    // implementation, this is cancelled out when we compute the next    // iteration of a parameter, making the extra divisions wasteful.    params = new MutableContext[numPreds];    modelExpects = new MutableContext[numPreds];    observedExpects = new MutableContext[numPreds];    evalParams = new EvalParameters(params,0,correctionConstant,numOutcomes);    int[] activeOutcomes = new int[numOutcomes];    int[] outcomePattern;    int[] allOutcomesPattern= new int[numOutcomes];    for (int oi = 0; oi < numOutcomes; oi++) {      allOutcomesPattern[oi] = oi;    }    int numActiveOutcomes = 0;    for (int pi = 0; pi < numPreds; pi++) {      numActiveOutcomes = 0;      if (useSimpleSmoothing) {        numActiveOutcomes = numOutcomes;        outcomePattern = allOutcomesPattern;      }      else { //determine active outcomes        for (int oi = 0; oi < numOutcomes; oi++) {          if (predCount[pi][oi] > 0 && predicateCounts[pi] > cutoff) {            activeOutcomes[numActiveOutcomes] = oi;            numActiveOutcomes++;          }        }        if (numActiveOutcomes == numOutcomes) {          outcomePattern = allOutcomesPattern;        }        else {          outcomePattern = new int[numActiveOutcomes];          for (int aoi=0;aoi<numActiveOutcomes;aoi++) {            outcomePattern[aoi] = activeOutcomes[aoi];          }        }      }      params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);      modelExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);      observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);      for (int aoi=0;aoi<numActiveOutcomes;aoi++) {        int oi = outcomePattern[aoi];        params[pi].setParameter(aoi, 0.0);        modelExpects[pi].setParameter(aoi, 0.0);        if (predCount[pi][oi] > 0) {            observedExpects[pi].setParameter(aoi, predCount[pi][oi]);        }        else if (useSimpleSmoothing) {           observedExpects[pi].setParameter(aoi,smoothingObservation);        }      }    }    // compute the expected value of correction    if (useSlackParameter) {      int cfvalSum = 0;      for (int ti = 0; ti < numUniqueEvents; ti++) {        for (int j = 0; j < contexts[ti].length; j++) {          int pi = contexts[ti][j];          if (!modelExpects[pi].contains(outcomeList[ti])) {            cfvalSum += numTimesEventsSeen[ti];          }        }        cfvalSum += (correctionConstant - contexts[ti].length) * numTimesEventsSeen[ti];      }      if (cfvalSum == 0) {        cfObservedExpect = Math.log(NEAR_ZERO); //nearly zero so log is defined      }      else {        cfObservedExpect = Math.log(cfvalSum);      }    }    predCount = null; // don't need it anymore    display("...done.\n");    modelDistribution = new double[numOutcomes];    numfeats = new int[numOutcomes];    /***************** Find the parameters ************************/    display("Computing model parameters...\n");    findParameters(iterations);    /*************** Create and return the model ******************/    return new GISModel(params, predLabels, outcomeLabels, correctionConstant, evalParams.correctionParam);  }  /* Estimate and return the model parameters. */  private void findParameters(int iterations) {    double prevLL = 0.0;    double currLL = 0.0;    display("Performing " + iterations + " iterations.\n");    for (int i = 1; i <= iterations; i++) {      if (i < 10)        display("  " + i + ":  ");      else if (i < 100)        display(" " + i + ":  ");      else        display(i + ":  ");      currLL = nextIteration();      if (i > 1) {        if (prevLL > currLL) {          System.err.println("Model Diverging: loglikelihood decreased");          break;        }        if (currLL - prevLL < LLThreshold) {          break;        }      }      prevLL = currLL;    }    // kill a bunch of these big objects now that we don't need them    observedExpects = null;    modelExpects = null;    numTimesEventsSeen = null;    contexts = null;  }    //modeled on implementation in  Zhang Le's maxent kit  private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) {    double param = params[predicate].getParameters()[oid];    double x = 0.0;    double x0 = 0.0;    double f;    double tmp;    double fp;    double modelValue = modelExpects[predicate].getParameters()[oid];    double observedValue = observedExpects[predicate].getParameters()[oid];    for (int i = 0; i < 50; i++) {      tmp = modelValue * Math.exp(correctionConstant * x0);      f = tmp + (param + x0) / sigma - observedValue;      fp = tmp * correctionConstant + 1 / sigma;      if (fp == 0) {        break;      }      x = x0 - f / fp;      if (Math.abs(x - x0) < 0.000001) {        x0 = x;        break;      }      x0 = x;    }    return x0;  }    /* Compute one iteration of GIS and retutn log-likelihood.*/  private double nextIteration() {    // compute contribution of p(a|b_i) for each feature and the new    // correction parameter    double loglikelihood = 0.0;    CFMOD = 0.0;    int numEvents = 0;    int numCorrect = 0;    for (int ei = 0; ei < numUniqueEvents; ei++) {      if (values != null) {        prior.logPrior(modelDistribution,contexts[ei],values[ei]);        GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);      }      else {        prior.logPrior(modelDistribution,contexts[ei]);        GISModel.eval(contexts[ei], modelDistribution, evalParams);      }      for (int j = 0; j < contexts[ei].length; j++) {        int pi = contexts[ei][j];        if (predicateCounts[pi] >= cutoff) {          int[] activeOutcomes = modelExpects[pi].getOutcomes();          for (int aoi=0;aoi<activeOutcomes.length;aoi++) {            int oi = activeOutcomes[aoi];            if (values != null && values[ei] != null) {              modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);            }            else {              modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);            }          }          if (useSlackParameter) {            for (int oi = 0; oi < numOutcomes; oi++) {              if (!modelExpects[pi].contains(oi)) {                CFMOD += modelDistribution[oi] * numTimesEventsSeen[ei];              }            }          }        }      }      if (useSlackParameter)        CFMOD += (evalParams.correctionConstant - contexts[ei].length) * numTimesEventsSeen[ei];            loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];      numEvents += numTimesEventsSeen[ei];      if (printMessages) {        int max = 0;        for (int oi = 1; oi < numOutcomes; oi++) {          if (modelDistribution[oi] > modelDistribution[max]) {            max = oi;          }        }        if (max == outcomeList[ei]) {          numCorrect += numTimesEventsSeen[ei];        }      }    }    display(".");    // compute the new parameter values    for (int pi = 0; pi < numPreds; pi++) {      double[] observed = observedExpects[pi].getParameters();      double[] model = modelExpects[pi].getParameters();      int[] activeOutcomes = params[pi].getOutcomes();      for (int aoi=0;aoi<activeOutcomes.length;aoi++) {        if (useGaussianSmoothing) {          params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,evalParams.correctionConstant));        }        else {          if (model[aoi] == 0) {            System.err.println("Model expects == 0 for "+predLabels[pi]+" "+outcomeLabels[aoi]);          }          params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi])));        }        modelExpects[pi].setParameter(aoi,0.0); // re-initialize to 0.0's      }    }    if (CFMOD > 0.0 && useSlackParameter)        evalParams.correctionParam += (cfObservedExpect - Math.log(CFMOD));    display(". loglikelihood=" + loglikelihood + "\t" + ((double) numCorrect / numEvents) + "\n");    return (loglikelihood);  }  private void display(String s) {    if (printMessages)      System.out.print(s);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -