⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpclassification.cpp

📁 这是我自己编程的BP算法
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// BPClassification.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>


//Define
#define NUM_CLASS 3 //总共要处理三个类
#define NUM_TRAIN_PER_CLASS 30 //每个类所给的学习数据个数
#define NUM_TEST_PER_CLASS 20 //每个类所给的测试数据个数
#define NUM_FEATURE 4 //每个数据有4个特征

#define NUM_NODE_INPUT_LAYER NUM_FEATURE //输入层节点个数
#define NUM_NODE_HIDE_LAYER 6 //隐含层节点个数
#define NUM_NODE_OUTPUT_LAYER NUM_CLASS //输出层节点个数
#define NUM_NODE_TOTAL NUM_NODE_INPUT_LAYER+NUM_NODE_HIDE_LAYER+NUM_NODE_OUTPUT_LAYER

#define INDEX_START_INPUT_LAYER 0
#define INDEX_START_HIDE_LAYER NUM_NODE_INPUT_LAYER
#define INDEX_START_OUTPUT_LAYER NUM_NODE_INPUT_LAYER+NUM_NODE_HIDE_LAYER

#define TRAIN_RATE 0.5 //学习速率 DEBUG-->正式程序要由用户来输入
#define NUM_MAX_ITERATE 60000 //最大迭代次数
#define THRESHOLD 0.05 //
//Define-End

//结构体定义

//定义学习数据
typedef struct
{
	float feature[NUM_FEATURE];//存放每个数据的特征
	int nClass;//当前学习数据所属的类
}TrainData;

//定义测试数据
typedef struct
{
	float feature[NUM_FEATURE];//存放每个测试数据的特征
	int nClass;//当前测试数据所属的类,已知
	int result;//当前测试数据所属的类,程序判断结果
	bool Right;//通过比较nClass与result的值来判断BP网络的测试结果是否正确,此项主要用来统计,以表示BP网络的性能
}TestData;

//定义BP网络节点结构
typedef struct
{
	float data;
	float expect;
	float difference;
	float input;
	float output;
	float err;
}BP_Node;

//结构体定义结束


//声明全局变量
TrainData train[NUM_CLASS*NUM_TRAIN_PER_CLASS];
TestData test[NUM_CLASS*NUM_TEST_PER_CLASS];
float W_INPUT_TO_HIDE[NUM_NODE_INPUT_LAYER][NUM_NODE_HIDE_LAYER];
float W_HIDE_TO_OUTPUT[NUM_NODE_HIDE_LAYER][NUM_NODE_OUTPUT_LAYER];
BP_Node bpNode[NUM_NODE_TOTAL];
char *kinds[] = {"Iris-setosa", "Iris-versicolor", "Iris-virginica"};

int NUM_CURRENT_ITERATE=0;

//声明全局变量结束


//函数声明

void Read_Data(FILE *file, TrainData *train, TestData *test);
//功能:从文件file中读取各学习数据集至train所指的空间中,读取测试数据集至test所指的空间中

void Print_Data(TrainData *train,TestData* test);
//功能:调用函数Read_Data()结束后,回显示数据以测试Read_Data()的功能是否正确

void Init_Weight_Difference(float W_INPUT_TO_HIDE[][NUM_NODE_HIDE_LAYER],float W_HIDE_TO_OUTPUT[][NUM_NODE_OUTPUT_LAYER],BP_Node* bpNode);
//功能:初始化BP网络的权值和偏差

void TEST_Init_Weight_Difference(float W_INPUT_TO_HIDE[][NUM_NODE_HIDE_LAYER],float W_HIDE_TO_OUTPUT[][NUM_NODE_OUTPUT_LAYER],BP_Node* bpNode);
//功能:测试Init_Weight_Difference()

void Input_Single_Train_Data_To_Input_Layer(TrainData singleTrainData);
//功能:将单个学习样本输入到BP的输入层

void Input_Teacher_Data_To_Output_Layer(TrainData singleTrainData);
//功能:依据当前单个学习样本所属的类(nClass)设定输出层各节点的教师值

void Print_Input_Output_Layer_Before_Train();
//功能:在正式学习之前,输出Input及Output层各节点的相关值,以检测值的设定是否正确

void Train();
//功能:训练学习样本

void prePropagation();
//功能:向前传播

void CalcErr();
//功能:求出BP网络中所有节点的输出

void Update_Weight_Difference();
//功能:更新权值及偏差

bool isStop(TrainData current_Train_Data);
//功能:判断是否满足结束条件

void printTrainResult();

void Test();
//功能:测试样本

void PrintTestResult();

//函数声明结束



int main(int argc, char* argv[])
{
	int i,j;


	FILE* inFile;

	//初始化 train[],及test[]
	for(i=0;i<NUM_CLASS*NUM_TRAIN_PER_CLASS;i++)
	{
		for(j=0;j<NUM_FEATURE;j++)
		{
			train[i].feature[j]=0.0;
		}
		train[i].nClass=-1;
	}
	
	for(i=0;i<NUM_CLASS*NUM_TEST_PER_CLASS;i++)
	{
		for(j=0;j<NUM_FEATURE;j++)
		{
			test[i].feature[j]=0.0;
		}
		test[i].nClass=-1;
		test[i].result=-1;
		test[i].Right=false;
	}
	
	//打开文件出错
	if((inFile=fopen("iris.dat","r"))==NULL)
	{
		fprintf(stderr,"Can not Open File!\n");
		exit(1);
	}

	Read_Data(inFile, train, test);//读入数据
	fclose(inFile);
	Print_Data(train,test);

	Init_Weight_Difference(W_INPUT_TO_HIDE,W_HIDE_TO_OUTPUT,bpNode);
    TEST_Init_Weight_Difference(W_INPUT_TO_HIDE,W_HIDE_TO_OUTPUT,bpNode);

	Train();
	printTrainResult();
	Test();
	PrintTestResult();
	return 0;
}

void Read_Data(FILE *in, TrainData *train, TestData *test)
{
	char	string[256];
	char	name[20];
	int		num = NUM_TRAIN_PER_CLASS+NUM_TEST_PER_CLASS;
	int		i,j,k;

	i = j = k = 0;
	while(fgets(string, 256, in) != NULL)
	{
		printf("WUJM:String=%s.\n",string);

		if((i%num) < NUM_TRAIN_PER_CLASS)//加载学习数据
		{
			sscanf(string,"%f,%f,%f,%f,%s",&train[j].feature[0],&train[j].feature[1],&train[j].feature[2],&train[j].feature[3],name);
			//DEBUG
			train[j].feature[0]/=10;
			train[j].feature[1]/=10;
			train[j].feature[2]/=10;
			train[j].feature[3]/=10;
			//DEBUG
			if (strcmp(name, kinds[0]) == 0)
				train[j].nClass = 0;
			else if (strcmp(name, kinds[1]) == 0)
				train[j].nClass = 1;
			else if (strcmp(name, kinds[2]) == 0)
				train[j].nClass = 2;
			j++;
		}
		else//加载测试数据
		{
			sscanf(string,"%f,%f,%f,%f,%s",&test[k].feature[0],&test[k].feature[1],&test[k].feature[2],&test[k].feature[3],name);
			//DEBUG
			test[k].feature[0]/=10;
			test[k].feature[1]/=10;
			test[k].feature[2]/=10;
			test[k].feature[3]/=10;
			//DEBUG
			if (strcmp(name, kinds[0]) == 0)
				test[k].nClass = 0;
			else if (strcmp(name, kinds[1]) == 0)
				test[k].nClass = 1;
			else if (strcmp(name, kinds[2]) == 0)
				test[k].nClass = 2;
			k++;		
		}
		i++;
	}
}


void Print_Data(TrainData *train, TestData *test)
{
	int index;
	
	//显示学习数据:
	printf("学习数据显示:\n");
	for(index=0;index<NUM_CLASS*NUM_TRAIN_PER_CLASS;index++)
	{
		printf("%f\t%f\t%f\t%f\t",train[index].feature[0],train[index].feature[1],train[index].feature[2],train[index].feature[3]);
		if(train[index].nClass==0)
		{
			printf("Iris-setosa");
		}
		else if(train[index].nClass==1)
		{
			printf("Iris-versicolor");
		}
		else if(train[index].nClass==2)
		{
			printf("Iris-virginica");
		}
		else if(train[index].nClass==-1)
		{
			printf("未正确读入所属类");
		}
		printf("\n");
	}
	
	//显示测试数据:
	printf("测试数据显示:\n");
	for(index=0;index<NUM_TEST_PER_CLASS*NUM_CLASS;index++)
	{
		printf("%f\t%f\t%f\t%f\t",test[index].feature[0],test[index].feature[1],test[index].feature[2],test[index].feature[3]);
		
		if(test[index].nClass==0)
		{
			printf("Iris-setosa");
		}
		else if(test[index].nClass==1)
		{
			printf("Iris-versicolor");
		}
		else if(test[index].nClass==2)
		{
			printf("Iris-virginica");
		}
		else if(test[index].nClass=-1)
		{
			printf("未正确读入所属类");
		}

		if(test[index].result==0)
		{
			printf("\tIris-setosa");
		}
		else if(test[index].result==1)
		{
			printf("\tIris-versicolor");
		}
		else if(test[index].result==2)
		{
			printf("\tIris-virginica");
		}
		else if(test[index].result==-1)
		{
			printf("\t未测试");
		}
		
		if(test[index].Right==true)
		{
			printf("\t测试结果正确");
		}
		else
		{
			printf("\t测试结果错误");
		}
		printf("\n");	
	}
}

void Init_Weight_Difference(float W_INPUT_TO_HIDE[][NUM_NODE_HIDE_LAYER],float W_HIDE_TO_OUTPUT[][NUM_NODE_OUTPUT_LAYER],BP_Node* bpNode)
{
	int i,j;
	time_t t;
	//初始化 bpNode的
	srand((unsigned)time(&t));
	for(i=0;i<NUM_NODE_TOTAL;i++)
	{
		bpNode[i].data=-1.0;				
		bpNode[i].difference=(float)(rand() % 1000)/1000-0.5f;//随机生成
		bpNode[i].expect=0.0; //处理学习数据时设定
		bpNode[i].input=0.0f;
		bpNode[i].output=0.0f;
		bpNode[i].err=0.0f;
	}
	
	for(i=0;i<NUM_NODE_INPUT_LAYER;i++)
	{
		for(j=0;j<NUM_NODE_HIDE_LAYER;j++)
		{
			W_INPUT_TO_HIDE[i][j]=(float)(rand() % 1000)/1000-0.5f;//随机生成
		}
	}

	for(i=0;i<NUM_NODE_HIDE_LAYER;i++)
	{
		for(j=0;j<NUM_NODE_OUTPUT_LAYER;j++)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -