⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 parameters.java

📁 MSTParser是以最大生成树理论为基础的判别式依存句法分析器。它将一科依存树的得分看作是 所有依存关系的得分的总和
💻 JAVA
字号:
package mstparser;import java.io.*;import java.util.ArrayList;import gnu.trove.*;public class Parameters {    public double[] parameters;    public double[] total;    public String lossType = "punc";    public Parameters(int size) { 		parameters = new double[size];	total = new double[size];	for(int i = 0; i < parameters.length; i++) {	    parameters[i] = 0.0;	    total[i] = 0.0;	}	lossType = "punc";    }    public void setLoss(String lt) {	lossType = lt;    }    public void averageParams(double avVal) {	for(int j = 0; j < total.length; j++)	    total[j] *= 1.0/((double)avVal);			parameters = total;    }	    public void updateParamsMIRA(DependencyInstance inst, Object[][] d, double upd) {			String actParseTree = inst.actParseTree;	FeatureVector actFV = inst.fv;	int K = 0;	for(int i = 0; i < d.length && d[i][0] != null; i++) {	    K = i+1;	}	double[] b = new double[K];	double[] lam_dist = new double[K];	FeatureVector[] dist = new FeatureVector[K];	for(int k = 0; k < K; k++) {	    lam_dist[k] = getScore(actFV)		- getScore((FeatureVector)d[k][0]);	    b[k] = (double)numErrors(inst,(String)d[k][1],actParseTree);	    b[k] -= lam_dist[k];	    dist[k] = FeatureVector.getDistVector(actFV,(FeatureVector)d[k][0]);	}	double[] alpha = hildreth(dist,b);			FeatureVector fv  = null;	int res = 0;	for(int k = 0; k < K; k++) {	    fv = dist[k];					    for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next) {		if(curr.index < 0)		    continue;		parameters[curr.index] += alpha[k]*curr.value;		total[curr.index] += upd*alpha[k]*curr.value;	    }	}    }    public double getScore(FeatureVector fv) {	double score = 0.0;	for(FeatureVector curr = fv; curr.index >= 0; curr = curr.next) {	    score += parameters[curr.index]*curr.value;	}			return score;    }    private double[] hildreth(FeatureVector[] a, double[] b) {	int i;	int max_iter = 10000;	double eps = 0.00000001;	double zero = 0.000000000001;			double[] alpha = new double[b.length];	double[] F = new double[b.length];	double[] kkt = new double[b.length];	double max_kkt = Double.NEGATIVE_INFINITY;	int K = a.length;			double[][] A = new double[K][K];	boolean[] is_computed = new boolean[K];	for(i = 0; i < K; i++) {	    A[i][i] = FeatureVector.dotProduct(a[i],a[i]);	    is_computed[i] = false;	}					int max_kkt_i = -1;			for(i = 0; i < F.length; i++) {	    F[i] = b[i];	    kkt[i] = F[i];	    if(kkt[i] > max_kkt) { max_kkt = kkt[i]; max_kkt_i = i; }	}	int iter = 0;	double diff_alpha;	double try_alpha;	double add_alpha;		while(max_kkt >= eps && iter < max_iter) {				    diff_alpha = A[max_kkt_i][max_kkt_i] <= zero ? 0.0 : F[max_kkt_i]/A[max_kkt_i][max_kkt_i];	    try_alpha = alpha[max_kkt_i] + diff_alpha;	    add_alpha = 0.0;	    if(try_alpha < 0.0)		add_alpha = -1.0 * alpha[max_kkt_i];	    else		add_alpha = diff_alpha;	    alpha[max_kkt_i] = alpha[max_kkt_i] + add_alpha;	    if (!is_computed[max_kkt_i]) {		for(i = 0; i < K; i++) {		    A[i][max_kkt_i] = FeatureVector.dotProduct(a[i],a[max_kkt_i]); // for version 1		    is_computed[max_kkt_i] = true;		}	    }	    for(i = 0; i < F.length; i++) {		F[i] -= add_alpha * A[i][max_kkt_i];		kkt[i] = F[i];		if(alpha[i] > zero)		    kkt[i] = Math.abs(F[i]);	    }	    max_kkt = Double.NEGATIVE_INFINITY;	    max_kkt_i = -1;	    for(i = 0; i < F.length; i++)		if(kkt[i] > max_kkt) { max_kkt = kkt[i]; max_kkt_i = i; }	    iter++;	}	return alpha;    }        public double numErrors(DependencyInstance inst, String pred, String act) {	if(lossType.equals("nopunc"))	    return numErrorsDepNoPunc(inst,pred,act)+numErrorsLabelNoPunc(inst,pred,act);	return numErrorsDep(inst,pred,act)+numErrorsLabel(inst,pred,act);    }    public double numErrorsDep(DependencyInstance inst, String pred, String act) {			String[] act_spans = act.split(" ");	String[] pred_spans = pred.split(" ");	int correct = 0;			for(int i = 0; i < pred_spans.length; i++) {	    String p = pred_spans[i].split(":")[0]; String a = act_spans[i].split(":")[0];	    if(p.equals(a)) {		correct++;	    }	}			return ((double)act_spans.length - correct);		    }	    public double numErrorsLabel(DependencyInstance inst, String pred, String act) {			String[] act_spans = act.split(" ");	String[] pred_spans = pred.split(" ");	int correct = 0;			for(int i = 0; i < pred_spans.length; i++) {	    String p = pred_spans[i].split(":")[1]; String a = act_spans[i].split(":")[1];	    if(p.equals(a)) {		correct++;	    }	}			return ((double)act_spans.length - correct);		    }	    public double numErrorsDepNoPunc(DependencyInstance inst, String pred, String act) {			String[] act_spans = act.split(" ");	String[] pred_spans = pred.split(" ");		String[] pos = inst.pos;		int correct = 0;	int numPunc = 0;	for(int i = 0; i < pred_spans.length; i++) {	    String p = pred_spans[i].split(":")[0]; String a = act_spans[i].split(":")[0];	    if(pos[i+1].matches("[,:.'`]+")) {		numPunc++;		continue;	    }	    if(p.equals(a)) {		correct++;	    }	}			return ((double)act_spans.length - numPunc - correct);		    }	    public double numErrorsLabelNoPunc(DependencyInstance inst, String pred, String act) {			String[] act_spans = act.split(" ");	String[] pred_spans = pred.split(" ");	String[] pos = inst.pos;		int correct = 0;	int numPunc = 0;	for(int i = 0; i < pred_spans.length; i++) {	    String p = pred_spans[i].split(":")[1]; String a = act_spans[i].split(":")[1];	    if(pos[i+1].matches("[,:.'`]+")) {		numPunc++;		continue;	    }	    if(p.equals(a)) {		correct++;	    }	}			return ((double)act_spans.length - numPunc - correct);		    }	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -