cover_tree.hpp

Go to the documentation of this file.
00001 
00022 #ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00023 #define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00024 
00025 #include <mlpack/core.hpp>
00026 #include <mlpack/core/metrics/lmetric.hpp>
00027 #include "first_point_is_root.hpp"
00028 #include "../statistic.hpp"
00029 
00030 namespace mlpack {
00031 namespace tree {
00032 
00100 template<typename MetricType = metric::LMetric<2, true>,
00101          typename RootPointPolicy = FirstPointIsRoot,
00102          typename StatisticType = EmptyStatistic>
00103 class CoverTree
00104 {
00105  public:
00106   typedef arma::mat Mat;
00107 
00118   CoverTree(const arma::mat& dataset,
00119             const double base = 2.0,
00120             MetricType* metric = NULL);
00121 
00131   CoverTree(const arma::mat& dataset,
00132             MetricType& metric,
00133             const double base = 2.0);
00134 
00166   CoverTree(const arma::mat& dataset,
00167             const double base,
00168             const size_t pointIndex,
00169             const int scale,
00170             CoverTree* parent,
00171             const double parentDistance,
00172             arma::Col<size_t>& indices,
00173             arma::vec& distances,
00174             size_t nearSetSize,
00175             size_t& farSetSize,
00176             size_t& usedSetSize,
00177             MetricType& metric = NULL);
00178 
00195   CoverTree(const arma::mat& dataset,
00196             const double base,
00197             const size_t pointIndex,
00198             const int scale,
00199             CoverTree* parent,
00200             const double parentDistance,
00201             const double furthestDescendantDistance,
00202             MetricType* metric = NULL);
00203 
00210   CoverTree(const CoverTree& other);
00211 
00215   ~CoverTree();
00216 
00219   template<typename RuleType>
00220   class SingleTreeTraverser;
00221 
00223   template<typename RuleType>
00224   class DualTreeTraverser;
00225 
00227   const arma::mat& Dataset() const { return dataset; }
00228 
00230   size_t Point() const { return point; }
00232   size_t Point(const size_t) const { return point; }
00233 
00234   bool IsLeaf() const { return (children.size() == 0); }
00235   size_t NumPoints() const { return 1; }
00236 
00238   const CoverTree& Child(const size_t index) const { return *children[index]; }
00240   CoverTree& Child(const size_t index) { return *children[index]; }
00241 
00243   size_t NumChildren() const { return children.size(); }
00244 
00246   const std::vector<CoverTree*>& Children() const { return children; }
00248   std::vector<CoverTree*>& Children() { return children; }
00249 
00251   size_t NumDescendants() const;
00252 
00254   size_t Descendant(const size_t index) const;
00255 
00257   int Scale() const { return scale; }
00259   int& Scale() { return scale; }
00260 
00262   double Base() const { return base; }
00264   double& Base() { return base; }
00265 
00267   const StatisticType& Stat() const { return stat; }
00269   StatisticType& Stat() { return stat; }
00270 
00272   double MinDistance(const CoverTree* other) const;
00273 
00276   double MinDistance(const CoverTree* other, const double distance) const;
00277 
00279   double MinDistance(const arma::vec& other) const;
00280 
00283   double MinDistance(const arma::vec& other, const double distance) const;
00284 
00286   double MaxDistance(const CoverTree* other) const;
00287 
00290   double MaxDistance(const CoverTree* other, const double distance) const;
00291 
00293   double MaxDistance(const arma::vec& other) const;
00294 
00297   double MaxDistance(const arma::vec& other, const double distance) const;
00298 
00300   math::Range RangeDistance(const CoverTree* other) const;
00301 
00304   math::Range RangeDistance(const CoverTree* other, const double distance)
00305       const;
00306 
00308   math::Range RangeDistance(const arma::vec& other) const;
00309 
00312   math::Range RangeDistance(const arma::vec& other, const double distance)
00313       const;
00314 
00316   static bool HasSelfChildren() { return true; }
00317 
00319   CoverTree* Parent() const { return parent; }
00321   CoverTree*& Parent() { return parent; }
00322 
00324   double ParentDistance() const { return parentDistance; }
00326   double& ParentDistance() { return parentDistance; }
00327 
00329   double FurthestPointDistance() const { return 0.0; }
00330 
00332   double FurthestDescendantDistance() const
00333   { return furthestDescendantDistance; }
00336   double& FurthestDescendantDistance() { return furthestDescendantDistance; }
00337 
00339   void Centroid(arma::vec& centroid) const { centroid = dataset.col(point); }
00340 
00342   MetricType& Metric() const { return *metric; }
00343 
00344  private:
00346   const arma::mat& dataset;
00347 
00349   size_t point;
00350 
00352   std::vector<CoverTree*> children;
00353 
00355   int scale;
00356 
00358   double base;
00359 
00361   StatisticType stat;
00362 
00364   size_t numDescendants;
00365 
00367   CoverTree* parent;
00368 
00370   double parentDistance;
00371 
00373   double furthestDescendantDistance;
00374 
00376   bool localMetric;
00377 
00379   MetricType* metric;
00380 
00384   void CreateChildren(arma::Col<size_t>& indices,
00385                       arma::vec& distances,
00386                       size_t nearSetSize,
00387                       size_t& farSetSize,
00388                       size_t& usedSetSize);
00389 
00401   void ComputeDistances(const size_t pointIndex,
00402                         const arma::Col<size_t>& indices,
00403                         arma::vec& distances,
00404                         const size_t pointSetSize);
00419   size_t SplitNearFar(arma::Col<size_t>& indices,
00420                       arma::vec& distances,
00421                       const double bound,
00422                       const size_t pointSetSize);
00423 
00443   size_t SortPointSet(arma::Col<size_t>& indices,
00444                       arma::vec& distances,
00445                       const size_t childFarSetSize,
00446                       const size_t childUsedSetSize,
00447                       const size_t farSetSize);
00448 
00449   void MoveToUsedSet(arma::Col<size_t>& indices,
00450                      arma::vec& distances,
00451                      size_t& nearSetSize,
00452                      size_t& farSetSize,
00453                      size_t& usedSetSize,
00454                      arma::Col<size_t>& childIndices,
00455                      const size_t childFarSetSize,
00456                      const size_t childUsedSetSize);
00457   size_t PruneFarSet(arma::Col<size_t>& indices,
00458                      arma::vec& distances,
00459                      const double bound,
00460                      const size_t nearSetSize,
00461                      const size_t pointSetSize);
00462 
00467   void RemoveNewImplicitNodes();
00468 
00469  public:
00473   std::string ToString() const;
00474 
00475   size_t DistanceComps() const { return distanceComps; }
00476   size_t& DistanceComps() { return distanceComps; }
00477 
00478  private:
00479   size_t distanceComps;
00480 };
00481 
00482 }; // namespace tree
00483 }; // namespace mlpack
00484 
00485 // Include implementation.
00486 #include "cover_tree_impl.hpp"
00487 
00488 #endif

Generated on 13 Aug 2014 for MLPACK by  doxygen 1.6.1