em_fit.hpp
Go to the documentation of this file.00001
00023 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
00024 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
00025
00026 #include <mlpack/core.hpp>
00027
00028
00029 #include <mlpack/methods/kmeans/kmeans.hpp>
00030
00031 #include "positive_definite_constraint.hpp"
00032
00033 namespace mlpack {
00034 namespace gmm {
00035
00049 template<typename InitialClusteringType = kmeans::KMeans<>,
00050 typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
00051 class EMFit
00052 {
00053 public:
00071 EMFit(const size_t maxIterations = 300,
00072 const double tolerance = 1e-10,
00073 InitialClusteringType clusterer = InitialClusteringType(),
00074 CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
00075
00091 void Estimate(const arma::mat& observations,
00092 std::vector<arma::vec>& means,
00093 std::vector<arma::mat>& covariances,
00094 arma::vec& weights,
00095 const bool useInitialModel = false);
00096
00114 void Estimate(const arma::mat& observations,
00115 const arma::vec& probabilities,
00116 std::vector<arma::vec>& means,
00117 std::vector<arma::mat>& covariances,
00118 arma::vec& weights,
00119 const bool useInitialModel = false);
00120
00122 const InitialClusteringType& Clusterer() const { return clusterer; }
00124 InitialClusteringType& Clusterer() { return clusterer; }
00125
00127 const CovarianceConstraintPolicy& Constraint() const { return constraint; }
00129 CovarianceConstraintPolicy& Constraint() { return constraint; }
00130
00132 size_t MaxIterations() const { return maxIterations; }
00134 size_t& MaxIterations() { return maxIterations; }
00135
00137 double Tolerance() const { return tolerance; }
00139 double& Tolerance() { return tolerance; }
00140
00141 private:
00152 void InitialClustering(const arma::mat& observations,
00153 std::vector<arma::vec>& means,
00154 std::vector<arma::mat>& covariances,
00155 arma::vec& weights);
00156
00167 double LogLikelihood(const arma::mat& data,
00168 const std::vector<arma::vec>& means,
00169 const std::vector<arma::mat>& covariances,
00170 const arma::vec& weights) const;
00171
00173 size_t maxIterations;
00175 double tolerance;
00177 InitialClusteringType clusterer;
00179 CovarianceConstraintPolicy constraint;
00180 };
00181
00182 };
00183 };
00184
00185
00186 #include "em_fit_impl.hpp"
00187
00188 #endif