LDA++
FastSupervisedEStep.hpp
1 #ifndef _LDAPLUSPLUS_EM_FASTSUPERVISEDESTEP_HPP_
2 #define _LDAPLUSPLUS_EM_FASTSUPERVISEDESTEP_HPP_
3 
4 #include "ldaplusplus/em/AbstractEStep.hpp"
5 
6 namespace ldaplusplus {
7 namespace em {
8 
9 
18 template<typename Scalar>
19 class FastSupervisedEStep : public AbstractEStep<Scalar>
20 {
21  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
22  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
23 
24  public:
30  {
35  Constant = 1,
41  };
42 
62  size_t e_step_iterations = 10,
63  Scalar e_step_tolerance = 1e-2,
64  Scalar C = 1,
65  CWeightType weight_type = CWeightType::Constant,
66  Scalar compute_likelihood = 1.0,
67  int random_state = 0
68  );
69 
100  std::shared_ptr<parameters::Parameters> doc_e_step(
101  const std::shared_ptr<corpus::Document> doc,
102  const std::shared_ptr<parameters::Parameters> parameters
103  ) override;
104 
110  void e_step() override;
111 
112  private:
119  Scalar get_weight();
120 
121  // The maximum number of iterations in expecation step.
122  size_t e_step_iterations_;
123  // The convergence tolerance for the maximazation of the ELBO w.r.t.
124  // \f$\phi\f$ and \f$\gamma\f$ in expecation step.
125  Scalar e_step_tolerance_;
126  // A parameter weighting the supervised component in the variational
127  // distribution.
128  Scalar C_;
129  // A parameter that is used to indicate whether to compute or not the
130  // supervised likelihood at the end of each expectation step.
131  Scalar compute_likelihood_;
132  // The method used to update parameter C between consecutive expectation
133  // steps.
134  CWeightType weight_type_;
135  // The epochs seen so far.
136  int epochs_;
137 };
138 
139 } // namespace em
140 } // namespace ldaplusplus
141 
142 #endif // _LDAPLUSPLUS_EM_FASTSUPERVISEDESTEP_HPP_
Definition: FastSupervisedEStep.hpp:19
Definition: FastSupervisedEStep.hpp:35
CWeightType
Definition: FastSupervisedEStep.hpp:29
Definition: FastSupervisedEStep.hpp:40
std::shared_ptr< parameters::Parameters > doc_e_step(const std::shared_ptr< corpus::Document > doc, const std::shared_ptr< parameters::Parameters > parameters) override
Definition: FastSupervisedEStep.cpp:30
void e_step() override
Definition: FastSupervisedEStep.cpp:102
FastSupervisedEStep(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar C=1, CWeightType weight_type=CWeightType::Constant, Scalar compute_likelihood=1.0, int random_state=0)
Definition: FastSupervisedEStep.cpp:12
Definition: AbstractEStep.hpp:21
Definition: Document.hpp:11