⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 owlqn.h

📁 source codes for "Orthant-Wise Limited-memory Quasi-Newton Optimizer for L1-regularized Objectives"
💻 H
字号:
#pragma once

#include <vector>
#include <deque>
#include <iostream>

	typedef std::vector<double> DblVec;

	struct DifferentiableFunction {
		virtual double Eval(const DblVec& input, DblVec& gradient) const = 0;
	};

class OWLQN {
	struct DblVecPtrDeque : public std::deque<DblVec*> {
		~DblVecPtrDeque() {
			for (size_t s = 0; s < size(); ++s) {
				if ((*this)[s] != NULL) delete (*this)[s];
			}
		}
	};

	bool quiet;

	struct OptimizerState {
		DblVec x, grad, newX, newGrad, dir;
		DblVec steepestDescDir; // references newGrad to save memory, since we don't ever use both at the same time
		DblVecPtrDeque sList, yList;
		std::deque<double> roList;
		std::vector<double> alphas;
		double value;
		int iter, m;
		const size_t dim;
		const DifferentiableFunction& func;
		double l1weight;
		bool quiet;

		static double dotProduct(const DblVec& a, const DblVec& b);
		static void add(DblVec& a, const DblVec& b);
		static void addMult(DblVec& a, const DblVec& b, double c);
		static void addMultInto(DblVec& a, const DblVec& b, const DblVec& c, double d);
		static void scale(DblVec& a, double b);
		static void scaleInto(DblVec& a, const DblVec& b, double c);

		void MapDirByInverseHessian();
		void UpdateDir();
		double DirDeriv() const;
		void GetNextPoint(double alpha);
		void BackTrackingLineSearch();
		void Shift();
		void MakeSteepestDescDir();
		double EvalL1();
		void FixDirSigns();
		void TestDirDeriv();

		OptimizerState(const DifferentiableFunction& f, const DblVec& init, int m, double l1weight, bool quiet) 
			: dim(init.size()), func(f), x(init), grad(init.size()), dir(init.size()), newX(init), newGrad(init.size()), m(m), iter(1), l1weight(l1weight), steepestDescDir(newGrad), alphas(m), quiet(quiet) {
				if (m <= 0) {
					std::cerr << "m must be an integer greater than zero." << std::endl;
					exit(1);
				}
				value = EvalL1();
				grad = newGrad;
		}
	};

public:
	void Minimize(const DifferentiableFunction& function, const DblVec& initial, DblVec& minimum, double l1weight = 1.0, double tol = 1e-4, int m = 10) const;

	OWLQN(bool quiet = false) : quiet(quiet) { }

	void SetQuiet(bool q) { quiet = q; }
};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -