⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 network.java

📁 著名的神经网络工具箱
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    if( under_construction ) System.out.println("Network.setFunction: "+f);
    switch( f.type ){
      case Function.INIT      : ki.setInitFunc(f.kernel_name, p); break;
      case Function.LEARN     : ki.setLearnFunc(f.kernel_name, p); break;
      case Function.REMAPPING : ki.setRemapFunc(f.kernel_name, p); break;
      case Function.UPDATE    : ki.setUpdateFunc(f.kernel_name, p); break;
      default                 : return;
    }
    functions[f.type] = f;
    parameters[f.type] = p;
  }
  /**
   * method returns the currently chosen function of the specialized type
   *
   * @param the function type
   * @return the function
   */
  public Function getFunction( int type ){
    if( type < 0 || type > functions.length ) return null;
    return functions[ type ];
  }

  /**
   * method returns the parameter values of the specialized function type
   *
   * @param the funciton type
   * @return the parameter values
   */
  public double[] getParameters( int function_type ){
    if( function_type < 0 || function_type > parameters.length ) return null;
    return parameters[ function_type ];
  }

  public void showFnList(){
    if( under_construction ) System.out.println("Network.showFnList()");
    String text = "";
    for( int i=0; i< functions.length; i++ )
      if( functions[i] != null )
        text += "\n"+functions[i].show_name;
    System.out.println( text );
  }

  public void showState(){
    System.out.println("Network name: "+getName());
    System.out.println("Training pattern set: "+snns.patternSets.current.getName());
    System.out.println("Validation pattern set: "+snns.patternSets.validation.getName());
    showFnList();
  }

  private void updateFnList( int type ){
    String kn = null;
    switch ( type ) {
      case Function.LEARN   : kn = ki.getLearnFunc();
                              break;
      case Function.INIT    : kn = ki.getInitFunc();
                              break;
      case Function.PRUNING : kn = ki.getPrunFunc();
                              break;
      case Function.UPDATE  : kn = ki.getUpdateFunc();
                              break;
    }
    if( kn != null ) functions[ type ] = snns.functions.getFunction( kn, type );
    if( under_construction ) {
      System.out.println("Network.updateFnList( int )");
      showFnList();
    }
  }

/*-----------------------------             ----------------------------------*/
  /**
   * method sets the parameters for Cascade Correlation
   * with several parameters ( more details in kernel )
   */
  public void setCascadeParams(double max_outp_uni_error,
                                      String learn_func,
                                      boolean print_covar,
                                      boolean prune_new_hidden,
                                      String mini_func,
                                      double min_covar_change,
                                      int cand_patience,
                                      int max_no_covar,
                                      int max_no_cand_units,
                                      String actfunc,
                                      double error_change,
                                      int output_patience,
                                      int max_no_epochs,
                                      String modification,
                                      double[] modP,
                                      boolean cacheUnitAct){

    ki.setCascadeParams( max_outp_uni_error,
                         learn_func,
                         print_covar,
                         prune_new_hidden,
                         mini_func,
                         min_covar_change,
                         cand_patience,
                         max_no_covar,
                         max_no_cand_units,
                         actfunc,
                         error_change,
                         output_patience,
                         max_no_epochs,
                         modification,
                         modP,
                         cacheUnitAct);
  }
  /**
   * method sets the pruning function
   * with several parameters ( more details in kernel )
   */
  public void setPruningFunc(String prune_func, String learn_func,
                          double pmax_error_incr, double paccepted_error,
                          boolean precreatef, int pfirst_train_cyc,
                          int pretrain_cyc, double pmin_error_to_stop,
                          double pinit_matrix_value, boolean pinput_pruningf,
                          boolean phidden_pruningf){

    ki.setPruningFunc(prune_func, learn_func,
                          pmax_error_incr, paccepted_error,
                          precreatef, pfirst_train_cyc,
                          pretrain_cyc, pmin_error_to_stop,
                          pinit_matrix_value, pinput_pruningf,
                          phidden_pruningf);
  }

  /**
   * method tells the kernel to start pruning
   */
  public void pruneNet(boolean refreshDisplay) {
    if(!refreshDisplay) ki.pruneNet();
    else {
      double max_err = ki.pruneNet_FirstStep();
      do {
        fireEvent(NetworkEvent.NETWORK_PRUNED);
        try { Thread.sleep(100); }
        catch( Exception e ) { }
      }
      while(ki.pruneNet_Step() <= max_err);
      ki.pruneNet_LastStep();
    }
    fireEvent(NetworkEvent.NETWORK_PRUNED);
  }

  /**
   * Initializes the network using the current init function
   */
  public void initNet() throws Exception {
    ki.initNet();
    fireEvent( NetworkEvent.NETWORK_INITIALIZED );
  }


  /**
   * Trains the network with all patterns from the current pattern set.
   *
   * @param steps the number of training steps
   * @param shuffle whether the patterns are trained next by next or not
   * @param subShuffle if the subpatterns are shuffled
   */
  public void trainNet(ThreadChief tc, int steps, boolean shuffle, boolean subShuffle){
    if( under_construction ) System.out.println("Network.trainNet( int )");
    ki.setShuffle(shuffle);
    ki.setSubShuffle(subShuffle);
    setSubPatternScheme();
    NetTrainer trainer = new NetTrainer(this, tc, steps, shuffle, subShuffle);
  }

  /**
   * method validates the training state of the networks by the specified
   * validation PatternSet
   * it returns the medium squared sum of the errors
   * medium means, the sum is divided by the number of patterns
   *
   * @param val_set the validation pattern set
   * @return the medium squared sum error
   */
  public double validate(PatternSet val_set, boolean shuffle, boolean subShuffle){
    if( under_construction ) System.out.println("Network.validate(PatternSet)");
    PatternSets sets = snns.patternSets;
    PatternSet training_set = sets.getCurrent();
    sets.setCurrent(val_set);
    ki.setShuffle(shuffle);
    ki.setSubShuffle(subShuffle);
    setSubPatternScheme();
    ki.testNet();
    sets.setCurrent(training_set);
    setSubPatternScheme();
    return ki.sse;
  }


  /**
   * Trains the network the current pattern.
   *
   * @param number of training steps
   */
  public void trainNet_CurrentPattern( ThreadChief tc,
        int steps,boolean shuffle, boolean subShuffle){
    if( under_construction ) System.out.println("Network.trainNet_CurrentPattern( int )");
    ki.setShuffle(shuffle);
    ki.setSubShuffle(subShuffle);
    setSubPatternScheme();
    NetTrainer trainer = new NetTrainer(this, tc, steps, shuffle, subShuffle, getCurrentPatternNo());
  }


  /**
   * Returns training error of the network
   *
   * @param errorType   <code>true</code> if error should be squared, else the absolute error is returned
   * @param average     <code>true</code> if error should be divided by the number of output units
   * @return network training error ( or -1 if something went wrong )
  public native double  analyzer_error(int currPatt, int unitNo, int errorType, boolean average);
   */
  /*public double getError( boolean squared, boolean average ){
    System.out.println("Network.getError(boolean, boolean)");
    try{
      System.out.println( ki.getPatternNo()+", 1," + ( squared? 2 : 1 ) + "," + average );
      return ki.analyzer_error( ki.getPatternNo(), 1, ( squared? 2 : 1 ), average );
    }
    catch( Exception e ){
      snns.showException( e, this );
      return -1;
    }
  } */

  /**
   * Returns absolute training error of a unit
   *
   * @param unitNo unit number
   * @return absolute unit training error ( or -1 if something went wrong )
   */
  /*public double getError( int unitNo ){
    if( under_construction ) System.out.println("Network.getError(int)");
    try{
      System.out.println( ki.getPatternNo()+", "+ unitNo+", 3, false ");
      return ki.analyzer_error( ki.getPatternNo(), unitNo, 3, false );
    }
    catch( Exception e ){
      snns.showException( e, this );
      return -1;
    }
  } */
  /**
   * Returns training error of the network of the current pattern
   *
   * @param errorType   <code>true</code> if error should be squared, else the absolute error is returned
   * @param average     <code>true</code> if error should be divided by the number of output units
   * @return network training error ( or -1 if something went wrong )
  public native double  analyzer_error(int currPatt, int unitNo, int errorType, boolean average);
   */
  public double getError( boolean squared, boolean average ) throws Exception{
    ki.showPattern(1);
    ki.updateNet();
    Pattern p1 = new Pattern( this ), p2;
    ki.showPattern(2);
    p2 = new Pattern( this );
    double v = 0;
    for( int i=0; i<p1.output.length; i++ ){
      if( squared ) v += ( p1.output[i] - p2.output[i] ) * ( p1.output[i] - p2.output[i] );
      else v += Math.abs( p1.output[i] - p2.output[i] );
    }
    if( average ) v /= (double)p1.output.length;
    return v;
  }

  /**
   * Returns absolute training error of a unit of the current pattern
   *
   * @param unitNo unit number
   * @return absolute unit training error ( or -1 if something went wrong )
   */
  public double getError( int unitNo ) throws Exception{
    ki.showPattern(1);
    ki.updateNet();
    double v = ki.getUnitActivation( unitNo );
    ki.showPattern(2);
    v -= ki.getUnitActivation( unitNo );
    return Math.abs(v);
  }

  /**
   * deletes the current network and all patterns
   */
  public NetworkDeleteArgument deleteNetwork(){
    boolean really = true;
    if( content_changed ) really = snns.askForSaving( this );
    if( !really ) return null;
    homeFile = null;
    if( ki.getNoOfUnits() == 0 ){
      setName( "default" );
      return null;
    }
    setName( "default", false );
    selection_flags = null;
    max_coord[0] = max_coord[1] = -1;
    NetworkDeleteArgument nda = new NetworkDeleteArgument( net_name, deleteAllUnits() );
    fireEvent( NetworkEvent.NETWORK_DELETED, nda );
    content_changed = false;
    return nda;
  }

  /**
   * returns the maximal layer number
   */
  public int getMaxLayerNo(){
    if( under_construction )
      System.out.println( "Network.getMaxLayerNo: " + layers.maxLayerNo );
    return layers.maxLayerNo;
  }

  /**
   * method returns the maximal x coordinate of an unit
   */
  public int getMaxXCoordinate(){ return max_coord[0]; }

  /**
   * method returns the maximal y coordinate of an unit
   */
  public int getMaxYCoordinate(){ return max_coord[1]; }


// Methode, die aus Traditionsgr黱den auf's KernelInterface draufgepackt wurde,
// deren Notwendigkeit noch gekl鋜t werden mu

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -