⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf4.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      assert defaultWeights != null;      assert featureSelections != null;      assert weightsFrozen != null;      int n = weights.length;      assert defaultWeights.length == n;      assert featureSelections.length == n;      assert weightsFrozen.length == n;    }  }	public int numStates () { return states.size(); }	public Transducer.State getState (int index) {		return (Transducer.State) states.get(index); }		public Iterator initialStateIterator () {		return initialStates.iterator (); }	public boolean isTrainable () { return trainable; }	public void setTrainable (boolean f)	{		if (f != trainable) {			if (f) {				constraints = new SparseVector[weights.length];				expectations = new SparseVector[weights.length];				defaultConstraints = new double[weights.length];				defaultExpectations = new double[weights.length];				for (int i = 0; i < weights.length; i++) {					constraints[i] = (SparseVector) weights[i].cloneMatrixZeroed ();					expectations[i] = (SparseVector) weights[i].cloneMatrixZeroed ();				}			} else {				constraints = expectations = null;				defaultConstraints = defaultExpectations = null;			}			for (int i = 0; i < numStates(); i++)				((State)getState(i)).setTrainable(f);			trainable = f;		}	}	public double getParametersAbsNorm ()	{		double ret = 0;		for (int i = 0; i < numStates(); i++) {			State s = (State) getState (i);			ret += Math.abs (s.initialCost);			ret += Math.abs (s.finalCost);		}		for (int i = 0; i < weights.length; i++) {			ret += Math.abs (defaultWeights[i]);			ret += weights[i].absNorm();		}		return ret;	}	/** Only sets the parameter from the first group of parameters. */	public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value)	{		cachedValueStale = cachedGradientStale = true;		State source = (State)getState(sourceStateIndex);		State dest = (State) getState(destStateIndex);		int rowIndex;		for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)			if (source.destinationNames[rowIndex].equals (dest.name))				break;		if (rowIndex == source.destinationNames.length)			throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");		int weightsIndex = source.weightsIndices[rowIndex][0];		if (featureIndex < 0)			defaultWeights[weightsIndex] = value;		else {			weights[weightsIndex].setValue (featureIndex, value);		}		someTrainingDone = true;	}	/** Only gets the parameter from the first group of parameters. */	public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value)	{		State source = (State)getState(sourceStateIndex);		State dest = (State) getState(destStateIndex);		int rowIndex;		for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)			if (source.destinationNames[rowIndex].equals (dest.name))				break;		if (rowIndex == source.destinationNames.length)			throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");		int weightsIndex = source.weightsIndices[rowIndex][0];		if (featureIndex < 0)			return defaultWeights[weightsIndex];		else			return weights[weightsIndex].value (featureIndex);	}		public void reset ()	{		throw new UnsupportedOperationException ("Not used in CRFs");	}	public void estimate ()	{		if (!trainable)			throw new IllegalStateException ("This transducer not currently trainable.");		// xxx Put stuff in here.		throw new UnsupportedOperationException ("Not yet implemented.  Never?");	}  //xxx experimental  public Sequence transduce (Object unpipedInput)  {    Instance inst = new Instance (unpipedInput, null, null, null, inputPipe);    return transduce ((FeatureVectorSequence) inst.getData ());  }  public Sequence transduce (Sequence input)  {    if (!(input instanceof FeatureVectorSequence))      throw new IllegalArgumentException ("CRF4.transduce requires FeatureVectorSequence.  This may be relaxed later...");    switch (transductionType) {      case VITERBI:        ViterbiPath lattice = viterbiPath (input);        return lattice.output ();      // CPAL - added viterbi "beam search"      case VITERBI_FBEAM:        ViterbiPathBeam lattice2 = viterbiPathBeam(input);        return lattice2.output();      //CPAL - added viterbi backward beam search      case VITERBI_BBEAM:        ViterbiPathBeamB lattice3 = viterbiPathBeamB(input);        return lattice3.output();      // CPAL - added viterbi forward backward beam      //      - we may call this constrained "max field" or something similar later      case VITERBI_FBBEAM:        ViterbiPathBeamFB lattice4 = viterbiPathBeamFB(input);        return lattice4.output();      // CPAL - added an adaptive viterbi "beam search"      case VITERBI_FBEAMKL:        ViterbiPathBeamKL lattice5 = viterbiPathBeamKL(input);        return lattice5.output();      default:        throw new IllegalStateException ("Unknown CRF4 transuction type "+transductionType);    }  }  public void print ()  {    print (new PrintWriter (new OutputStreamWriter (System.out), true));  }  // yyy  public void print (PrintWriter out)  {    out.println ("*** CRF STATES ***");    for (int i = 0; i < numStates (); i++) {      State s = (State) getState (i);      out.print ("STATE NAME=\"");      out.print (s.name); out.print ("\" ("); out.print (s.destinations.length); out.print (" outgoing transitions)\n");      out.print ("  "); out.print ("initialCost = "); out.print (s.initialCost); out.print ('\n');      out.print ("  "); out.print ("finalCost = "); out.print (s.finalCost); out.print ('\n');      out.println ("  transitions:");      for (int j = 0; j < s.destinations.length; j++) {        out.print ("    "); out.print (s.name); out.print (" -> "); out.println (s.getDestinationState (j).getName ());        for (int k = 0; k < s.weightsIndices[j].length; k++) {          out.print ("        WEIGHTS = \"");          int widx = s.weightsIndices[j][k];          out.print (weightAlphabet.lookupObject (widx).toString ());          out.print ("\"\n");        }      }      out.println ();    }    out.println ("\n\n\n*** CRF WEIGHTS ***");    for (int widx = 0; widx < weights.length; widx++) {      out.println ("WEIGHTS NAME = " + weightAlphabet.lookupObject (widx));      out.print (": <DEFAULT_FEATURE> = "); out.print (defaultWeights[widx]); out.print ('\n');      SparseVector transitionWeights = weights[widx];      if (transitionWeights.numLocations () == 0)        continue;      RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights);      for (int m = 0; m < rfv.numLocations (); m++) {        double v = rfv.getValueAtRank (m);        int index = rfv.indexAtLocation (rfv.getIndexAtRank (m));        Object feature = inputAlphabet.lookupObject (index);        if (v != 0) {          out.print (": "); out.print (feature); out.print (" = "); out.println (v);        }      }    }    out.flush ();  }	// Java question:	// If I make a non-static inner class CRF.Trainer,	// can that class by subclassed in another .java file,	// and can that subclass still have access to all the CRF's	// instance variables?  // ANSWER: Yes and yes, but you have to use special syntax in the subclass ctor (see mallet-dev archive) -cas	public boolean train (InstanceList ilist)	{		return train (ilist, (InstanceList)null, (InstanceList)null);	}	public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing)	{		return train (ilist, validation, testing, (TransducerEvaluator)null);	}		public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing,										 TransducerEvaluator eval)	{		return train (ilist, validation, testing, eval, 9999);	}	public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing,												TransducerEvaluator eval, int numIterations)	{		if (numIterations <= 0)			return false;		assert (ilist.size() > 0);		if (useSparseWeights) {			setWeightsDimensionAsIn (ilist);			 		} else {			setWeightsDimensionDensely ();		}				MaximizableCRF mc = new MaximizableCRF (ilist, this);		//Maximizer.ByGradient minimizer = new ConjugateGradient (0.001);		Maximizer.ByGradient maximizer = new LimitedMemoryBFGS();		int i;		boolean continueTraining = true;		boolean converged = false;		logger.info ("CRF about to train with "+numIterations+" iterations");		for (i = 0; i < numIterations; i++) {			try {                // CPAL - added this to alter forward backward beam parameters based on iteration                setCurIter(i);  // CPAL - this resets the tctIter as well				converged = maximizer.maximize (mc, 1);                logger.info ("CRF took " + tctIter + " intermediate iterations");                //if (i!=0) {                //    if(tctIter>1) {                        // increase the beam size                //    }                //}                // CPAL - done				logger.info ("CRF finished one iteration of maximizer, i="+i);			} catch (IllegalArgumentException e) {				e.printStackTrace();				logger.info ("Catching exception; saying converged.");				converged = true;			}			if (eval != null) {				continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i,																					converged, mc.getValue(), ilist, validation, testing);				if (!continueTraining)					break;			}			if (converged) {				logger.info ("CRF training has converged, i="+i);				break;			}		}		logger.info ("About to setTrainable(false)");		// Free the memory of the expectations and constraints		setTrainable (false);		logger.info ("Done setTrainable(false)");		return converged;	}	public boolean train (InstanceList training, InstanceList validation, InstanceList testing,												TransducerEvaluator eval, int numIterations,												int numIterationsPerProportion,												double[] trainingProportions)	{		int trainingIteration = 0;		for (int i = 0; i < trainingProportions.length; i++) {			// Train the CRF			InstanceList theTrainingData = training;			if (trainingProportions != null && i < trainingProportions.length) {				logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");				InstanceList[] sampledTrainingData = training.split (new Random(1),																															new double[] {trainingProportions[i],	1-trainingProportions[i]});				theTrainingData = sampledTrainingData[0];			}			boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion);			trainingIteration += numIterationsPerProportion;		}		logger.info ("Training on 100% of the data this round, for "+												(numIterations-trainingIteration)+" iterations.");		return this.train (training, validation, testing,											 eval, numIterations - trainingIteration);	}	public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions)	{		return trainWithFeatureInduction (trainingData, validationData, testingData,																			eval, numIterations, numIterationsBetweenFeatureInductions,																			numFeatureInductions, numFeaturesPerFeatureInduction,																			trueLabelProbThreshold, clusteredFeatureInduction,																			trainingProportions, "exp");	}		public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions,																						String gainName)	{		int trainingIteration = 0;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -