LDA++
LDABuilder.hpp
1 #ifndef _LDAPLUSPLUS_LDABUILDER_HPP_
2 #define _LDAPLUSPLUS_LDABUILDER_HPP_
3 
4 #include <memory>
5 #include <stdexcept>
6 #include <thread>
7 #include <vector>
8 
9 #include <Eigen/Core>
10 
11 #include "ldaplusplus/Document.hpp"
12 #include "ldaplusplus/em/FastSupervisedEStep.hpp"
13 #include "ldaplusplus/em/EStepInterface.hpp"
14 #include "ldaplusplus/em/MStepInterface.hpp"
15 #include "ldaplusplus/LDA.hpp"
16 
17 namespace ldaplusplus {
18 
19 
24 template <typename Scalar = double>
26 {
27  public:
28  virtual operator LDA<Scalar>() const = 0;
29 
30  virtual ~LDABuilderInterface(){};
31 };
32 
33 
61 template <typename Scalar = double>
62 class LDABuilder : public LDABuilderInterface<Scalar>
63 {
64  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
65  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
66 
67  public:
78  LDABuilder();
79 
81  LDABuilder & set_iterations(size_t iterations);
82 
84  LDABuilder & set_workers(size_t workers);
85 
103  std::shared_ptr<em::EStepInterface<Scalar> > get_classic_e_step(
104  size_t e_step_iterations = 10,
105  Scalar e_step_tolerance = 1e-2,
106  Scalar compute_likelihood = 1.0,
107  int random_state = 0
108  );
113  size_t e_step_iterations = 10,
114  Scalar e_step_tolerance = 1e-2,
115  Scalar compute_likelihood = 1.0,
116  int random_state = 0
117  ) {
118  e_requires_eta_ = false;
119  return set_e(get_classic_e_step(
120  e_step_iterations,
121  e_step_tolerance,
122  compute_likelihood,
123  random_state
124  ));
125  }
126 
141  std::shared_ptr<em::EStepInterface<Scalar> > get_supervised_e_step(
142  size_t e_step_iterations = 10,
143  Scalar e_step_tolerance = 1e-2,
144  size_t fixed_point_iterations = 10,
145  Scalar compute_likelihood = 1.0,
146  int random_state = 0
147  );
152  size_t e_step_iterations = 10,
153  Scalar e_step_tolerance = 1e-2,
154  size_t fixed_point_iterations = 10,
155  Scalar compute_likelihood = 1.0,
156  int random_state = 0
157  ) {
158  set_e(get_supervised_e_step(
159  e_step_iterations,
160  e_step_tolerance,
161  fixed_point_iterations,
162  compute_likelihood,
163  random_state
164  ));
165  e_requires_eta_ = true;
166  return *this;
167  }
168 
188  std::shared_ptr<em::EStepInterface<Scalar> > get_fast_supervised_e_step(
189  size_t e_step_iterations = 10,
190  Scalar e_step_tolerance = 1e-2,
191  Scalar C = 1,
192  Scalar compute_likelihood = 1.0,
193  int random_state = 0
194  );
199  size_t e_step_iterations = 10,
200  Scalar e_step_tolerance = 1e-2,
201  Scalar C = 1,
202  Scalar compute_likelihood = 1.0,
203  int random_state = 0
204  ) {
205  set_e(get_fast_supervised_e_step(
206  e_step_iterations,
207  e_step_tolerance,
208  C,
209  compute_likelihood,
210  random_state
211  ));
212  e_requires_eta_ = true;
213  return *this;
214  }
215 
231  std::shared_ptr<em::EStepInterface<Scalar> > get_semi_supervised_e_step(
232  std::shared_ptr<em::EStepInterface<Scalar> > supervised_step = nullptr,
233  std::shared_ptr<em::EStepInterface<Scalar> > unsupervised_step = nullptr
234  );
239  std::shared_ptr<em::EStepInterface<Scalar> > supervised_step = nullptr,
240  std::shared_ptr<em::EStepInterface<Scalar> > unsupervised_step = nullptr
241  ) {
242  set_e(get_semi_supervised_e_step(
243  supervised_step,
244  unsupervised_step
245  ));
246  e_requires_eta_ = true;
247  return *this;
248  }
249 
273  std::shared_ptr<em::EStepInterface<Scalar> > get_multinomial_supervised_e_step(
274  size_t e_step_iterations = 10,
275  Scalar e_step_tolerance = 1e-2,
276  Scalar mu = 2,
277  Scalar eta_weight = 1,
278  Scalar compute_likelihood = 1.0,
279  int random_state = 0
280  );
285  size_t e_step_iterations = 10,
286  Scalar e_step_tolerance = 1e-2,
287  Scalar mu = 2,
288  Scalar eta_weight = 1,
289  Scalar compute_likelihood = 1.0,
290  int random_state = 0
291  ) {
292  set_e(get_multinomial_supervised_e_step(
293  e_step_iterations,
294  e_step_tolerance,
295  mu,
296  eta_weight,
297  compute_likelihood,
298  random_state
299  ));
300  e_requires_eta_ = true;
301  return *this;
302  }
303 
318  std::shared_ptr<em::EStepInterface<Scalar> > get_correspondence_supervised_e_step(
319  size_t e_step_iterations = 10,
320  Scalar e_step_tolerance = 1e-2,
321  Scalar mu = 2,
322  Scalar compute_likelihood = 1.0,
323  int random_state = 0
324  );
329  size_t e_step_iterations = 10,
330  Scalar e_step_tolerance = 1e-2,
331  Scalar mu = 2,
332  Scalar compute_likelihood = 1.0,
333  int random_state = 0
334  ) {
335  set_e(get_correspondence_supervised_e_step(
336  e_step_iterations,
337  e_step_tolerance,
338  mu,
339  compute_likelihood,
340  random_state
341  ));
342  e_requires_eta_ = true;
343  return *this;
344  }
345 
355  LDABuilder & set_e(std::shared_ptr<em::EStepInterface<Scalar> > e_step) {
356  e_requires_eta_ = false; // clear require eta because we do not know
357  // this e_step
358  e_step_ = e_step;
359  return *this;
360  }
361 
365  std::shared_ptr<em::MStepInterface<Scalar> > get_classic_m_step();
370  set_m(get_classic_m_step());
371  m_requires_eta_ = false;
372  return *this;
373  }
374 
388  std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_m_step(
389  size_t m_step_iterations = 10,
390  Scalar m_step_tolerance = 1e-2,
391  Scalar regularization_penalty = 1e-2
392  );
397  size_t m_step_iterations = 10,
398  Scalar m_step_tolerance = 1e-2,
399  Scalar regularization_penalty = 1e-2
400  ) {
401  set_m(get_fast_supervised_m_step(
402  m_step_iterations,
403  m_step_tolerance,
404  regularization_penalty
405  ));
406  m_requires_eta_ = true;
407  return *this;
408  }
409 
423  std::shared_ptr<em::MStepInterface<Scalar> > get_supervised_m_step(
424  size_t m_step_iterations = 10,
425  Scalar m_step_tolerance = 1e-2,
426  Scalar regularization_penalty = 1e-2
427  );
432  size_t m_step_iterations = 10,
433  Scalar m_step_tolerance = 1e-2,
434  Scalar regularization_penalty = 1e-2
435  ) {
436  set_m(get_supervised_m_step(
437  m_step_iterations,
438  m_step_tolerance,
439  regularization_penalty
440  ));
441  m_requires_eta_ = true;
442  return *this;
443  }
444 
463  std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
464  size_t num_classes,
465  Scalar regularization_penalty = 1e-2,
466  size_t minibatch_size = 128,
467  Scalar eta_momentum = 0.9,
468  Scalar eta_learning_rate = 0.01,
469  Scalar beta_weight = 0.9
470  );
471  LDABuilder & set_fast_supervised_online_m_step(
472  size_t num_classes,
473  Scalar regularization_penalty = 1e-2,
474  size_t minibatch_size = 128,
475  Scalar eta_momentum = 0.9,
476  Scalar eta_learning_rate = 0.01,
477  Scalar beta_weight = 0.9
478  ) {
479  set_m(get_fast_supervised_online_m_step(
480  num_classes,
481  regularization_penalty,
482  minibatch_size,
483  eta_momentum,
484  eta_learning_rate,
485  beta_weight
486  ));
487  m_requires_eta_ = true;
488  return *this;
489  }
490 
510  std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
511  std::vector<Scalar> class_weights,
512  Scalar regularization_penalty = 1e-2,
513  size_t minibatch_size = 128,
514  Scalar eta_momentum = 0.9,
515  Scalar eta_learning_rate = 0.01,
516  Scalar beta_weight = 0.9
517  );
518  LDABuilder & set_fast_supervised_online_m_step(
519  std::vector<Scalar> class_weights,
520  Scalar regularization_penalty = 1e-2,
521  size_t minibatch_size = 128,
522  Scalar eta_momentum = 0.9,
523  Scalar eta_learning_rate = 0.01,
524  Scalar beta_weight = 0.9
525  ) {
526  set_m(get_fast_supervised_online_m_step(
527  class_weights,
528  regularization_penalty,
529  minibatch_size,
530  eta_momentum,
531  eta_learning_rate,
532  beta_weight
533  ));
534  m_requires_eta_ = true;
535  return *this;
536  }
537 
557  std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
558  Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
559  Scalar regularization_penalty = 1e-2,
560  size_t minibatch_size = 128,
561  Scalar eta_momentum = 0.9,
562  Scalar eta_learning_rate = 0.01,
563  Scalar beta_weight = 0.9
564  );
565  LDABuilder & set_fast_supervised_online_m_step(
566  Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
567  Scalar regularization_penalty = 1e-2,
568  size_t minibatch_size = 128,
569  Scalar eta_momentum = 0.9,
570  Scalar eta_learning_rate = 0.01,
571  Scalar beta_weight = 0.9
572  ) {
573  set_m(get_fast_supervised_online_m_step(
574  class_weights,
575  regularization_penalty,
576  minibatch_size,
577  eta_momentum,
578  eta_learning_rate,
579  beta_weight
580  ));
581  m_requires_eta_ = true;
582  return *this;
583  }
584 
598  std::shared_ptr<em::MStepInterface<Scalar> > get_semi_supervised_m_step(
599  size_t m_step_iterations = 10,
600  Scalar m_step_tolerance = 1e-2,
601  Scalar regularization_penalty = 1e-2
602  );
607  size_t m_step_iterations = 10,
608  Scalar m_step_tolerance = 1e-2,
609  Scalar regularization_penalty = 1e-2
610  ) {
611  set_m(get_semi_supervised_m_step(
612  m_step_iterations,
613  m_step_tolerance,
614  regularization_penalty
615  ));
616  m_requires_eta_ = true;
617  return *this;
618  }
619 
628  std::shared_ptr<em::MStepInterface<Scalar> > get_multinomial_supervised_m_step(
629  Scalar mu = 2.
630  );
635  Scalar mu = 2.
636  ) {
637  set_m(get_multinomial_supervised_m_step(mu));
638  m_requires_eta_ = true;
639  return *this;
640  }
641 
650  std::shared_ptr<em::MStepInterface<Scalar> > get_correspondence_supervised_m_step(
651  Scalar mu = 2.
652  );
657  Scalar mu = 2.
658  ) {
659  set_m(get_correspondence_supervised_m_step(mu));
660  m_requires_eta_ = true;
661  return *this;
662  }
663 
669  LDABuilder & set_m(std::shared_ptr<em::MStepInterface<Scalar> > m_step) {
670  m_requires_eta_ = false; // clear require eta because we do not know
671  // this m_step
672  m_step_ = m_step;
673  return *this;
674  }
675 
691  LDABuilder & initialize_topics_seeded(
692  const Eigen::MatrixXi &X,
693  size_t topics,
694  size_t N = 30,
695  int random_state = 0
696  );
697 
713  LDABuilder & initialize_topics_seeded(
714  std::shared_ptr<corpus::Corpus> corpus,
715  size_t topics,
716  size_t N = 30,
717  int random_state = 0
718  );
719 
730  LDABuilder & initialize_topics_random(
731  size_t words,
732  size_t topics,
733  int random_state = 0
734  );
735 
743  std::shared_ptr<parameters::ModelParameters<Scalar> > model
744  ) {
745  model_parameters_->alpha = model->alpha;
746  model_parameters_->beta = model->beta;
747 
748  return *this;
749  }
750 
757  LDABuilder & initialize_eta_zeros(size_t num_classes);
758 
765  LDABuilder & initialize_eta_uniform(size_t num_classes);
766 
774  std::shared_ptr<parameters::SupervisedModelParameters<Scalar> > model
775  ) {
776  model_parameters_->eta = model->eta;
777 
778  return *this;
779  }
780 
788  virtual operator LDA<Scalar>() const override {
789  if (model_parameters_->beta.rows() == 0) {
790  throw std::runtime_error("You need to call initialize_topics before "
791  "creating an LDA from the builder.");
792  }
793 
794  if (
795  model_parameters_->eta.rows() == 0 &&
796  (e_requires_eta_ || m_requires_eta_)
797  ) {
798  throw std::runtime_error("An E step or M step seems to be supervised "
799  "yet you have not initialized eta. "
800  "Call initialize_eta_*()");
801  }
802 
803  return LDA<Scalar>(
804  model_parameters_,
805  e_step_,
806  m_step_,
807  iterations_,
808  workers_
809  );
810  };
811 
812  private:
813  // generic lda parameters
814  size_t iterations_;
815  size_t workers_;
816 
817  // implementations
818  std::shared_ptr<em::EStepInterface<Scalar> > e_step_;
819  std::shared_ptr<em::MStepInterface<Scalar> > m_step_;
820 
821  // the model parameters
822  std::shared_ptr<parameters::SupervisedModelParameters<Scalar> > model_parameters_;
823 
824  // A flag to keep track of having set EM steps that require the eta
825  // model parameters.
826  bool e_requires_eta_;
827  bool m_requires_eta_;
828 };
829 
830 
831 } // namespace ldaplusplus
832 #endif // _LDAPLUSPLUS_LDABUILDER_HPP_
Definition: MStepInterface.hpp:24
LDABuilder & set_fast_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar C=1, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:198
LDABuilder & set_fast_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:396
LDABuilder & initialize_eta_from_model(std::shared_ptr< parameters::SupervisedModelParameters< Scalar > > model)
Definition: LDABuilder.hpp:773
LDABuilder & initialize_topics_from_model(std::shared_ptr< parameters::ModelParameters< Scalar > > model)
Definition: LDABuilder.hpp:742
LDABuilder & set_classic_m_step()
Definition: LDABuilder.hpp:369
LDABuilder & set_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:431
LDABuilder & set_e(std::shared_ptr< em::EStepInterface< Scalar > > e_step)
Definition: LDABuilder.hpp:355
LDABuilder & set_correspondence_supervised_m_step(Scalar mu=2.)
Definition: LDABuilder.hpp:656
Definition: EStepInterface.hpp:24
LDABuilder & set_m(std::shared_ptr< em::MStepInterface< Scalar > > m_step)
Definition: LDABuilder.hpp:669
Definition: LDABuilder.hpp:62
LDABuilder & set_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, size_t fixed_point_iterations=10, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:151
LDABuilder & set_multinomial_supervised_m_step(Scalar mu=2.)
Definition: LDABuilder.hpp:634
Definition: ProgressEvents.hpp:11
LDABuilder & set_classic_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:112
Definition: LDABuilder.hpp:25
LDABuilder & set_multinomial_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar mu=2, Scalar eta_weight=1, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:284
Definition: Parameters.hpp:27
LDABuilder & set_semi_supervised_e_step(std::shared_ptr< em::EStepInterface< Scalar > > supervised_step=nullptr, std::shared_ptr< em::EStepInterface< Scalar > > unsupervised_step=nullptr)
Definition: LDABuilder.hpp:238
LDABuilder & set_correspondence_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar mu=2, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:328
LDABuilder & set_semi_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:606
Definition: Document.hpp:11