LDA++
FastSupervisedMStep.hpp
1 #ifndef _LDAPLUSPLUS_EM_FASTSUPERVISEDMSTEP_HPP_
2 #define _LDAPLUSPLUS_EM_FASTSUPERVISEDMSTEP_HPP_
3 
4 #include "ldaplusplus/em/UnsupervisedMStep.hpp"
5 
6 namespace ldaplusplus {
7 namespace em {
8 
9 
44 template <typename Scalar>
45 class FastSupervisedMStep : public UnsupervisedMStep<Scalar>
46 {
47  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
48  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
49 
50  public:
59  size_t m_step_iterations = 10,
60  Scalar m_step_tolerance = 1e-2,
61  Scalar regularization_penalty = 1e-2
62  ) : m_step_iterations_(m_step_iterations),
63  m_step_tolerance_(m_step_tolerance),
64  regularization_penalty_(regularization_penalty),
65  docs_(0)
66  {}
67 
76  virtual void m_step(
77  std::shared_ptr<parameters::Parameters> parameters
78  ) override;
79 
91  virtual void doc_m_step(
92  const std::shared_ptr<corpus::Document> doc,
93  const std::shared_ptr<parameters::Parameters> v_parameters,
94  std::shared_ptr<parameters::Parameters> m_parameters
95  ) override;
96 
97  private:
98  // The maximum number of iterations in M-step
99  size_t m_step_iterations_;
100  // The convergence tolerance for the maximization of the ELBO w.r.t.
101  // eta in M-step
102  Scalar m_step_tolerance_;
103  // The regularization penalty for the multinomial logistic regression
104  Scalar regularization_penalty_;
105 
106  // Number of documents processed so far
107  int docs_;
108  MatrixX expected_z_bar_;
109  Eigen::VectorXi y_;
110 };
111 
112 } // namespace em
113 } // namespace ldaplusplus
114 
115 #endif // _LDAPLUSPLUS_EM_FASTSUPERVISEDMSTEP_HPP_
Definition: UnsupervisedMStep.hpp:42
virtual void doc_m_step(const std::shared_ptr< corpus::Document > doc, const std::shared_ptr< parameters::Parameters > v_parameters, std::shared_ptr< parameters::Parameters > m_parameters) override
Definition: FastSupervisedMStep.cpp:15
FastSupervisedMStep(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: FastSupervisedMStep.hpp:58
virtual void m_step(std::shared_ptr< parameters::Parameters > parameters) override
Definition: FastSupervisedMStep.cpp:56
Definition: Document.hpp:11
Definition: FastSupervisedMStep.hpp:45