LDA++
GradientDescent.hpp
1 #ifndef _LDAPLUSPLUS_OPTIMIZATION_GRADIENT_DESCENT_HPP_
2 #define _LDAPLUSPLUS_OPTIMIZATION_GRADIENT_DESCENT_HPP_
3 
4 #include <functional>
5 #include <memory>
6 
7 #include <Eigen/Core>
8 
9 namespace ldaplusplus {
10 namespace optimization {
11 
12 
18 template <typename ProblemType, typename ParameterType>
20 {
21  public:
22  typedef typename ParameterType::Scalar Scalar;
23 
37  virtual Scalar search(
38  const ProblemType &problem,
39  Eigen::Ref<ParameterType> x0,
40  const ParameterType &grad_x0,
41  const ParameterType &direction
42  ) = 0;
43 
44  virtual ~LineSearch(){};
45 };
46 
47 
52 template <typename ProblemType, typename ParameterType>
53 class ConstantLineSearch : public LineSearch<ProblemType, ParameterType>
54 {
55  public:
56  typedef typename ParameterType::Scalar Scalar;
57 
61  ConstantLineSearch(Scalar alpha) : alpha_(alpha) {}
62 
63  Scalar search(
64  const ProblemType &problem,
65  Eigen::Ref<ParameterType> x0,
66  const ParameterType &grad_x0,
67  const ParameterType &direction
68  ) {
69  x0 -= alpha_ * direction;
70 
71  return problem.value(x0);
72  }
73 
74  private:
75  Scalar alpha_;
76 };
77 
99 template <typename ProblemType, typename ParameterType>
100 class ArmijoLineSearch : public LineSearch<ProblemType, ParameterType>
101 {
102  public:
103  typedef typename ParameterType::Scalar Scalar;
104 
109  ArmijoLineSearch(Scalar beta=0.001, Scalar tau=0.5) : beta_(beta),
110  tau_(tau)
111  {}
112 
113  Scalar search(
114  const ProblemType &problem,
115  Eigen::Ref<ParameterType> x0,
116  const ParameterType &grad_x0,
117  const ParameterType &direction
118  ) {
119  ParameterType x_copy(x0.rows(), x0.cols());
120  Scalar value_x0 = problem.value(x0);
121  Scalar decrease = beta_ * (grad_x0.array() * direction.array()).sum();
122  Scalar value = value_x0;
123  Scalar a = 1.0/tau_;
124 
125  while (value > value_x0 - a * decrease) {
126  a *= tau_;
127  x_copy = x0 - a * direction;
128  value = problem.value(x_copy);
129  }
130 
131  x0 -= a * direction;
132 
133  return value;
134  }
135 
136  private:
137  Scalar beta_;
138  Scalar tau_;
139 };
140 
141 
157 template <typename ProblemType, typename ParameterType>
159 {
160  public:
161  typedef typename ParameterType::Scalar Scalar;
162 
170  std::shared_ptr<LineSearch<ProblemType, ParameterType> > line_search,
171  std::function<bool(Scalar, Scalar, size_t)> progress
172  ) : line_search_(line_search),
173  progress_(progress)
174  {}
175 
186  void minimize(const ProblemType &problem, Eigen::Ref<ParameterType> x0) {
187  // allocate memory for the gradient
188  ParameterType grad(x0.rows(), x0.cols());
189 
190  // Keep the value in this variable
191  Scalar value = problem.value(x0);
192 
193  // And the iterations in this one
194  size_t iterations = 0;
195 
196  // Whether we stop or not is decided by someone else
197  while (progress_(value, grad.template lpNorm<Eigen::Infinity>(), iterations++)) {
198  problem.gradient(x0, grad);
199  value = line_search_->search(problem, x0, grad, grad);
200  }
201  }
202 
203  private:
204  std::shared_ptr<LineSearch<ProblemType, ParameterType> > line_search_;
205  std::function<bool(Scalar, Scalar, size_t)> progress_;
206 };
207 
208 
209 } // namespace optimization
210 } // namespace ldaplusplus
211 
212 #endif // _LDAPLUSPLUS_OPTIMIZATION_GRADIENT_DESCENT_HPP_
ArmijoLineSearch(Scalar beta=0.001, Scalar tau=0.5)
Definition: GradientDescent.hpp:109
Definition: GradientDescent.hpp:100
Scalar search(const ProblemType &problem, Eigen::Ref< ParameterType > x0, const ParameterType &grad_x0, const ParameterType &direction)
Definition: GradientDescent.hpp:63
GradientDescent(std::shared_ptr< LineSearch< ProblemType, ParameterType > > line_search, std::function< bool(Scalar, Scalar, size_t)> progress)
Definition: GradientDescent.hpp:169
Definition: GradientDescent.hpp:158
Definition: GradientDescent.hpp:19
void minimize(const ProblemType &problem, Eigen::Ref< ParameterType > x0)
Definition: GradientDescent.hpp:186
virtual Scalar search(const ProblemType &problem, Eigen::Ref< ParameterType > x0, const ParameterType &grad_x0, const ParameterType &direction)=0
ConstantLineSearch(Scalar alpha)
Definition: GradientDescent.hpp:61
Definition: GradientDescent.hpp:53
Scalar search(const ProblemType &problem, Eigen::Ref< ParameterType > x0, const ParameterType &grad_x0, const ParameterType &direction)
Definition: GradientDescent.hpp:113
Definition: Document.hpp:11