⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lrn_pifa.c

📁 这个代码是policy iteration算法关于强化学习的. 请您用winzip 解压缩
💻 C
📖 第 1 页 / 共 2 页
字号:
		}
#endif

		cont_grad = 1;
		cnt = 0;
		for ( k = 1; k < Episodes_Per_Parameter_Update && (cont_grad == 1); k++ )
		{
			for ( j = 0; j < num_of_gaussians && (cont_grad == 1); j++ )
			{
				cont_grad = modes_visited[k][j];
			}
			if (cont_grad == 1)
				cnt++;
		}
		cnt++;

		Q_a = (double **)My_Malloc((long)cnt  * sizeof(double*));
		for ( k = 0; k < cnt; k++ )
		{
			Q_a[k] = (double *)My_Malloc((long)num_of_gaussians  * sizeof(double));
		}

		Update_Function_Coefficients(cnt,Q_a);

		Update_Policy_Parameters_Using_FA(cen, var);
		

#ifdef UPDATE_POLICY_PARAMETERS
		{
			double  tmax;
			{ // 
				char error_text[256];
				FILE *fp;
				
				if ((fp = fopen("gf.txt", "a")) == NULL)
				{
					sprintf(error_text, "Couldn't open \"%s\"\n", "gf.txt");
					My_Error(error_text);
				}
				for ( j = 0; j < num_of_gaussians; j++ )
				{
					for ( i = 0; i < dim; i++ )
					{
						fprintf(fp,"%g\n",drhodc[j][i]);
						fprintf(fp,"%g\n",drhodv[j][i]);
					}
				}
				fclose(fp);
			}
			for ( j = 0; j < num_of_gaussians; j++ )
			{
				for ( i = 0; i < dim; i++ )
				{
					if ( tdc < fabs(drhodc[j][i]) )
					{
						tdc = fabs(drhodc[j][i]);
					}
					if ( tdv < fabs(drhodv[j][i]) )
					{
						tdv = fabs(drhodv[j][i]);
					}
				}
			}
			
			if ( tdc > tdv )
			{
				tmax = tdc;
			}
			else
			{
				tmax = tdv;
			}
			
			
			for ( j = 0; j < num_of_gaussians; j++ )
			{
				for ( i = 0; i < dim; i++ )
				{
					drhodc[j][i] = drhodc[j][i] / tmax;
					drhodv[j][i] = drhodv[j][i] / tmax;
				}
			}
			
#ifdef GRAPHICS
			Update_Boundaries = 1;
#endif
			for ( j = 0; j < num_of_gaussians; j++ )
			{
				for ( i = 0; i < dim; i++ )
				{
					cen[j][i] = cen[j][i] + alpha * drhodc[j][i];
					var[j][i] = var[j][i] + alpha * drhodv[j][i];
					if ( var[j][i] < 0.01)
						var[j][i] = 0.01;
				}
			}
		}
#else
		{ // 
			char error_text[256];
			FILE *fp;
			
			if ((fp = fopen("gf.txt", "a")) == NULL)
			{
				sprintf(error_text, "Couldn't open \"%s\"\n", "gf.txt");
				My_Error(error_text);
			}
			for ( j = 0; j < num_of_gaussians; j++ )
			{
				for ( i = 0; i < dim; i++ )
				{
					fprintf(fp,"%g %g ",drhodc[j][i],drhodv[j][i]);
				}
			}
			fprintf(fp,"\n");
			fclose(fp);
		}
#endif
		

		for ( k = 0; k < cnt; k++ )
		{
			free(Q_a[k]);
		}
		free(Q_a);

		for ( k = 0; k < total_states_visited; k++ )
		{
			free(states_visited[k]);
		}
		free(states_visited);

		for ( k = 0; k < Episodes_Per_Parameter_Update+1; k++ )
		{
			free(Q[k]);
			free(p_pi[k]);

			for ( i = 0; i < num_of_gaussians; i++ )
			{
				for ( j = 0; j < dim; j++ )
				{
					free(dpdc[k][i][j]);
					free(dpdv[k][i][j]);
				}
				free(dpdc[k][i]);
				free(dpdv[k][i]);
			}
			free(dpdc[k]);
			free(dpdv[k]);
		}
		free(Q);
		free(dpdc);
		free(dpdv);
		free(p_pi);
		
		for ( i = 1; i < Episodes_Per_Parameter_Update; i++ )
		{
			free(modes_visited[i]);
		}
		free(modes_visited);

		Update_Policy_Parameters = 0;
		Step_To_Execute_Mode = -1;
		Mode_Execute = -1;

#ifdef UPDATE_POLICY_PARAMETERS
		Reset_Random_Seed_For_Paths();
#endif
		Num_of_Grad_Calculations++;
		if ( Num_of_Grad_Calculations > Max_Num_Grad_Calc )
		{
			exit(1);
		}
	}

}



void Update_Policy_Parameters_Using_FA(double **cen, double **var)
{
	int step,i,q,n,j;
	double g_tot, *g, *lpi, Qt;

	g = (double *)My_Malloc((long)num_of_gaussians  * sizeof(double));
	lpi = (double *)My_Malloc((long)num_of_gaussians  * sizeof(double));
	
	for ( step = 0; step < total_states_visited; step++ )
	{
		g_tot = 0.0;
		for ( i = 0; i < num_of_gaussians; i++ )
		{
			g[i] = evaluate_gauss(dim, states_visited[step], cen[i], var[i]);
			g_tot = g_tot + g[i];
		}
		for ( i = 0; i < num_of_gaussians; i++ )
		{
			lpi[i] = g[i]/g_tot;
			if ( lpi[i] < MIN_PI )
			{
				lpi[i] = MIN_PI;
			}
		}

		for ( q = 0; q < num_of_gaussians; q++ )
		{
			avaluate_total_gradient(dim, states_visited[step], cen[q], var[q], 
				dpdc_t, dpdv_t, wrk, q, g, num_of_gaussians, g_tot);

			Qt = 0.0;
			for ( i = 0; i < dim; i++ )
			{
				for ( n = 0; n < num_of_gaussians; n++ )
				{
					Qt = Qt + drhodc_coeff[q][i][n] * dpdc_t[i][n] / lpi[n];
					Qt = Qt + drhodv_coeff[q][i][n] * dpdv_t[i][n] / lpi[n];
				}
			}

			for ( i = 0; i < dim; i++ )
			{
				for ( n = 0; n < num_of_gaussians; n++ )
				{
					drhodc[q][i] = drhodc[q][i] + Qt * dpdc_t[i][n];
					drhodv[q][i] = drhodv[q][i] + Qt * dpdv_t[i][n];
				}
			}
			
		}
	}

	
	for ( j = 0; j < num_of_gaussians; j++ )
	{
		for ( i = 0; i < dim; i++ )
		{
			drhodc[j][i] = drhodc[j][i] / total_states_visited;
			drhodv[j][i] = drhodv[j][i] / total_states_visited;
		}
	}

	free(g);
	free(lpi);
}

void Update_Function_Coefficients(int nums, double **Q_a)
{
	double err, Qt, err_sq_before_lrn, err_sq_after_lrn;
	int i,j,k,n,cnt;
	
	int rows,cols;

	double **A,**AtA,*b,*Atb,*w, *S, *c;

	rows = nums-1;
	cols = 2 * num_of_gaussians * dim;

	/*** allocate memory ***/
	b = (double *)My_Malloc((long)(rows)  * sizeof(double));
	A = (double **)My_Malloc((long)rows  * sizeof(double*));
	for ( i = 0; i < rows; i++ )
	{
		A[i] = (double *)My_Malloc((long)cols  * sizeof(double));
	}
	w = (double *)My_Malloc((long)(2*cols)  * sizeof(double));
	S = (double *)My_Malloc((long)(2*cols)  * sizeof(double));
	c = (double *)My_Malloc((long)(2*cols)  * sizeof(double));
	Atb = (double *)My_Malloc((long)(2*cols)  * sizeof(double));
	AtA = (double **)My_Malloc((long)(2*cols)  * sizeof(double*));
	for ( i = 0; i < (2*cols); i++ )
	{
		AtA[i] = (double *)My_Malloc((long)cols  * sizeof(double));
	}
	/****************************/

	/* first update the policy parameters */
	err_sq_before_lrn = 0.0;
	for ( j = 0; j < num_of_gaussians; j++ )
	{
		for ( k = 1; k < nums; k++ )
		{
			cnt = 0;
			Qt = 0.0;
			for ( i = 0; i < dim; i++ )
			{
				for ( n = 0; n < num_of_gaussians; n++ )
				{
					Qt = Qt + drhodc_coeff[j][i][n] * dpdc[k][j][i][n] / p_pi[k][n];
					A[k-1][cnt++] = dpdc[k][j][i][n] / p_pi[k][n];
					Qt = Qt + drhodv_coeff[j][i][n] * dpdv[k][j][i][n] / p_pi[k][n];
					A[k-1][cnt++] = dpdv[k][j][i][n] / p_pi[k][n];
				}
			}
			err = Q[k][j] - Qt;
			err_sq_before_lrn = err_sq_before_lrn + err * err;
			b[k-1] = err;
		}

		for ( i = 0; i < cols; i++ )
		{
			for ( n = 0; n <= i; n++ )
			{
				AtA[i][n] = 0.0;
				for ( k = 0; k < rows; k++ )
				{
					AtA[i][n] = AtA[i][n] + A[k][i]*A[k][n];
					if ( _isnan(AtA[i][n]))
					{
						My_Error("Not a Number!");
					}
				}
			}
			Atb[i] = 0.0;
			for ( k = 0; k < rows; k++ )
			{
				Atb[i] = Atb[i] + A[k][i]*b[k];
			}
		}

		
		for ( i = 0; i < cols; i++ )
		{
			for ( n = i+1; n < cols; n++ )
			{
				AtA[i][n] = AtA[n][i];
			}
		}
		

		Solve_System(AtA, Atb, w, S, c, cols);
		cnt = 0;
		for ( i = 0; i < dim; i++ )
		{
			for ( n = 0; n < num_of_gaussians; n++ )
			{
				drhodc_coeff[j][i][n] = drhodc_coeff[j][i][n] + w[cnt++];
				drhodv_coeff[j][i][n] = drhodv_coeff[j][i][n] + w[cnt++];
			}
		}
	}
	
	/* calculate the new approximations */
	err_sq_after_lrn = 0.0;
	for ( k = 1; k < nums; k++ )
	{
		for ( j = 0; j < num_of_gaussians; j++ )
		{
			Qt = 0.0;
			for ( i = 0; i < dim; i++ )
			{
				for ( n = 0; n < num_of_gaussians; n++ )
				{
					Qt = Qt + drhodc_coeff[j][i][n] * dpdc[k][j][i][n] / p_pi[k][n];
					Qt = Qt + drhodv_coeff[j][i][n] * dpdv[k][j][i][n] / p_pi[k][n];
				}
			}

			Q_a[k][j] = Qt;
			err = Q[k][j] - Qt;
			err_sq_after_lrn = err_sq_after_lrn + err * err;
		}
	}

	/*** free up the memory ***/
	for ( i = 0; i < rows; i++ )
	{
		free(A[i]);
	}
	free(A);
	free(b);
	free(Atb);
	free(w);
	free(S);
	free(c);

	for ( i = 0; i < (2*cols); i++ )
	{
		free(AtA[i]);
	}
	free(AtA);
	/***************************/

}


void Solve_System(double **A, double *b, double *w, 
				  double *S2, double *c, int rows)
{
	void svd(double **A, double *S2, int n);
	int i, j, k;
	double tmp;
	
	svd(A, S2, rows);
	
	for (i = rows-1; i>=0; i--)                              /* "invert" S2 */
    {
		if (S2[i]*MAX_SV_RATIO > S2[0]) 
		{                                           /* SV large enough? */
			S2[i] = 1.0/S2[i];
		}
		else
			S2[i] = 0.0;                                /* delete direction */
    }
	
	for (i=0; i< rows; i++) 
    {                                 /* compute invA = V*invS2*(US)' */
		for (j=0; j< rows; j++) 
		{
			for (tmp=0.0, k=0; k < rows; k++)
				tmp += A[i+rows][k]*A[j][k]*S2[k];
			c[j] = tmp;
		}
		for (j=0; j< rows; j++) 
			A[i+rows][j] = c[j];  /* copy "c" into "A" */
    }
	
	for (i=0; i < rows; i++)             /* compute w = invA*b */
    {
		for (tmp=0.0, j=0; j< rows; j++)
			tmp += A[i+rows][j]*b[j];
		w[i] = tmp;
    }
}
#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -