📄 votedperceptron.java
字号:
package rmn;import java.util.*;public class VotedPerceptron { static public int MAX_ITER = 50; static public double LEARNING_RATE = 1.0 / 50; int m_T = MAX_ITER; FactorGraphModel m_fgm; Vector m_vecPF; public VotedPerceptron(FactorGraphModel fgm, Vector vecPF) { m_fgm = fgm; m_vecPF = vecPF; } public void setIterations(int T) { m_T = T; } public boolean train() { // allocate learning structures setLearning(true); // set all weights to one initWeights(); // init vector of potential factories used for previous factor graph Vector vecPF = new Vector(); int nTime = 0; // perceptron iterations for (int t = 0; t < m_T; t++) { System.out.println("VotedPerceptron iteration t = " + t); for (int i = 0; i < m_fgm.size(); i++) { System.out.println("\tFactorGraph fg = " + i); // set all counts to zero initCounts(vecPF); FactorGraph fg = m_fgm.getFactorGraph(i); // hide inferred value for hidden variables fg.setInfHidden(); fg.setMPE(); // compute true/inf counts and factories to be updated vecPF = fg.computeCounts(); // perceptron delta rule updateWeights(vecPF, nTime); nTime++; } } // add last weights into average averageWeights(m_vecPF, nTime); // set weights to voted version setWeightsAverage(m_fgm.size() * m_T); // deallocate learning structures setLearning(false); return true; } public void setLearning(boolean bLearning) { for (int i = 0; i < m_vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) m_vecPF.get(i); pf.setLearning(bLearning); } } protected void initWeights() { for (int i = 0; i < m_vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) m_vecPF.get(i); pf.initWeights(); } } protected void initCounts(Vector vecPF) { for (int i = 0; i < vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) vecPF.get(i); pf.initCounts(); } } protected void updateWeights(Vector vecPF, int nTime) { for (int i = 0; i < vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) vecPF.get(i); pf.updateWeights(nTime, LEARNING_RATE); } } protected void averageWeights(Vector vecPF, int nTime) { for (int i = 0; i < vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) vecPF.get(i); pf.averageWeights(nTime); } } protected void setWeightsAverage(int n) { for (int i = 0; i < m_vecPF.size(); i++) { PotentialFactory pf = (PotentialFactory) m_vecPF.get(i); pf.setWeightsAverage(n); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -