⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuron.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
      weightsDelta[nweights] = 0;
    }

    // -----------------------------------------------------------------------
    //  Calculation methods
    // -----------------------------------------------------------------------
    /**
     * Reset all internal calculation values.
     *
     * Calls recursively the reset methods of all input nodes.
     */
    public void resetValues() {

      if ( !Category.isMissingValue(input) || !Category.isMissingValue(output) ||
           !Category.isMissingValue(error) ) {
        input  = Category.MISSING_VALUE;
        output = Category.MISSING_VALUE;
        error  = Category.MISSING_VALUE;
        weightsUpdated = false;
        for (int i = 0; i < getParentsCount(); i++)
          ( (NeuralNode) getParentAt(i) ).resetValues();
      }
    }

    /**
     * Calculates input of the neuron. At this, the scalar product
     * of input neuron output values (calls the method recursively) and weights
     * is calculated.
     *
     * @param newcalc calculate the input new
     * @return output value, missing if not calculated
     * @exception MiningException error while calculating input value
     */
    public double inputValue(boolean newcalc) throws MiningException {

      if ( Category.isMissingValue(input) && newcalc ) {

        // Bias value:
        input = weights[0]*1.0;

        // Scalar product:
        for (int i = 0; i < getParentsCount(); i++) {
          double out = ((NeuralNode) getParentAt(i)).outputValue(true);
          input = input + weights[i+1]*out;
        };
      }

      return input;
    }

    /**
     * Calculates output of the neuron. At this, first the input value
     * is calculated and then the activation function is applied.
     *
     * @param newcalc calculate the output new
     * @return output value, missing if not calculated
     * @exception MiningException error while calculating output value
     */
    public double outputValue(boolean newcalc) throws MiningException {

      if ( Category.isMissingValue(output) && newcalc ) {

        // Get input value:
        output = inputValue(true);

        // Apply activation function:
        output = activationFunction.function(output);
      }

      return output;
    }

    /**
     * Calculate error of neuron.
     *
     * @param newcalc calculate the error new
     * @return error value, missing if not calculated
     * @exception MiningException error while calculating error value
     */
    public double errorValue(boolean newcalc) throws MiningException {

      if ( Category.isMissingValue(error) && !Category.isMissingValue(input) && newcalc ) {

          // Error back propagation:
          error = 0.0;
          for (int i = 0; i < getChildCount(); i++) {
            NeuralNode NN = (NeuralNode) getChildAt(i);
            double weight = NN.getWeightAt( NN.getParentIndex(this) );
            error = error + NN.errorValue(true)*weight;
          }

          // Derivation of activation function:
          double value = inputValue(false);
          error = error*activationFunction.derivation(value);
      }

      return error;
    }

    /**
     * This method updates the weights of this neuron and then
     * calls the same method of the input nodes recursively. Should
     * be overriden by Neuron's.
     *
     * @param learningRate the learning rate
     * @param momentum the momentum
     * @exception MiningException cannot update weights
     */
    public void updateWeights(double learningRate, double momentum)
     throws MiningException {

      if (! weightsUpdated && !Double.isNaN(error) ) {
        // Learning rate times error:
        double coef0 = learningRate*errorValue(false);

        // Bias term:
        double delta    = coef0*1.0 + momentum*weightsDelta[0];
        weights[0]      = weights[0] + delta;
        weightsDelta[0] = delta;

        // Input neurons:
        for (int i = 1; i < getParentsCount()+1; i++) {
          delta = coef0*((NeuralNode) getParentAt(i-1)).outputValue(false);
          delta = delta + momentum*weightsDelta[i];

          weights[i]      = weights[i] + delta;
          weightsDelta[i] = delta;
        };

        // Recursive call to input nodes:
        super.updateWeights(learningRate, momentum);
      };
    }

    // -----------------------------------------------------------------------
    //  Methods of PMML handling
    // -----------------------------------------------------------------------
    /**
     * Write neuron to PMML element.
     *
     * @return PMML element of neuron
     * @exception MiningException
     */
    public Object createPmmlObject() throws MiningException
    {
      // Create neuron:
      com.prudsys.pdm.Adapters.PmmlVersion20.Neuron neuron =
        new com.prudsys.pdm.Adapters.PmmlVersion20.Neuron();

      // Set neuron ID:
      neuron.setId(id);

      // Set input neurons:
      int npar = getParentsCount();
      if (npar > 0) {
        com.prudsys.pdm.Adapters.PmmlVersion20.Conn[] conn =
          new com.prudsys.pdm.Adapters.PmmlVersion20.Conn[npar];

        boolean addWeights = false;
        if ( getNumberOfWeights() > npar )
          addWeights = true;
        else
          System.out.println("Warning! Neuron #ID: " + id + " without valid weight vector");
        for (int i = 0; i < npar; i++) {
          conn[i] = new com.prudsys.pdm.Adapters.PmmlVersion20.Conn();
          NeuralNode par = (NeuralNode) getParentAt(i);
          conn[i].setFrom( par.getId() );
          if (addWeights) conn[i].setWeight( String.valueOf( weights[i+1] ) );
        };
        neuron.setConn(conn);
      };

      // Set activation function:
      if (activationFunction != null && neuralLayer.getActivationFunction() == null) {
        String afname = ActivationFunction.convertTypeToPmml(activationFunction.
            getFunctionType());
        if (afname == null)
          afname = activationFunction.getFunctionType(); // use even if no equivalent PMML name exists
        neuron.setActivationFunction(afname);
      };

      // Set bias:
      if (useBias && weights != null && weights.length > 0)
        neuron.setBias( String.valueOf( weights[0] ) );

      // Set threshold:
      if ( !Category.isMissingValue(threshold) )
        neuron.setThreshold( String.valueOf(threshold) );

      return neuron;
    }

    /**
     * Read neuron from PMML element.
     *
     * @param pmmlObject PMML element to read in
     * @exception MiningException always thrown
     */
    public void parsePmmlObject( Object pmmlObject ) throws MiningException
    {
      // Get neuron:
      com.prudsys.pdm.Adapters.PmmlVersion20.Neuron neuron =
        (com.prudsys.pdm.Adapters.PmmlVersion20.Neuron) pmmlObject;

      // Get neural input ID:
      id = neuron.getId();

      // Get input neurons:
      com.prudsys.pdm.Adapters.PmmlVersion20.Conn[] inputs = neuron.getConn();
      if (inputs == null || inputs.length == 0)
        throw new MiningException("neuron: " + id + " has no inputs");
      int ninp = inputs.length;
      inputIDs = new String[ninp];
      weights  = new double[ninp+1];
      for (int i = 0; i < ninp; i++) {
        inputIDs[i] = inputs[i].getFrom();
        try {
          weights[i+1] = Double.parseDouble( inputs[i].getWeight() );
        }
        catch (Exception ex) {
          System.out.println("Warning! Invalid weight value of neuron #ID: " + id);
        };
      };

      // Get activation function:
      String afname = neuron.getActivationFunction();
      if (afname != null) {
        String aftype = ActivationFunction.convertPmmlToType(afname);
        if (afname == null)
          aftype = afname;   // try even if name is unknown in PMML standard
        ActivationFunction af = ActivationFunction.getInstance(aftype);
        if (af != null)
          activationFunction = af;
      }

      // Get bias:
      useBias = false;
      String bias = neuron.getBias();
      if (bias != null) {
        try {
          weights[0] = Double.parseDouble(bias);
          useBias = true;
        }
        catch (Exception ex) {
          System.out.println("Warning! Invalid bias value of neuron #ID: " + id);
        };
      }

      // Get threshold:
      String thresh = neuron.getThreshold();
      if (thresh != null) {
        try {
          threshold = Double.parseDouble(thresh);
        }
        catch (Exception ex) {
          System.out.println("Warning! Invalid threshold value of neuron #ID: " + id);
        };
      }
    }

    // -----------------------------------------------------------------------
    //  Further methods
    // -----------------------------------------------------------------------
    /**
     * Delivers string representation of neuron.
     *
     * @return string representation of neuron
     */
    public String toString() {

      String s = "Neuron, " + super.toString();
      if (weights != null) {
        if (useBias) s = s + ", bias = " + weights[0];
        s = s + " weights: ";
        for (int i = 0; i < weights.length; i++)
          s = s + weights[i] + " ";
      };

      return s;
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -