📄 linesearch.cpp
字号:
#include "iostream.h"
#include "math.h"
//梯度计算模块
//参数:指向目标函数的指针,变量个数,求梯度的点,结果
void comput_grad(double (*pf)(double *x),
int n,
double *point,
double *grad)
{
double h=1E-3;
int i;
double *temp;
temp = new double[n];
for(i=1;i<=n;i++)
{
temp[i-1]=point[i-1];
}
for(i=1;i<=n;i++)
{
temp[i-1]+=0.5*h;
grad[i-1]=4*pf(temp)/(3*h);
temp[i-1]-=h;
grad[i-1]-=4*pf(temp)/(3*h);
temp[i-1]+=(3*h/2);
grad[i-1]-=(pf(temp)/(6*h));
temp[i-1]-=(2*h);
grad[i-1]+=(pf(temp)/(6*h));
temp[i-1]=point[i-1];
}
delete[] temp;
}
//一维搜索模块
//参数:指向目标函数的指针,变量个数,出发点,搜索方向
//返回:最优步长
double line_search(
double (*pf)(double *x),
int n,
double *start,
double *direction)
{
int i;
double step=0.001;
double a=0,value_a,diver_a;
double b,value_b,diver_b;
double t,value_t,diver_t;
double s,z,w;
double *grad,*temp_point;
grad=new double[n];
temp_point=new double[n];
comput_grad(pf,n,start,grad);
diver_a=0;
for(i=1;i<=n;i++)
diver_a=diver_a+grad[i-1]*direction[i-1];
do
{
b=a+step;
for(i=1;i<=n;i++)
temp_point[i-1]=start[i-1]+b*direction[i-1];
comput_grad(pf,n,temp_point,grad);
diver_b=0;
for(i=1;i<=n;i++)
diver_b=diver_b+grad[i-1]*direction[i-1];
if( fabs(diver_b)<1E-10 )
{
delete[] grad;
delete[] temp_point;
return(b);
}
if( diver_b<-1E-15 )
{
a=b;
diver_a=diver_b;
step=2*step;
}
}while( diver_b<=1E-15 );
for(i=1;i<=n;i++)
temp_point[i-1]=start[i-1]+a*direction[i-1];
value_a=(*pf)(temp_point);
for(i=1;i<=n;i++)
temp_point[i-1]=start[i-1]+b*direction[i-1];
value_b=(*pf)(temp_point);
do
{
s=3*(value_b-value_a)/(b-a);
z=s-diver_a-diver_b;
w=sqrt(z*z-diver_a*diver_b);
t=a+(w-z-diver_a)*(b-a)/(diver_b-diver_a+2*w);
value_b=(*pf)(temp_point);
for(i=1;i<=n;i++)
temp_point[i-1]=start[i-1]+t*direction[i-1];
value_t=(*pf)(temp_point);
comput_grad(pf,n,temp_point,grad);
diver_t=0;
for(i=1;i<=n;i++)
diver_t=diver_t+grad[i-1]*direction[i-1];
if(diver_t>1E-6)
{
b=t;
value_b=value_t;
diver_b=diver_t;
}
else if(diver_t<-1E-6)
{
a=t;
value_a=value_t;
diver_a=diver_t;
}
else break;
}while( (fabs(diver_t)>=1E-6) && (fabs(b-a)>1E-6) );
delete[] grad;
delete[] temp_point;
return(t);
}
///
double fun(double *x)
{
return x[0]*x[0]+2*x[1]*x[1]+3*x[2]*x[2]+4*x[3]*x[3];
}
void main()
{
double x[4]={1,1,1,1};
double p[4]={-2,-4,-6,-8};
double landa=line_search(fun,4,x,p);
cout<<landa;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -