LDA++
SupervisedMStep.hpp
1 #ifndef _LDAPLUSPLUS_EM_SUPERVISEDMSTEP_HPP_
2 #define _LDAPLUSPLUS_EM_SUPERVISEDMSTEP_HPP_
3 
4 #include <vector>
5 
6 #include "ldaplusplus/em/UnsupervisedMStep.hpp"
7 
8 namespace ldaplusplus {
9 namespace em {
10 
11 
39 template <typename Scalar>
40 class SupervisedMStep : public UnsupervisedMStep<Scalar>
41 {
42  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
43  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
44 
45  public:
54  size_t m_step_iterations = 10,
55  Scalar m_step_tolerance = 1e-2,
56  Scalar regularization_penalty = 1e-2
57  ) : m_step_iterations_(m_step_iterations),
58  m_step_tolerance_(m_step_tolerance),
59  regularization_penalty_(regularization_penalty),
60  docs_(0)
61  {}
62 
71  virtual void m_step(
72  std::shared_ptr<parameters::Parameters> parameters
73  ) override;
74 
86  virtual void doc_m_step(
87  const std::shared_ptr<corpus::Document> doc,
88  const std::shared_ptr<parameters::Parameters> v_parameters,
89  std::shared_ptr<parameters::Parameters> m_parameters
90  ) override;
91 
92  private:
93  // The maximum number of iterations in M-step
94  size_t m_step_iterations_;
95  // The convergence tolerance for the maximazation of the ELBO w.r.t.
96  // eta in M-step
97  Scalar m_step_tolerance_;
98  // The regularization penalty for the multinomial logistic regression
99  Scalar regularization_penalty_;
100 
101  // Number of documents processed so far
102  int docs_;
103  MatrixX phi_scaled;
104  MatrixX expected_z_bar_;
105  std::vector<MatrixX> variance_z_bar_;
106  Eigen::VectorXi y_;
107 };
108 
109 } // namespace em
110 } // namespace ldaplusplus
111 
112 #endif // _LDAPLUSPLUS_EM_SUPERVISEDMSTEP_HPP_
Definition: UnsupervisedMStep.hpp:42
virtual void m_step(std::shared_ptr< parameters::Parameters > parameters) override
Definition: SupervisedMStep.cpp:84
virtual void doc_m_step(const std::shared_ptr< corpus::Document > doc, const std::shared_ptr< parameters::Parameters > v_parameters, std::shared_ptr< parameters::Parameters > m_parameters) override
Definition: SupervisedMStep.cpp:16
SupervisedMStep(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: SupervisedMStep.hpp:53
Definition: SupervisedMStep.hpp:40
Definition: Document.hpp:11