📄 rbf.java
字号:
distMin = distTemp;
posMin = y;
}
}
distMax = distMax / (float) Math.sqrt(2 * numNeuronasCentro);
neuronasCentro[x].sigma = distMin * 0.5F;
}
}
// NUEVANEURONACENTRO
protected NeuronaCentro[] nuevaNeuronaCentro(float[] centro) {
int y;
NeuronaCentro neuronasCentroTemp[] = new NeuronaCentro[numNeuronasCentro + 1];
// Se copian los centros existentes
for (y = 0; y < numNeuronasCentro; y++)
neuronasCentroTemp[y] = neuronasCentro[y];
// Se a�ade el nuevo centro
neuronasCentroTemp[numNeuronasCentro] = new NeuronaCentro(numNeuronasEntrada);
// Se inicializa el vector de valores del centro con el patron
// especificado
neuronasCentroTemp[numNeuronasCentro].vectorValores = centro;
// neuronasCentroTemp[numNeuronasCentro].patronRaiz = centro;
return neuronasCentroTemp;
}
// NUEVOCONECTOR
protected Conector[][] nuevoConector() {
int y, x, pos = 0;
Conector conectorCSTemp[][] = new Conector[numNeuronasCentro + 1][numNeuronasSalida];
float distMin, distTemp;
for (x = 0; x < numNeuronasCentro; x++)
for (y = 0; y < numNeuronasSalida; y++)
conectorCSTemp[x][y] = conectorCS[x][y];
// Se calcula el centro mas cercano al insertado para obtener los pesos
// de sus conectores
for (y = 0, distMin = Float.MAX_VALUE; y < numNeuronasCentro; y++) {
distTemp = distanciaEuclidea(neuronasCentro[numNeuronasCentro].vectorValores, neuronasCentro[y].vectorValores);
if (distTemp < distMin) {
distMin = distTemp;
pos = y;
}
}
for (y = 0; y < numNeuronasSalida; y++)
conectorCSTemp[numNeuronasCentro][y] = new Conector(conectorCS[pos][y].pesoActual, momento);
return conectorCSTemp;
}
// TEST
protected int test() {
int x, y;
int aciertos;
for (x = 0, errorIteracion[iteracion] = 0, aciertos = 0; x < numCasos; x++) {
loadCase(x);
generateOutput();
for (y = 0, errorPatron[x] = 0; y < numNeuronasSalida; y++)
errorPatron[x] += Math.pow(valAtrib[x][numNeuronasEntrada + y] - neuronasSalida[y].valor, 2);
errorPatron[x] /= 2;
errorIteracion[iteracion] += errorPatron[x];
if (errorPatron[x] <= TOLERANCIA)
aciertos++;
}
return aciertos;
}
// FORECAST
protected void forecast_() {
this.forecast = true;
/*
* Matriz de la forma: VARIABLE_1 PREDICCION_PARA_1 VARIABLE_2 PREDICCION_PARA_2
* 1 1.9 2 2.9 ... ... ... ...
*/
Float[][] data = new Float[numCasos][2 * this.numNeuronasSalida];
String cad;
int y, x, aciertos;
float error, errorTotal;
if (noHayConfiguracion())
return ;
log("---------- FORECAST() ---------");
//logger.info("RBF: Forecasting");
// Para cada patron de entrenamiento
for (x = 0, errorTotal = 0, aciertos = 0; x < numCasos; x++) {
if (!forecast) {
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).forecastStopped();
}
return;
}
loadCase(x);
generateOutput();
cad = "c" + x;
for (y = 0, error = 0; y < numNeuronasSalida; y++) {
error += Math.pow(valAtrib[x][numNeuronasEntrada + y] - neuronasSalida[y].valor, 2);
data[x][y * 2] = valAtrib[x][numNeuronasEntrada + y];
data[x][y * 2 + 1] = neuronasSalida[y].valor;
cad += "\t" + formato.format(valAtrib[x][numNeuronasEntrada + y]) + "\t" + formato.format(neuronasSalida[y].valor);
}
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).atStep(numCasos, x + 1, error, 1);
}
log(cad);
error /= 2;
errorTotal += error;
if (error <= TOLERANCIA)
aciertos++;
}
//System.out.println("Forecasting... ");
//System.out.println("\t" + (errorTotal / numCasos));
log("------ RESUMEN: forecast() -----");
log("Error total: " + errorTotal);
log("Error medio: " + (errorTotal / numCasos));
log("Numero de aciertos: " + aciertos + "/" + numCasos);
log("--------------------------------");
Float[] resumen = new Float[4];
resumen[0] = errorTotal;
resumen[1] = (errorTotal / numCasos);
resumen[2] = new Float(aciertos);
resumen[3] = new Float(numCasos);
for (int i = 0; i < listeners.size(); i++) {
listeners.elementAt(i).forecastFinished(data,resumen);
}
}
// GENFUZZY
protected void genFuzzy_(String nomFich) throws IOException {
if (noHayConfiguracion())
return;
int y, x;
DataOutputStream cout;
String linActual = "", valorAct, nomFicheroCFG;
if (nomFich.length() == 0) {
nomFicheroCFG = nomFichero.substring(0, nomFichero.lastIndexOf(".")) + "_RBF.rul";
System.out.println("");
System.out.print("Introduce el nombre del fichero, [" + nomFicheroCFG + "]: ");
valorAct = teclado.readLine();
nomFicheroCFG = (valorAct.length() == 0) ? nomFicheroCFG : valorAct;
} else
nomFicheroCFG = nomFich;
cout = new DataOutputStream(new FileOutputStream(nomFicheroCFG));
for (x = 0; x < numNeuronasCentro; x++) {
linActual = "";
for (y = 0; y < numNeuronasEntrada; y++)
linActual += neuronasCentro[x].vectorValores[y] + ",\t" + neuronasCentro[x].sigma + ",\t";
// Solamente una neurona de salida
linActual += conectorCS[x][0].pesoActual + endl;
cout.writeBytes(linActual);
}
cout.writeBytes(endl);
}
// CROSSVALSIMPLE
protected void crossValSimpleRuben(List<Data> trainningFiles, List<Data> testFiles, Hashtable<String, Float> parameters) {
int y, numValidaciones;
String nomBase;
Float[][][] toret = null;
try {
nomBase = "rbf_configuration_file";
numValidaciones = trainningFiles.size();
toret = new Float[numValidaciones][][];
for (y = 0; y < numValidaciones; y++) {
log("");
logger.info("RBF: Foldcross validation " + (y + 1) + "/" + numValidaciones + "");
log("-------- CROSSVAL(" + (y + 1) + "/" + numValidaciones + ") --------");
// SE CARGA EL CONJUNTO DE ENTRENAMIENTO
loadData_(trainningFiles.get(y), "TRN");
// Se cargan los parametros de entrenamiento solo una vez
if (y == 0) {
inic_(parameters);
} else
generateNetwork(MIN_CENTROS);
training_();
saveCFG_(nomBase + "_trn_" + y + "_RBF.cfg");
genFuzzy_(nomBase + "_trn_" + y + "_RBF.rul");
// SE CARGA EL CONJUNTO DE TEST
loadData_(testFiles.get(y), "TEST");
loadCFG_(nomBase + "_trn_" + y + "_RBF.cfg");
forecast_();
}
} catch (Exception ex) {
ex.printStackTrace();
ex.printStackTrace(new PrintWriter(fichLog));
}
for (int i = 0; i < listeners.size(); i++) {
listeners.elementAt(i).crossValidationFinished();
}
}
// SELETCTRN
protected void selectTrn() {
valAtrib = valAtribTrn;
numCasos = numCasosTrn;
}
// SELECTTEST
protected void selectTest() {
valAtrib = valAtribTest;
numCasos = numCasosTest;
}
// SELECTKNN
protected void selectKNN(int casoActual) {
int x, y, indice;
float distActual;
Vector distancias = new Vector(numCasosTrn);
float[][] valAtribSim = new float[numCasosSim][numAtributos];
// Para cada patron de entrenamiento se calcula su distancia con el
// patron de entrada
for (x = 0; x < numCasosTrn; x++) {
distActual = distanciaEuclidea(valAtribTrn[x], valAtribTest[casoActual]);
distancias.addElement(new Float(distActual));
}
// Se eligen las K distancias menores
for (x = 0; x < numCasosSim; x++) {
indice = MinMax(distancias, MIN);
for (y = 0; y < numAtributos; y++)
valAtribSim[x][y] = valAtribTrn[indice][y];
distancias.removeElementAt(indice);
distancias.trimToSize();
}
valAtrib = valAtribSim;
numCasos = numCasosSim;
}
// SELECTGCS
protected void selectGCS(int casoActual) {
int x, y, indice, posBMU = 0;
float distActual, disMinima = Float.MAX_VALUE;
// Para cada nodo GCS se calcula su distancia con el caso actual
for (x = 0; x < Nodos.size(); x++) {
// Si el nodo no contiene casos, no se tiene en cuenta.
if (getNodoGCS(x).casosAgrupados.size() == 0)
continue;
distActual = distanciaEuclidea(valAtribTest[casoActual], getNodoGCS(x).w);
if (distActual < disMinima) {
disMinima = distActual;
posBMU = x;
}
} // En posBMUCasoActual estara la BMU (Best Matching Unit) para el
// caso actual
// Se seleccionan los casos agrupados por el nodo BMU
float[][] valAtribSim = new float[getNodoGCS(posBMU).casosAgrupados.size()][numAtributos];
for (x = 0; x < getNodoGCS(posBMU).casosAgrupados.size(); x++) {
indice = ((Integer) getNodoGCS(posBMU).casosAgrupados.elementAt(x)).intValue();
for (y = 0; y < numAtributos; y++)
valAtribSim[x][y] = valAtribTrn[indice][y];
}
valAtrib = valAtribSim;
numCasos = getNodoGCS(posBMU).casosAgrupados.size();
}
// MIN
// operacion = 1 [MAXIMO] | operacion = -1 [MINIMO]
protected int MinMax(Vector vector, int operacion) {
int x, posMin = -1, posMax = -1;
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
for (x = 0; x < vector.size(); x++) {
if (min > ((Float) vector.elementAt(x)).floatValue()) {
min = ((Float) vector.elementAt(x)).floatValue();
posMin = x;
}
if (max < ((Float) vector.elementAt(x)).floatValue()) {
max = ((Float) vector.elementAt(x)).floatValue();
posMax = x;
}
}
return (operacion == MAX) ? posMax : posMin;
}
// DISTANCIAEUCLIDEA
protected float distanciaEuclidea(float[] v1, float[] v2) {
float distancia = 0;
for (int x = 0; x < numNeuronasEntrada; x++)
distancia += Math.pow(v1[x] - v2[x], 2);
return ((float) Math.sqrt(distancia));
}
// GETVECTORENTRADA
protected float[] getVectorEntrada() {
float[] vector = new float[numNeuronasEntrada];
for (int x = 0; x < numNeuronasEntrada; x++)
vector[x] = neuronasEntrada[x].valor;
return vector;
}
// RESETSTATCENTROS
protected void resetStatCentros() {
for (int x = 0; x < numNeuronasCentro; x++) {
neuronasCentro[x].kNN.removeAllElements();
neuronasCentro[x].errorAcumulado = 0;
}
}
// INICPATRONESALEATORIOS
protected void inicPatronesAleatorios() {
Date dt = new Date();
rnd = new Random(dt.getTime());
patrones = new Vector(numCasos);
for (int x = 0; x < numCasos; x++)
patrones.addElement(new Integer(x));
}
// SIGPATRONALEATORIO
protected int sigPatronAleatorio() {
int sigPatron, index;
index = (int) Math.floor(rnd.nextFloat() * patrones.size());
sigPatron = ((Integer) patrones.elementAt(index)).intValue();
patrones.removeElementAt(index);
patrones.trimToSize();
return sigPatron;
}
// NOHAYDATOS
protected boolean noHayDatos() {
if (numCasos == 0) {
System.out.println("No hay datos cargados!");
return true;
} else
return false;
}
// NOHAYCONFIGURACION
protected boolean noHayConfiguracion() {
if (numNeuronasCentro == 0) {
System.out.println("No se ha configurado la red!");
return true;
} else
return false;
}
// LOG
protected void log(String cad) {
try {
fichLog.writeBytes("\r\n" + cad);
fichLog.flush();
} catch (IOException ex) {
ex.printStackTrace();
}
}
// ********** [DEBUG: para depuracion de nuevo codigo] **********
protected void debug_() {
/*
* // VOLCADO DE LA RED GCS AL FICHERO DE LOGS. for (int x=0; x <
* Nodos.size(); x++) { log("Nodo " + x); log(getNodoGCS(x).getW());
* log(getNodoGCS(x).getNodosVecinos());
* log(getNodoGCS(x).getDistNodosVecinos());
* log(getNodoGCS(x).getCasosAgrupados()); log(""); }
*/
}
// GETNODOGCS
protected NodoGCS getNodoGCS(int indice) {
return ((NodoGCS) Nodos.elementAt(indice));
}
public void stopTrainning() {
this.train = false;
}
public void stopForecast() {
this.forecast = false;
}
}
/** ****************************************************** */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -