LDA++
FastOnlineSupervisedMStep.hpp
1 #ifndef _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_
2 #define _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_
3 
4 #include "ldaplusplus/em/MStepInterface.hpp"
5 
6 namespace ldaplusplus {
7 namespace em {
8 
9 
25 template <typename Scalar>
27 {
28  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
29  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
30 
31  public:
50  VectorX class_weights,
51  Scalar regularization_penalty = 1e-2,
52  size_t minibatch_size = 128,
53  Scalar eta_momentum = 0.9,
54  Scalar eta_learning_rate = 0.01,
55  Scalar beta_weight = 0.9
56  );
74  size_t num_classes,
75  Scalar regularization_penalty = 1e-2,
76  size_t minibatch_size = 128,
77  Scalar eta_momentum = 0.9,
78  Scalar eta_learning_rate = 0.01,
79  Scalar beta_weight = 0.9
80  );
81 
85  virtual void m_step(
86  std::shared_ptr<parameters::Parameters> parameters
87  ) override;
88 
100  virtual void doc_m_step(
101  const std::shared_ptr<corpus::Document> doc,
102  const std::shared_ptr<parameters::Parameters> v_parameters,
103  std::shared_ptr<parameters::Parameters> m_parameters
104  ) override;
105 
106  private:
107  // Number of classes
108  VectorX class_weights_;
109  size_t num_classes_;
110 
111  // Minibatch size and portion (the portion of the corpus)
112  size_t minibatch_size_;
113 
114  // The regularization penalty for the multinomial logistic regression
115  // Mind that it should account for the minibatch size
116  Scalar regularization_penalty_;
117 
118  // The suff stats and data needed to optimize the ELBO w.r.t. model
119  // parameters
120  MatrixX b_;
121  Scalar beta_weight_;
122  MatrixX expected_z_bar_;
123  Eigen::VectorXi y_;
124  MatrixX eta_velocity_;
125  MatrixX eta_gradient_;
126  Scalar eta_momentum_;
127  Scalar eta_learning_rate_;
128 
129  // The number of document's seen so far
130  size_t docs_seen_so_far_;
131 };
132 
133 } // namespace em
134 } // namespace ldaplusplus
135 
136 #endif // _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_
Definition: MStepInterface.hpp:24
FastOnlineSupervisedMStep(VectorX class_weights, Scalar regularization_penalty=1e-2, size_t minibatch_size=128, Scalar eta_momentum=0.9, Scalar eta_learning_rate=0.01, Scalar beta_weight=0.9)
Definition: FastOnlineSupervisedMStep.cpp:12
virtual void m_step(std::shared_ptr< parameters::Parameters > parameters) override
Definition: FastOnlineSupervisedMStep.cpp:89
virtual void doc_m_step(const std::shared_ptr< corpus::Document > doc, const std::shared_ptr< parameters::Parameters > v_parameters, std::shared_ptr< parameters::Parameters > m_parameters) override
Definition: FastOnlineSupervisedMStep.cpp:48
Definition: FastOnlineSupervisedMStep.hpp:26
Definition: Document.hpp:11