📄 plsfilter.java
字号:
* * @param v the vector to store in the matrix * @param m the receiving matrix * @param columnIndex the column to store the values in */ protected void setVector(Matrix v, Matrix m, int columnIndex) { m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v); } /** * returns the (column) vector of the matrix at the specified index * * @param m the matrix to work on * @param columnIndex the column to get the values from * @return the column vector */ protected Matrix getVector(Matrix m, int columnIndex) { return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex); } /** * determines the dominant eigenvector for the given matrix and returns it * * @param m the matrix to determine the dominant eigenvector for * @return the dominant eigenvector */ protected Matrix getDominantEigenVector(Matrix m) { EigenvalueDecomposition eigendecomp; double[] eigenvalues; int index; Matrix result; eigendecomp = m.eig(); eigenvalues = eigendecomp.getRealEigenvalues(); index = Utils.maxIndex(eigenvalues); result = columnAsVector(eigendecomp.getV(), index); return result; } /** * normalizes the given vector (inplace) * * @param v the vector to normalize */ protected void normalizeVector(Matrix v) { double sum; int i; // determine length sum = 0; for (i = 0; i < v.getRowDimension(); i++) sum += v.get(i, 0) * v.get(i, 0); sum = StrictMath.sqrt(sum); // normalize content for (i = 0; i < v.getRowDimension(); i++) v.set(i, 0, v.get(i, 0) / sum); } /** * processes the instances using the PLS1 algorithm * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong */ protected Instances processPLS1(Instances instances) throws Exception { Matrix X, X_trans, x; Matrix y; Matrix W, w; Matrix T, t, t_trans; Matrix P, p, p_trans; double b; Matrix b_hat; int i; int j; Matrix X_new; Matrix tmp; Instances result; Instances tmpInst; // initialization if (!isFirstBatchDone()) { // split up data X = getX(instances); y = getY(instances); X_trans = X.transpose(); // init W = new Matrix(instances.numAttributes() - 1, getNumComponents()); P = new Matrix(instances.numAttributes() - 1, getNumComponents()); T = new Matrix(instances.numInstances(), getNumComponents()); b_hat = new Matrix(getNumComponents(), 1); for (j = 0; j < getNumComponents(); j++) { // 1. step: wj w = X_trans.times(y); normalizeVector(w); setVector(w, W, j); // 2. step: tj t = X.times(w); t_trans = t.transpose(); setVector(t, T, j); // 3. step: ^bj b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0); b_hat.set(j, 0, b); // 4. step: pj p = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0)); p_trans = p.transpose(); setVector(p, P, j); // 5. step: Xj+1 X = X.minus(t.times(p_trans)); y = y.minus(t.times(b)); } // W*(P^T*W)^-1 tmp = W.times(((P.transpose()).times(W)).inverse()); // X_new = X*W*(P^T*W)^-1 X_new = getX(instances).times(tmp); // factor = W*(P^T*W)^-1 * b_hat m_PLS1_RegVector = tmp.times(b_hat); // save matrices m_PLS1_P = P; m_PLS1_W = W; m_PLS1_b_hat = b_hat; if (getPerformPrediction()) result = toInstances(getOutputFormat(), X_new, y); else result = toInstances(getOutputFormat(), X_new, getY(instances)); } // prediction else { result = new Instances(getOutputFormat()); for (i = 0; i < instances.numInstances(); i++) { // work on each instance tmpInst = new Instances(instances, 0); tmpInst.add((Instance) instances.instance(i).copy()); x = getX(tmpInst); X = new Matrix(1, getNumComponents()); T = new Matrix(1, getNumComponents()); for (j = 0; j < getNumComponents(); j++) { setVector(x, X, j); // 1. step: tj = xj * wj t = x.times(getVector(m_PLS1_W, j)); setVector(t, T, j); // 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!) x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0))); } if (getPerformPrediction()) tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat)); else tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst)); result.add(tmpInst.instance(0)); } } return result; } /** * processes the instances using the SIMPLS algorithm * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong */ protected Instances processSIMPLS(Instances instances) throws Exception { Matrix A, A_trans; Matrix M; Matrix X, X_trans; Matrix X_new; Matrix Y, y; Matrix C, c; Matrix Q, q; Matrix W, w; Matrix P, p, p_trans; Matrix v, v_trans; Matrix T; Instances result; int h; if (!isFirstBatchDone()) { // init X = getX(instances); X_trans = X.transpose(); Y = getY(instances); A = X_trans.times(Y); M = X_trans.times(X); C = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1); W = new Matrix(instances.numAttributes() - 1, getNumComponents()); P = new Matrix(instances.numAttributes() - 1, getNumComponents()); Q = new Matrix(1, getNumComponents()); for (h = 0; h < getNumComponents(); h++) { // 1. qh as dominant EigenVector of Ah'*Ah A_trans = A.transpose(); q = getDominantEigenVector(A_trans.times(A)); // 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column w = A.times(q); c = w.transpose().times(M).times(w); w = w.times(1.0 / StrictMath.sqrt(c.get(0, 0))); setVector(w, W, h); // 3. ph=Mh*wh, store ph in P as column p = M.times(w); p_trans = p.transpose(); setVector(p, P, h); // 4. qh=Ah'*wh, store qh in Q as column q = A_trans.times(w); setVector(q, Q, h); // 5. vh=Ch*ph, vh=vh/||vh|| v = C.times(p); normalizeVector(v); v_trans = v.transpose(); // 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph' C = C.minus(v.times(v_trans)); M = M.minus(p.times(p_trans)); // 7. Ah+1=ChAh (actually Ch+1) A = C.times(A); } // finish m_SIMPLS_W = W; T = X.times(m_SIMPLS_W); X_new = T; m_SIMPLS_B = W.times(Q.transpose()); if (getPerformPrediction()) y = T.times(P.transpose()).times(m_SIMPLS_B); else y = getY(instances); result = toInstances(getOutputFormat(), X_new, y); } else { result = new Instances(getOutputFormat()); X = getX(instances); X_new = X.times(m_SIMPLS_W); if (getPerformPrediction()) y = X.times(m_SIMPLS_B); else y = getY(instances); result = toInstances(getOutputFormat(), X_new, y); } return result; } /** * Returns the Capabilities of this filter. * * @return the capabilities of this object * @see Capabilities */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); return result; } /** * Processes the given data (may change the provided dataset) and returns * the modified version. This method is called in batchFinished(). * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong * @see #batchFinished() */ protected Instances process(Instances instances) throws Exception { Instances result; int i; double clsValue; double[] clsValues; result = null; // save original class values if no prediction is performed if (!getPerformPrediction()) clsValues = instances.attributeToDoubleArray(instances.classIndex()); else clsValues = null; if (!isFirstBatchDone()) { // init filters if (m_ReplaceMissing) m_Missing.setInputFormat(instances); switch (m_Preprocessing) { case PREPROCESSING_CENTER: m_ClassMean = instances.meanOrMode(instances.classIndex()); m_ClassStdDev = 1; m_Filter = new Center(); ((Center) m_Filter).setIgnoreClass(true); break; case PREPROCESSING_STANDARDIZE: m_ClassMean = instances.meanOrMode(instances.classIndex()); m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex())); m_Filter = new Standardize(); ((Standardize) m_Filter).setIgnoreClass(true); break; default: m_ClassMean = 0; m_ClassStdDev = 1; m_Filter = null; } if (m_Filter != null) m_Filter.setInputFormat(instances); } // filter data if (m_ReplaceMissing) instances = Filter.useFilter(instances, m_Missing); if (m_Filter != null) instances = Filter.useFilter(instances, m_Filter); switch (m_Algorithm) { case ALGORITHM_SIMPLS: result = processSIMPLS(instances); break; case ALGORITHM_PLS1: result = processPLS1(instances); break; default: throw new IllegalStateException( "Algorithm type '" + m_Algorithm + "' is not recognized!"); } // add the mean to the class again if predictions are to be performed, // otherwise restore original class values for (i = 0; i < result.numInstances(); i++) { if (!getPerformPrediction()) { result.instance(i).setClassValue(clsValues[i]); } else { clsValue = result.instance(i).classValue(); result.instance(i).setClassValue(clsValue*m_ClassStdDev + m_ClassMean); } } return result; } /** * runs the filter with the given arguments * * @param args the commandline arguments */ public static void main(String[] args) { runFilter(new PLSFilter(), args); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -