dtree.hpp
Go to the documentation of this file.00001
00023 #ifndef __MLPACK_METHODS_DET_DTREE_HPP
00024 #define __MLPACK_METHODS_DET_DTREE_HPP
00025
00026 #include <mlpack/core.hpp>
00027
00028 namespace mlpack {
00029 namespace det {
00030
00054 class DTree
00055 {
00056 public:
00060 DTree();
00061
00070 DTree(const arma::vec& maxVals,
00071 const arma::vec& minVals,
00072 const size_t totalPoints);
00073
00082 DTree(arma::mat& data);
00083
00096 DTree(const arma::vec& maxVals,
00097 const arma::vec& minVals,
00098 const size_t start,
00099 const size_t end,
00100 const double logNegError);
00101
00113 DTree(const arma::vec& maxVals,
00114 const arma::vec& minVals,
00115 const size_t totalPoints,
00116 const size_t start,
00117 const size_t end);
00118
00120 ~DTree();
00121
00132 double Grow(arma::mat& data,
00133 arma::Col<size_t>& oldFromNew,
00134 const bool useVolReg = false,
00135 const size_t maxLeafSize = 10,
00136 const size_t minLeafSize = 5);
00137
00146 double PruneAndUpdate(const double oldAlpha,
00147 const size_t points,
00148 const bool useVolReg = false);
00149
00155 double ComputeValue(const arma::vec& query) const;
00156
00164 void WriteTree(FILE *fp, const size_t level = 0) const;
00165
00173 int TagTree(const int tag = 0);
00174
00181 int FindBucket(const arma::vec& query) const;
00182
00188 void ComputeVariableImportance(arma::vec& importances) const;
00189
00196 double LogNegativeError(const size_t totalPoints) const;
00197
00201 bool WithinRange(const arma::vec& query) const;
00202
00203 private:
00204
00205
00206
00207
00208
00209
00212 size_t start;
00215 size_t end;
00216
00218 arma::vec maxVals;
00220 arma::vec minVals;
00221
00223 size_t splitDim;
00224
00226 double splitValue;
00227
00229 double logNegError;
00230
00232 double subtreeLeavesLogNegError;
00233
00235 size_t subtreeLeaves;
00236
00238 bool root;
00239
00241 double ratio;
00242
00244 double logVolume;
00245
00247 int bucketTag;
00248
00250 double alphaUpper;
00251
00253 DTree* left;
00255 DTree* right;
00256
00257 public:
00259 size_t Start() const { return start; }
00261 size_t End() const { return end; }
00263 size_t SplitDim() const { return splitDim; }
00265 double SplitValue() const { return splitValue; }
00267 double LogNegError() const { return logNegError; }
00269 double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
00271 size_t SubtreeLeaves() const { return subtreeLeaves; }
00274 double Ratio() const { return ratio; }
00276 double LogVolume() const { return logVolume; }
00278 DTree* Left() const { return left; }
00280 DTree* Right() const { return right; }
00282 bool Root() const { return root; }
00284 double AlphaUpper() const { return alphaUpper; }
00285
00287 const arma::vec& MaxVals() const { return maxVals; }
00289 arma::vec& MaxVals() { return maxVals; }
00290
00292 const arma::vec& MinVals() const { return minVals; }
00294 arma::vec& MinVals() { return minVals; }
00295
00296 private:
00297
00298
00299
00303 bool FindSplit(const arma::mat& data,
00304 size_t& splitDim,
00305 double& splitValue,
00306 double& leftError,
00307 double& rightError,
00308 const size_t maxLeafSize = 10,
00309 const size_t minLeafSize = 5) const;
00310
00314 size_t SplitData(arma::mat& data,
00315 const size_t splitDim,
00316 const double splitValue,
00317 arma::Col<size_t>& oldFromNew) const;
00318
00319 };
00320
00321 };
00322 };
00323
00324 #endif // __MLPACK_METHODS_DET_DTREE_HPP