⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 plsfilter.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
   *    * @param v		the vector to store in the matrix   * @param m		the receiving matrix   * @param columnIndex	the column to store the values in   */  protected void setVector(Matrix v, Matrix m, int columnIndex) {    m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v);  }    /**   * returns the (column) vector of the matrix at the specified index   *    * @param m		the matrix to work on   * @param columnIndex	the column to get the values from   * @return		the column vector   */  protected Matrix getVector(Matrix m, int columnIndex) {    return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex);  }  /**   * determines the dominant eigenvector for the given matrix and returns it   *    * @param m		the matrix to determine the dominant eigenvector for   * @return		the dominant eigenvector   */  protected Matrix getDominantEigenVector(Matrix m) {    EigenvalueDecomposition	eigendecomp;    double[]			eigenvalues;    int				index;    Matrix			result;        eigendecomp = m.eig();    eigenvalues = eigendecomp.getRealEigenvalues();    index       = Utils.maxIndex(eigenvalues);    result	= columnAsVector(eigendecomp.getV(), index);        return result;  }    /**   * normalizes the given vector (inplace)    *    * @param v		the vector to normalize   */  protected void normalizeVector(Matrix v) {    double	sum;    int		i;        // determine length    sum = 0;    for (i = 0; i < v.getRowDimension(); i++)      sum += v.get(i, 0) * v.get(i, 0);    sum = StrictMath.sqrt(sum);        // normalize content    for (i = 0; i < v.getRowDimension(); i++)      v.set(i, 0, v.get(i, 0) / sum);  }  /**   * processes the instances using the PLS1 algorithm   *   * @param instances   the data to process   * @return            the modified data   * @throws Exception  in case the processing goes wrong   */  protected Instances processPLS1(Instances instances) throws Exception {    Matrix	X, X_trans, x;    Matrix	y;    Matrix	W, w;    Matrix	T, t, t_trans;    Matrix	P, p, p_trans;    double	b;    Matrix	b_hat;    int		i;    int		j;    Matrix	X_new;    Matrix	tmp;    Instances	result;    Instances	tmpInst;    // initialization    if (!isFirstBatchDone()) {      // split up data      X       = getX(instances);      y       = getY(instances);      X_trans = X.transpose();            // init      W     = new Matrix(instances.numAttributes() - 1, getNumComponents());      P     = new Matrix(instances.numAttributes() - 1, getNumComponents());      T     = new Matrix(instances.numInstances(), getNumComponents());      b_hat = new Matrix(getNumComponents(), 1);            for (j = 0; j < getNumComponents(); j++) {	// 1. step: wj	w = X_trans.times(y);	normalizeVector(w);	setVector(w, W, j);		// 2. step: tj	t       = X.times(w);	t_trans = t.transpose();	setVector(t, T, j);		// 3. step: ^bj	b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0);	b_hat.set(j, 0, b);		// 4. step: pj	p       = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0));	p_trans = p.transpose();	setVector(p, P, j);		// 5. step: Xj+1	X = X.minus(t.times(p_trans));	y = y.minus(t.times(b));      }            // W*(P^T*W)^-1      tmp = W.times(((P.transpose()).times(W)).inverse());            // X_new = X*W*(P^T*W)^-1      X_new = getX(instances).times(tmp);            // factor = W*(P^T*W)^-1 * b_hat      m_PLS1_RegVector = tmp.times(b_hat);         // save matrices      m_PLS1_P     = P;      m_PLS1_W     = W;      m_PLS1_b_hat = b_hat;            if (getPerformPrediction())        result = toInstances(getOutputFormat(), X_new, y);      else        result = toInstances(getOutputFormat(), X_new, getY(instances));    }    // prediction    else {      result = new Instances(getOutputFormat());            for (i = 0; i < instances.numInstances(); i++) {	// work on each instance	tmpInst = new Instances(instances, 0);	tmpInst.add((Instance) instances.instance(i).copy());	x = getX(tmpInst);	X = new Matrix(1, getNumComponents());	T = new Matrix(1, getNumComponents());		for (j = 0; j < getNumComponents(); j++) {	  setVector(x, X, j);	  // 1. step: tj = xj * wj	  t = x.times(getVector(m_PLS1_W, j));	  setVector(t, T, j);	  // 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!)	  x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0)));	}		if (getPerformPrediction())	  tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat));	else	  tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst));		result.add(tmpInst.instance(0));      }    }        return result;  }  /**   * processes the instances using the SIMPLS algorithm   *   * @param instances   the data to process   * @return            the modified data   * @throws Exception  in case the processing goes wrong   */  protected Instances processSIMPLS(Instances instances) throws Exception {    Matrix	A, A_trans;    Matrix	M;    Matrix	X, X_trans;    Matrix	X_new;    Matrix	Y, y;    Matrix	C, c;    Matrix	Q, q;    Matrix	W, w;    Matrix	P, p, p_trans;    Matrix	v, v_trans;    Matrix	T;    Instances	result;    int		h;        if (!isFirstBatchDone()) {      // init      X       = getX(instances);      X_trans = X.transpose();      Y       = getY(instances);      A       = X_trans.times(Y);      M       = X_trans.times(X);      C       = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1);      W       = new Matrix(instances.numAttributes() - 1, getNumComponents());      P       = new Matrix(instances.numAttributes() - 1, getNumComponents());      Q       = new Matrix(1, getNumComponents());            for (h = 0; h < getNumComponents(); h++) {	// 1. qh as dominant EigenVector of Ah'*Ah	A_trans = A.transpose();	q       = getDominantEigenVector(A_trans.times(A));		// 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column	w       = A.times(q);	c       = w.transpose().times(M).times(w);	w       = w.times(1.0 / StrictMath.sqrt(c.get(0, 0)));	setVector(w, W, h);		// 3. ph=Mh*wh, store ph in P as column	p       = M.times(w);	p_trans = p.transpose();	setVector(p, P, h);		// 4. qh=Ah'*wh, store qh in Q as column	q = A_trans.times(w);	setVector(q, Q, h);		// 5. vh=Ch*ph, vh=vh/||vh||	v       = C.times(p);	normalizeVector(v);	v_trans = v.transpose();		// 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph'	C = C.minus(v.times(v_trans));	M = M.minus(p.times(p_trans));		// 7. Ah+1=ChAh (actually Ch+1)	A = C.times(A);      }            // finish      m_SIMPLS_W = W;      T          = X.times(m_SIMPLS_W);      X_new      = T;      m_SIMPLS_B = W.times(Q.transpose());            if (getPerformPrediction())	y = T.times(P.transpose()).times(m_SIMPLS_B);      else	y = getY(instances);      result = toInstances(getOutputFormat(), X_new, y);    }    else {      result = new Instances(getOutputFormat());            X     = getX(instances);      X_new = X.times(m_SIMPLS_W);            if (getPerformPrediction())	y = X.times(m_SIMPLS_B);      else	y = getY(instances);            result = toInstances(getOutputFormat(), X_new, y);    }        return result;  }  /**    * Returns the Capabilities of this filter.   *   * @return            the capabilities of this object   * @see               Capabilities   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NUMERIC_ATTRIBUTES);    result.enable(Capability.DATE_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);        // class    result.enable(Capability.NUMERIC_CLASS);    result.enable(Capability.DATE_CLASS);        return result;  }    /**   * Processes the given data (may change the provided dataset) and returns   * the modified version. This method is called in batchFinished().   *   * @param instances   the data to process   * @return            the modified data   * @throws Exception  in case the processing goes wrong   * @see               #batchFinished()   */  protected Instances process(Instances instances) throws Exception {    Instances	result;    int		i;    double	clsValue;    double[]	clsValues;        result = null;    // save original class values if no prediction is performed    if (!getPerformPrediction())      clsValues = instances.attributeToDoubleArray(instances.classIndex());    else      clsValues = null;        if (!isFirstBatchDone()) {      // init filters      if (m_ReplaceMissing)	m_Missing.setInputFormat(instances);            switch (m_Preprocessing) {	case PREPROCESSING_CENTER:	  m_ClassMean   = instances.meanOrMode(instances.classIndex());	  m_ClassStdDev = 1;	  m_Filter      = new Center();	  ((Center) m_Filter).setIgnoreClass(true);      	  break;	case PREPROCESSING_STANDARDIZE:	  m_ClassMean   = instances.meanOrMode(instances.classIndex());	  m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex()));	  m_Filter      = new Standardize();	  ((Standardize) m_Filter).setIgnoreClass(true);      	  break;	default:  	  m_ClassMean   = 0;	  m_ClassStdDev = 1;	  m_Filter      = null;      }      if (m_Filter != null)	m_Filter.setInputFormat(instances);    }        // filter data    if (m_ReplaceMissing)      instances = Filter.useFilter(instances, m_Missing);    if (m_Filter != null)      instances = Filter.useFilter(instances, m_Filter);        switch (m_Algorithm) {      case ALGORITHM_SIMPLS:	result = processSIMPLS(instances);	break;      case ALGORITHM_PLS1:	result = processPLS1(instances);	break;      default:	throw new IllegalStateException(	    "Algorithm type '" + m_Algorithm + "' is not recognized!");    }    // add the mean to the class again if predictions are to be performed,    // otherwise restore original class values    for (i = 0; i < result.numInstances(); i++) {      if (!getPerformPrediction()) {	result.instance(i).setClassValue(clsValues[i]);      }      else {	clsValue = result.instance(i).classValue();	result.instance(i).setClassValue(clsValue*m_ClassStdDev + m_ClassMean);      }    }        return result;  }  /**   * runs the filter with the given arguments   *   * @param args      the commandline arguments   */  public static void main(String[] args) {    runFilter(new PLSFilter(), args);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -