📄 network.java
字号:
if( under_construction ) System.out.println("Network.setFunction: "+f);
switch( f.type ){
case Function.INIT : ki.setInitFunc(f.kernel_name, p); break;
case Function.LEARN : ki.setLearnFunc(f.kernel_name, p); break;
case Function.REMAPPING : ki.setRemapFunc(f.kernel_name, p); break;
case Function.UPDATE : ki.setUpdateFunc(f.kernel_name, p); break;
default : return;
}
functions[f.type] = f;
parameters[f.type] = p;
}
/**
* method returns the currently chosen function of the specialized type
*
* @param the function type
* @return the function
*/
public Function getFunction( int type ){
if( type < 0 || type > functions.length ) return null;
return functions[ type ];
}
/**
* method returns the parameter values of the specialized function type
*
* @param the funciton type
* @return the parameter values
*/
public double[] getParameters( int function_type ){
if( function_type < 0 || function_type > parameters.length ) return null;
return parameters[ function_type ];
}
public void showFnList(){
if( under_construction ) System.out.println("Network.showFnList()");
String text = "";
for( int i=0; i< functions.length; i++ )
if( functions[i] != null )
text += "\n"+functions[i].show_name;
System.out.println( text );
}
public void showState(){
System.out.println("Network name: "+getName());
System.out.println("Training pattern set: "+snns.patternSets.current.getName());
System.out.println("Validation pattern set: "+snns.patternSets.validation.getName());
showFnList();
}
private void updateFnList( int type ){
String kn = null;
switch ( type ) {
case Function.LEARN : kn = ki.getLearnFunc();
break;
case Function.INIT : kn = ki.getInitFunc();
break;
case Function.PRUNING : kn = ki.getPrunFunc();
break;
case Function.UPDATE : kn = ki.getUpdateFunc();
break;
}
if( kn != null ) functions[ type ] = snns.functions.getFunction( kn, type );
if( under_construction ) {
System.out.println("Network.updateFnList( int )");
showFnList();
}
}
/*----------------------------- ----------------------------------*/
/**
* method sets the parameters for Cascade Correlation
* with several parameters ( more details in kernel )
*/
public void setCascadeParams(double max_outp_uni_error,
String learn_func,
boolean print_covar,
boolean prune_new_hidden,
String mini_func,
double min_covar_change,
int cand_patience,
int max_no_covar,
int max_no_cand_units,
String actfunc,
double error_change,
int output_patience,
int max_no_epochs,
String modification,
double[] modP,
boolean cacheUnitAct){
ki.setCascadeParams( max_outp_uni_error,
learn_func,
print_covar,
prune_new_hidden,
mini_func,
min_covar_change,
cand_patience,
max_no_covar,
max_no_cand_units,
actfunc,
error_change,
output_patience,
max_no_epochs,
modification,
modP,
cacheUnitAct);
}
/**
* method sets the pruning function
* with several parameters ( more details in kernel )
*/
public void setPruningFunc(String prune_func, String learn_func,
double pmax_error_incr, double paccepted_error,
boolean precreatef, int pfirst_train_cyc,
int pretrain_cyc, double pmin_error_to_stop,
double pinit_matrix_value, boolean pinput_pruningf,
boolean phidden_pruningf){
ki.setPruningFunc(prune_func, learn_func,
pmax_error_incr, paccepted_error,
precreatef, pfirst_train_cyc,
pretrain_cyc, pmin_error_to_stop,
pinit_matrix_value, pinput_pruningf,
phidden_pruningf);
}
/**
* method tells the kernel to start pruning
*/
public void pruneNet(boolean refreshDisplay) {
if(!refreshDisplay) ki.pruneNet();
else {
double max_err = ki.pruneNet_FirstStep();
do {
fireEvent(NetworkEvent.NETWORK_PRUNED);
try { Thread.sleep(100); }
catch( Exception e ) { }
}
while(ki.pruneNet_Step() <= max_err);
ki.pruneNet_LastStep();
}
fireEvent(NetworkEvent.NETWORK_PRUNED);
}
/**
* Initializes the network using the current init function
*/
public void initNet() throws Exception {
ki.initNet();
fireEvent( NetworkEvent.NETWORK_INITIALIZED );
}
/**
* Trains the network with all patterns from the current pattern set.
*
* @param steps the number of training steps
* @param shuffle whether the patterns are trained next by next or not
* @param subShuffle if the subpatterns are shuffled
*/
public void trainNet(ThreadChief tc, int steps, boolean shuffle, boolean subShuffle){
if( under_construction ) System.out.println("Network.trainNet( int )");
ki.setShuffle(shuffle);
ki.setSubShuffle(subShuffle);
setSubPatternScheme();
NetTrainer trainer = new NetTrainer(this, tc, steps, shuffle, subShuffle);
}
/**
* method validates the training state of the networks by the specified
* validation PatternSet
* it returns the medium squared sum of the errors
* medium means, the sum is divided by the number of patterns
*
* @param val_set the validation pattern set
* @return the medium squared sum error
*/
public double validate(PatternSet val_set, boolean shuffle, boolean subShuffle){
if( under_construction ) System.out.println("Network.validate(PatternSet)");
PatternSets sets = snns.patternSets;
PatternSet training_set = sets.getCurrent();
sets.setCurrent(val_set);
ki.setShuffle(shuffle);
ki.setSubShuffle(subShuffle);
setSubPatternScheme();
ki.testNet();
sets.setCurrent(training_set);
setSubPatternScheme();
return ki.sse;
}
/**
* Trains the network the current pattern.
*
* @param number of training steps
*/
public void trainNet_CurrentPattern( ThreadChief tc,
int steps,boolean shuffle, boolean subShuffle){
if( under_construction ) System.out.println("Network.trainNet_CurrentPattern( int )");
ki.setShuffle(shuffle);
ki.setSubShuffle(subShuffle);
setSubPatternScheme();
NetTrainer trainer = new NetTrainer(this, tc, steps, shuffle, subShuffle, getCurrentPatternNo());
}
/**
* Returns training error of the network
*
* @param errorType <code>true</code> if error should be squared, else the absolute error is returned
* @param average <code>true</code> if error should be divided by the number of output units
* @return network training error ( or -1 if something went wrong )
public native double analyzer_error(int currPatt, int unitNo, int errorType, boolean average);
*/
/*public double getError( boolean squared, boolean average ){
System.out.println("Network.getError(boolean, boolean)");
try{
System.out.println( ki.getPatternNo()+", 1," + ( squared? 2 : 1 ) + "," + average );
return ki.analyzer_error( ki.getPatternNo(), 1, ( squared? 2 : 1 ), average );
}
catch( Exception e ){
snns.showException( e, this );
return -1;
}
} */
/**
* Returns absolute training error of a unit
*
* @param unitNo unit number
* @return absolute unit training error ( or -1 if something went wrong )
*/
/*public double getError( int unitNo ){
if( under_construction ) System.out.println("Network.getError(int)");
try{
System.out.println( ki.getPatternNo()+", "+ unitNo+", 3, false ");
return ki.analyzer_error( ki.getPatternNo(), unitNo, 3, false );
}
catch( Exception e ){
snns.showException( e, this );
return -1;
}
} */
/**
* Returns training error of the network of the current pattern
*
* @param errorType <code>true</code> if error should be squared, else the absolute error is returned
* @param average <code>true</code> if error should be divided by the number of output units
* @return network training error ( or -1 if something went wrong )
public native double analyzer_error(int currPatt, int unitNo, int errorType, boolean average);
*/
public double getError( boolean squared, boolean average ) throws Exception{
ki.showPattern(1);
ki.updateNet();
Pattern p1 = new Pattern( this ), p2;
ki.showPattern(2);
p2 = new Pattern( this );
double v = 0;
for( int i=0; i<p1.output.length; i++ ){
if( squared ) v += ( p1.output[i] - p2.output[i] ) * ( p1.output[i] - p2.output[i] );
else v += Math.abs( p1.output[i] - p2.output[i] );
}
if( average ) v /= (double)p1.output.length;
return v;
}
/**
* Returns absolute training error of a unit of the current pattern
*
* @param unitNo unit number
* @return absolute unit training error ( or -1 if something went wrong )
*/
public double getError( int unitNo ) throws Exception{
ki.showPattern(1);
ki.updateNet();
double v = ki.getUnitActivation( unitNo );
ki.showPattern(2);
v -= ki.getUnitActivation( unitNo );
return Math.abs(v);
}
/**
* deletes the current network and all patterns
*/
public NetworkDeleteArgument deleteNetwork(){
boolean really = true;
if( content_changed ) really = snns.askForSaving( this );
if( !really ) return null;
homeFile = null;
if( ki.getNoOfUnits() == 0 ){
setName( "default" );
return null;
}
setName( "default", false );
selection_flags = null;
max_coord[0] = max_coord[1] = -1;
NetworkDeleteArgument nda = new NetworkDeleteArgument( net_name, deleteAllUnits() );
fireEvent( NetworkEvent.NETWORK_DELETED, nda );
content_changed = false;
return nda;
}
/**
* returns the maximal layer number
*/
public int getMaxLayerNo(){
if( under_construction )
System.out.println( "Network.getMaxLayerNo: " + layers.maxLayerNo );
return layers.maxLayerNo;
}
/**
* method returns the maximal x coordinate of an unit
*/
public int getMaxXCoordinate(){ return max_coord[0]; }
/**
* method returns the maximal y coordinate of an unit
*/
public int getMaxYCoordinate(){ return max_coord[1]; }
// Methode, die aus Traditionsgr黱den auf's KernelInterface draufgepackt wurde,
// deren Notwendigkeit noch gekl鋜t werden mu
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -