⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crfbygisupdate.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
					initializingInfiniteCosts = true;				}				// Clear the sufficient statistics that we are about to fill				for (int i = 0; i < numStates(); i++) {					State s = (State)getState(i);					s.initialExpectation = 0;					s.finalExpectation = 0;				}				for (int i = 0; i < weights.length; i++) {					expectations[i].setAll (0.0);					defaultExpectations[i] = 0.0;				}				// Calculate the cost of each instance, and also fill in expectations				double unlabeledCost, labeledCost, cost;				for (int ii = 0; ii < trainingSet.size(); ii++) {					Instance instance = trainingSet.getInstance(ii);					FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();					FeatureSequence output = (FeatureSequence) instance.getTarget();					labeledCost = forwardBackward (input, output, false).getCost();					if (Double.isInfinite (labeledCost))						logger.warning (instance.getName().toString() + " has infinite labeled cost.\n"														+(instance.getSource() != null ? instance.getSource() : ""));					unlabeledCost = forwardBackward (input, true).getCost ();					if (Double.isInfinite (unlabeledCost))						logger.warning (instance.getName().toString() + " has infinite unlabeled cost.\n"														+(instance.getSource() != null ? instance.getSource() : ""));					// Here cost is -log(conditional probability correct label sequence)					cost = labeledCost - unlabeledCost;					//System.out.println ("Instance "+ii+" CRF.MinimizableCRF.getCost = "+cost);					if (Double.isInfinite(cost)) {						logger.warning (instance.getName().toString() + " has infinite cost; skipping.");						if (initializingInfiniteCosts)							infiniteCosts.set (ii);						else if (!infiniteCosts.get(ii))							throw new IllegalStateException ("Instance i used to have non-infinite cost, "																							 +"but now it has infinite cost.");						continue;					} else {						cachedCost += cost;					}				}								// Incorporate prior on parameters				if (usingHyperbolicPrior) {					// Hyperbolic prior					for (int i = 0; i < numStates(); i++) {						State s = (State) getState (i);						if (!Double.isInfinite(s.initialCost))							cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness														 * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.initialCost)));						if (!Double.isInfinite(s.finalCost))							cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness														 * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.finalCost)));					}					for (int i = 0; i < weights.length; i++) {						cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness													 * Math.log (Maths.cosh (hyperbolicPriorSharpness * defaultWeights[i])));						for (int j = 0; j < weights[i].numLocations(); j++) {							double w = weights[i].valueAtLocation(j);							if (!Double.isInfinite(w))								cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness															 * Math.log (Maths.cosh (hyperbolicPriorSharpness * w)));						}					}				} else {					// Gaussian prior					double priorDenom = 2 * gaussianPriorVariance;					for (int i = 0; i < numStates(); i++) {						State s = (State) getState (i);						if (!Double.isInfinite(s.initialCost))							cachedCost += s.initialCost * s.initialCost / priorDenom;						if (!Double.isInfinite(s.finalCost))							cachedCost += s.finalCost * s.finalCost / priorDenom;					}					for (int i = 0; i < weights.length; i++) {						if (!Double.isInfinite(defaultWeights[i]))							cachedCost += defaultWeights[i] * defaultWeights[i] / priorDenom;						for (int j = 0; j < weights[i].numLocations(); j++) {							double w = weights[i].valueAtLocation (j);							if (!Double.isInfinite(w))								cachedCost += w * w / priorDenom;						}					}				}				cachedCostStale = false;				System.out.println ("getCost() (-loglikelihood) = "+cachedCost);				logger.fine ("getCost() (-loglikelihood) = "+cachedCost);				//crf.print();				long endingTime = System.currentTimeMillis();				System.out.println ("Inference milliseconds = "+(endingTime - startingTime));			}			return cachedCost;		}		private boolean checkForNaN ()		{			for (int i = 0; i < weights.length; i++) {				assert (!weights[i].isNaN());				assert (constraints == null || !constraints[i].isNaN());				assert (expectations == null || !expectations[i].isNaN());				assert (!Double.isNaN(defaultExpectations[i]));				assert (!Double.isNaN(defaultConstraints[i]));			}			for (int i = 0; i < numStates(); i++) {				State s = (State) getState (i);				assert (!Double.isNaN (s.initialExpectation));				assert (!Double.isNaN (s.initialConstraint));				assert (!Double.isNaN (s.initialCost));				assert (!Double.isNaN (s.finalExpectation));				assert (!Double.isNaN (s.finalConstraint));				assert (!Double.isNaN (s.finalCost));			}			return true;		}			/**		 * Returns a GIS update for the current parameter setting specified by params		 *		 * @param params feature weights of current model		 * @param updates Matrix Object in which to store the updates		 */		public void getGISUpdate (Matrix params, Matrix updates)		{			assert (params instanceof DenseVector);			assert (updates instanceof DenseVector);						// get max empirical feature count and store in maxCount.			if(maxCount <= 1.0) {				int pi = 0;				for(int i = 0; i < numStates(); i++) {					pi++;					pi++;				}				for (int i = 0; i < weights.length; i++) {					if(defaultConstraints[i] > maxCount)						maxCount = defaultConstraints[i];					pi++;					int nl = weights[i].numLocations();					for (int j = 0; j < nl; j++) {						if(constraints[i].valueAtLocation(j) > maxCount)							maxCount = constraints[i].valueAtLocation(j);						pi++;					}				}			}			computeGISUpdate((DenseVector)params,(DenseVector)updates,gaussianPriorVariance,maxCount);					}		private void computeGISUpdate(DenseVector lambda, DenseVector updates, 																 double sigma, double s) {			double p = 1/(s*sigma*sigma);			double inv_s = 1/(sigma*sigma);						int pi = 0;			for(int i = 0; i < numStates(); i++) {				updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, 1.0, 1.0)/s);				pi++;				updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, 1.0, 1.0)/s);				pi++;			}			for (int i = 0; i < weights.length; i++) {				updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, defaultConstraints[i], defaultExpectations[i])/s);				pi++;				int nl = weights[i].numLocations();				for (int j = 0; j < nl; j++) {					updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, constraints[i].valueAtLocation(j), expectations[i].valueAtLocation(j))/s);					pi++;				}			}		}			private double gis_solver(double m, double p, double e1, double e2) {			int iter=0;			double x = 2;			double new_x = 1;			boolean find = false;			while(x>0 && iter <5000) {				x= new_x;				new_x = x*(p-m-p*Math.log(x) +e1)/(e2*x+p);				if(Math.abs(new_x-x)<=(0.001*x) ) {					return Math.log(x);				}				iter++;			}			return 0;		}		//Serialization of MinimizableCRF				private static final long serialVersionUID = 1;		private static final int CURRENT_SERIAL_VERSION = 0;			private void writeObject (ObjectOutputStream out) throws IOException {			out.writeInt (CURRENT_SERIAL_VERSION);			out.writeObject(trainingSet);			out.writeDouble(cachedCost);			out.writeObject(cachedGISUpdate);			out.writeObject(infiniteCosts);			out.writeInt(numParameters);			out.writeObject(crf);		}				private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {			int version = in.readInt ();			trainingSet = (InstanceList) in.readObject();			cachedCost = in.readDouble();			cachedGISUpdate = (DenseVector) in.readObject();			infiniteCosts = (BitSet) in.readObject();			numParameters = in.readInt();			crf = (CRFByGISUpdate)in.readObject();		}	}	public static class State extends Transducer.State implements Serializable	{		// Parameters indexed by destination state, feature index		double initialConstraint, initialExpectation;		double finalConstraint, finalExpectation;		String name;		int index;		String[] destinationNames;		State[] destinations;		int[][] weightsIndices;								// contains indices into CRF.weights[],		String[] labels;		CRFByGISUpdate crf;				// No arg constructor so serialization works				protected State() {			super ();		}						protected State (String name, int index,										 double initialCost, double finalCost,										 String[] destinationNames,										 String[] labelNames,										 String[][] weightNames,										 CRFByGISUpdate crf)		{			super ();			assert (destinationNames.length == labelNames.length);			assert (destinationNames.length == weightNames.length);			this.name = name;			this.index = index;			this.initialCost = initialCost;			this.finalCost = finalCost;			this.destinationNames = new String[destinationNames.length];			this.destinations = new State[labelNames.length];			this.weightsIndices = new int[labelNames.length][];			this.labels = new String[labelNames.length];			this.crf = crf;			for (int i = 0; i < labelNames.length; i++) {				// Make sure this label appears in our output Alphabet				crf.outputAlphabet.lookupIndex (labelNames[i]);				this.destinationNames[i] = destinationNames[i];				this.labels[i] = labelNames[i];				this.weightsIndices[i] = new int[weightNames[i].length];				for (int j = 0; j < weightNames[i].length; j++)					this.weightsIndices[i][j] = crf.getWeightsIndex (weightNames[i][j]);			}			crf.cachedCostStale = crf.cachedGISUpdateStale = true;		}						public void print ()		{			System.out.println ("State #"+index+" \""+name+"\"");			System.out.println ("initialCost="+initialCost+", finalCost="+finalCost);			System.out.println ("#destinations="+destinations.length);			for (int i = 0; i < destinations.length; i++)				System.out.println ("-> "+destinationNames[i]);		}				public State getDestinationState (int index)		{			State ret;			if ((ret = destinations[index]) == null) {				ret = destinations[index] = (State) crf.name2state.get (destinationNames[index]);				//if (ret == null) System.out.println ("this.name="+this.name+" index="+index+" destinationNames[index]="+destinationNames[index]+" name2state.size()="+ crf.name2state.size());				assert (ret != null) : index;			}			return ret;		}				public void setTrainable (boolean f)		{			if (f) {				initialConstraint = finalConstraint = 0;				initialExpectation = finalExpectation = 0;			}		}				public Transducer.TransitionIterator transitionIterator (			Sequence inputSequence, int inputPosition,			Sequence outputSequence, int outputPosition)		{			if (inputPosition < 0 || outputPosition < 0)				throw new UnsupportedOperationException ("Epsilon transitions not implemented.");			if (inputSequence == null)				throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence.");			return new TransitionIterator (				this, (FeatureVectorSequence)inputSequence, inputPosition,				(outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf);		}				public String getName () { return name; }				public int getIndex () { return index; }		public void incrementInitialCount (double count)		{			//System.out.println ("incrementInitialCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count);			assert (crf.trainable || crf.gatheringWeightsPresent);			if (crf.gatheringConstraints)				initialConstraint += count;			else				initialExpectation += count;		}				public void incrementFinalCount (double count)		{			//System.out.println ("incrementFinalCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count);			assert (crf.trainable || crf.gatheringWeightsPresent);			if (crf.gatheringConstraints)				finalConstraint += count;			else				finalExpectation += count;		}				// Serialization		// For  class State								private static final long serialVersionUID = 1;		private static final int CURRENT_SERIAL_VERSION = 0;		private static final int NULL_INTEGER = -1;				private void writeObject (ObjectOutputStream out) throws IOException {			int i, size;			out.writeInt (CURRENT_SERIAL_VERSION);			out.writeDouble(initialConstraint);			out.writeDouble(initialExpectation);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -