MLPACK  1.0.8
hmm.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_METHODS_HMM_HMM_HPP
24 #define __MLPACK_METHODS_HMM_HMM_HPP
25 
26 #include <mlpack/core.hpp>
27 
28 namespace mlpack {
29 namespace hmm {
30 
92 template<typename Distribution = distribution::DiscreteDistribution>
93 class HMM
94 {
95  public:
110  HMM(const size_t states,
111  const Distribution emissions,
112  const double tolerance = 1e-5);
113 
135  HMM(const arma::mat& transition,
136  const std::vector<Distribution>& emission,
137  const double tolerance = 1e-5);
138 
167  void Train(const std::vector<arma::mat>& dataSeq);
168 
190  void Train(const std::vector<arma::mat>& dataSeq,
191  const std::vector<arma::Col<size_t> >& stateSeq);
192 
211  double Estimate(const arma::mat& dataSeq,
212  arma::mat& stateProb,
213  arma::mat& forwardProb,
214  arma::mat& backwardProb,
215  arma::vec& scales) const;
216 
228  double Estimate(const arma::mat& dataSeq,
229  arma::mat& stateProb) const;
230 
242  void Generate(const size_t length,
243  arma::mat& dataSequence,
244  arma::Col<size_t>& stateSequence,
245  const size_t startState = 0) const;
246 
257  double Predict(const arma::mat& dataSeq,
258  arma::Col<size_t>& stateSeq) const;
259 
266  double LogLikelihood(const arma::mat& dataSeq) const;
267 
269  const arma::mat& Transition() const { return transition; }
271  arma::mat& Transition() { return transition; }
272 
274  const std::vector<Distribution>& Emission() const { return emission; }
276  std::vector<Distribution>& Emission() { return emission; }
277 
279  size_t Dimensionality() const { return dimensionality; }
281  size_t& Dimensionality() { return dimensionality; }
282 
284  double Tolerance() const { return tolerance; }
286  double& Tolerance() { return tolerance; }
287 
288  private:
289  // Helper functions.
290 
301  void Forward(const arma::mat& dataSeq,
302  arma::vec& scales,
303  arma::mat& forwardProb) const;
304 
316  void Backward(const arma::mat& dataSeq,
317  const arma::vec& scales,
318  arma::mat& backwardProb) const;
319 
321  arma::mat transition;
322 
324  std::vector<Distribution> emission;
325 
328 
330  double tolerance;
331 };
332 
333 }; // namespace hmm
334 }; // namespace mlpack
335 
336 // Include implementation.
337 #include "hmm_impl.hpp"
338 
339 #endif