1 #ifndef _LDAPLUSPLUS_LDABUILDER_HPP_ 2 #define _LDAPLUSPLUS_LDABUILDER_HPP_ 11 #include "ldaplusplus/Document.hpp" 12 #include "ldaplusplus/em/FastSupervisedEStep.hpp" 13 #include "ldaplusplus/em/EStepInterface.hpp" 14 #include "ldaplusplus/em/MStepInterface.hpp" 15 #include "ldaplusplus/LDA.hpp" 24 template <
typename Scalar =
double>
61 template <
typename Scalar =
double>
64 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
65 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;
81 LDABuilder & set_iterations(
size_t iterations);
103 std::shared_ptr<em::EStepInterface<Scalar> > get_classic_e_step(
104 size_t e_step_iterations = 10,
105 Scalar e_step_tolerance = 1e-2,
106 Scalar compute_likelihood = 1.0,
113 size_t e_step_iterations = 10,
114 Scalar e_step_tolerance = 1e-2,
115 Scalar compute_likelihood = 1.0,
118 e_requires_eta_ =
false;
119 return set_e(get_classic_e_step(
141 std::shared_ptr<em::EStepInterface<Scalar> > get_supervised_e_step(
142 size_t e_step_iterations = 10,
143 Scalar e_step_tolerance = 1e-2,
144 size_t fixed_point_iterations = 10,
145 Scalar compute_likelihood = 1.0,
152 size_t e_step_iterations = 10,
153 Scalar e_step_tolerance = 1e-2,
154 size_t fixed_point_iterations = 10,
155 Scalar compute_likelihood = 1.0,
158 set_e(get_supervised_e_step(
161 fixed_point_iterations,
165 e_requires_eta_ =
true;
188 std::shared_ptr<em::EStepInterface<Scalar> > get_fast_supervised_e_step(
189 size_t e_step_iterations = 10,
190 Scalar e_step_tolerance = 1e-2,
192 Scalar compute_likelihood = 1.0,
199 size_t e_step_iterations = 10,
200 Scalar e_step_tolerance = 1e-2,
202 Scalar compute_likelihood = 1.0,
205 set_e(get_fast_supervised_e_step(
212 e_requires_eta_ =
true;
231 std::shared_ptr<em::EStepInterface<Scalar> > get_semi_supervised_e_step(
242 set_e(get_semi_supervised_e_step(
246 e_requires_eta_ =
true;
273 std::shared_ptr<em::EStepInterface<Scalar> > get_multinomial_supervised_e_step(
274 size_t e_step_iterations = 10,
275 Scalar e_step_tolerance = 1e-2,
277 Scalar eta_weight = 1,
278 Scalar compute_likelihood = 1.0,
285 size_t e_step_iterations = 10,
286 Scalar e_step_tolerance = 1e-2,
288 Scalar eta_weight = 1,
289 Scalar compute_likelihood = 1.0,
292 set_e(get_multinomial_supervised_e_step(
300 e_requires_eta_ =
true;
318 std::shared_ptr<em::EStepInterface<Scalar> > get_correspondence_supervised_e_step(
319 size_t e_step_iterations = 10,
320 Scalar e_step_tolerance = 1e-2,
322 Scalar compute_likelihood = 1.0,
329 size_t e_step_iterations = 10,
330 Scalar e_step_tolerance = 1e-2,
332 Scalar compute_likelihood = 1.0,
335 set_e(get_correspondence_supervised_e_step(
342 e_requires_eta_ =
true;
356 e_requires_eta_ =
false;
365 std::shared_ptr<em::MStepInterface<Scalar> > get_classic_m_step();
370 set_m(get_classic_m_step());
371 m_requires_eta_ =
false;
388 std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_m_step(
389 size_t m_step_iterations = 10,
390 Scalar m_step_tolerance = 1e-2,
391 Scalar regularization_penalty = 1e-2
397 size_t m_step_iterations = 10,
398 Scalar m_step_tolerance = 1e-2,
399 Scalar regularization_penalty = 1e-2
401 set_m(get_fast_supervised_m_step(
404 regularization_penalty
406 m_requires_eta_ =
true;
423 std::shared_ptr<em::MStepInterface<Scalar> > get_supervised_m_step(
424 size_t m_step_iterations = 10,
425 Scalar m_step_tolerance = 1e-2,
426 Scalar regularization_penalty = 1e-2
432 size_t m_step_iterations = 10,
433 Scalar m_step_tolerance = 1e-2,
434 Scalar regularization_penalty = 1e-2
436 set_m(get_supervised_m_step(
439 regularization_penalty
441 m_requires_eta_ =
true;
463 std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
465 Scalar regularization_penalty = 1e-2,
466 size_t minibatch_size = 128,
467 Scalar eta_momentum = 0.9,
468 Scalar eta_learning_rate = 0.01,
469 Scalar beta_weight = 0.9
471 LDABuilder & set_fast_supervised_online_m_step(
473 Scalar regularization_penalty = 1e-2,
474 size_t minibatch_size = 128,
475 Scalar eta_momentum = 0.9,
476 Scalar eta_learning_rate = 0.01,
477 Scalar beta_weight = 0.9
479 set_m(get_fast_supervised_online_m_step(
481 regularization_penalty,
487 m_requires_eta_ =
true;
510 std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
511 std::vector<Scalar> class_weights,
512 Scalar regularization_penalty = 1e-2,
513 size_t minibatch_size = 128,
514 Scalar eta_momentum = 0.9,
515 Scalar eta_learning_rate = 0.01,
516 Scalar beta_weight = 0.9
518 LDABuilder & set_fast_supervised_online_m_step(
519 std::vector<Scalar> class_weights,
520 Scalar regularization_penalty = 1e-2,
521 size_t minibatch_size = 128,
522 Scalar eta_momentum = 0.9,
523 Scalar eta_learning_rate = 0.01,
524 Scalar beta_weight = 0.9
526 set_m(get_fast_supervised_online_m_step(
528 regularization_penalty,
534 m_requires_eta_ =
true;
557 std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
558 Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
559 Scalar regularization_penalty = 1e-2,
560 size_t minibatch_size = 128,
561 Scalar eta_momentum = 0.9,
562 Scalar eta_learning_rate = 0.01,
563 Scalar beta_weight = 0.9
565 LDABuilder & set_fast_supervised_online_m_step(
566 Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
567 Scalar regularization_penalty = 1e-2,
568 size_t minibatch_size = 128,
569 Scalar eta_momentum = 0.9,
570 Scalar eta_learning_rate = 0.01,
571 Scalar beta_weight = 0.9
573 set_m(get_fast_supervised_online_m_step(
575 regularization_penalty,
581 m_requires_eta_ =
true;
598 std::shared_ptr<em::MStepInterface<Scalar> > get_semi_supervised_m_step(
599 size_t m_step_iterations = 10,
600 Scalar m_step_tolerance = 1e-2,
601 Scalar regularization_penalty = 1e-2
607 size_t m_step_iterations = 10,
608 Scalar m_step_tolerance = 1e-2,
609 Scalar regularization_penalty = 1e-2
611 set_m(get_semi_supervised_m_step(
614 regularization_penalty
616 m_requires_eta_ =
true;
628 std::shared_ptr<em::MStepInterface<Scalar> > get_multinomial_supervised_m_step(
637 set_m(get_multinomial_supervised_m_step(mu));
638 m_requires_eta_ =
true;
650 std::shared_ptr<em::MStepInterface<Scalar> > get_correspondence_supervised_m_step(
659 set_m(get_correspondence_supervised_m_step(mu));
660 m_requires_eta_ =
true;
670 m_requires_eta_ =
false;
692 const Eigen::MatrixXi &X,
714 std::shared_ptr<corpus::Corpus> corpus,
745 model_parameters_->alpha = model->alpha;
746 model_parameters_->beta = model->beta;
757 LDABuilder & initialize_eta_zeros(
size_t num_classes);
765 LDABuilder & initialize_eta_uniform(
size_t num_classes);
776 model_parameters_->eta = model->eta;
789 if (model_parameters_->beta.rows() == 0) {
790 throw std::runtime_error(
"You need to call initialize_topics before " 791 "creating an LDA from the builder.");
795 model_parameters_->eta.rows() == 0 &&
796 (e_requires_eta_ || m_requires_eta_)
798 throw std::runtime_error(
"An E step or M step seems to be supervised " 799 "yet you have not initialized eta. " 800 "Call initialize_eta_*()");
818 std::shared_ptr<em::EStepInterface<Scalar> > e_step_;
819 std::shared_ptr<em::MStepInterface<Scalar> > m_step_;
822 std::shared_ptr<parameters::SupervisedModelParameters<Scalar> > model_parameters_;
826 bool e_requires_eta_;
827 bool m_requires_eta_;
832 #endif // _LDAPLUSPLUS_LDABUILDER_HPP_ Definition: MStepInterface.hpp:24
LDABuilder & set_fast_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar C=1, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:198
LDABuilder & set_fast_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:396
LDABuilder & initialize_eta_from_model(std::shared_ptr< parameters::SupervisedModelParameters< Scalar > > model)
Definition: LDABuilder.hpp:773
LDABuilder & initialize_topics_from_model(std::shared_ptr< parameters::ModelParameters< Scalar > > model)
Definition: LDABuilder.hpp:742
LDABuilder & set_classic_m_step()
Definition: LDABuilder.hpp:369
LDABuilder & set_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:431
LDABuilder & set_e(std::shared_ptr< em::EStepInterface< Scalar > > e_step)
Definition: LDABuilder.hpp:355
LDABuilder & set_correspondence_supervised_m_step(Scalar mu=2.)
Definition: LDABuilder.hpp:656
Definition: EStepInterface.hpp:24
LDABuilder & set_m(std::shared_ptr< em::MStepInterface< Scalar > > m_step)
Definition: LDABuilder.hpp:669
Definition: LDABuilder.hpp:62
LDABuilder & set_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, size_t fixed_point_iterations=10, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:151
LDABuilder & set_multinomial_supervised_m_step(Scalar mu=2.)
Definition: LDABuilder.hpp:634
Definition: ProgressEvents.hpp:11
LDABuilder & set_classic_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:112
Definition: LDABuilder.hpp:25
LDABuilder & set_multinomial_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar mu=2, Scalar eta_weight=1, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:284
Definition: Parameters.hpp:27
Definition: Parameters.hpp:47
LDABuilder & set_semi_supervised_e_step(std::shared_ptr< em::EStepInterface< Scalar > > supervised_step=nullptr, std::shared_ptr< em::EStepInterface< Scalar > > unsupervised_step=nullptr)
Definition: LDABuilder.hpp:238
LDABuilder & set_correspondence_supervised_e_step(size_t e_step_iterations=10, Scalar e_step_tolerance=1e-2, Scalar mu=2, Scalar compute_likelihood=1.0, int random_state=0)
Definition: LDABuilder.hpp:328
LDABuilder & set_semi_supervised_m_step(size_t m_step_iterations=10, Scalar m_step_tolerance=1e-2, Scalar regularization_penalty=1e-2)
Definition: LDABuilder.hpp:606
Definition: Document.hpp:11