[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_HXX 00038 #define VIGRA_RANDOM_FOREST_HXX 00039 00040 #include <iostream> 00041 #include <algorithm> 00042 #include <map> 00043 #include <set> 00044 #include <list> 00045 #include <numeric> 00046 #include "mathutil.hxx" 00047 #include "array_vector.hxx" 00048 #include "sized_int.hxx" 00049 #include "matrix.hxx" 00050 #include "random.hxx" 00051 #include "functorexpression.hxx" 00052 #include "random_forest/rf_common.hxx" 00053 #include "random_forest/rf_nodeproxy.hxx" 00054 #include "random_forest/rf_split.hxx" 00055 #include "random_forest/rf_decisionTree.hxx" 00056 #include "random_forest/rf_visitors.hxx" 00057 #include "random_forest/rf_region.hxx" 00058 #include "sampling.hxx" 00059 #include "random_forest/rf_preprocessing.hxx" 00060 #include "random_forest/rf_online_prediction_set.hxx" 00061 #include "random_forest/rf_earlystopping.hxx" 00062 #include "random_forest/rf_ridge_split.hxx" 00063 namespace vigra 00064 { 00065 00066 /** \addtogroup MachineLearning Machine Learning 00067 00068 This module provides classification algorithms that map 00069 features to labels or label probabilities. 00070 Look at the RandomForest class first for a overview of most of the 00071 functionality provided as well as use cases. 00072 **/ 00073 //@{ 00074 00075 namespace detail 00076 { 00077 00078 00079 00080 /* \brief sampling option factory function 00081 */ 00082 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt) 00083 { 00084 SamplerOptions return_opt; 00085 return_opt.withReplacement(RF_opt.sample_with_replacement_); 00086 return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL); 00087 return return_opt; 00088 } 00089 }//namespace detail 00090 00091 /** Random Forest class 00092 * 00093 * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess 00094 * the input while learning and predicting. Currently Available: 00095 * ClassificationTag and RegressionTag. It is recommended to use 00096 * Splitfunctor::Preprocessor_t while using custom splitfunctors 00097 * as they may need the data to be in a different format. 00098 * \sa Preprocessor 00099 * 00100 * simple usage for classification (regression is not yet supported): 00101 * look at RandomForest::learn() as well as RandomForestOptions() for additional 00102 * options. 00103 * 00104 * \code 00105 * using namespace vigra; 00106 * using namespace rf; 00107 * typedef xxx feature_t; \\ replace xxx with whichever type 00108 * typedef yyy label_t; \\ likewise 00109 * 00110 * // allocate the training data 00111 * MultiArrayView<2, feature_t> f = get_training_features(); 00112 * MultiArrayView<2, label_t> l = get_training_labels(); 00113 * 00114 * RandomForest<> rf; 00115 * 00116 * // construct visitor to calculate out-of-bag error 00117 * visitors::OOB_Error oob_v; 00118 * 00119 * // perform training 00120 * rf.learn(f, l, visitors::create_visitor(oob_v)); 00121 * 00122 * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n"; 00123 * 00124 * // get features for new data to be used for prediction 00125 * MultiArrayView<2, feature_t> pf = get_features(); 00126 * 00127 * // allocate space for the response (pf.shape(0) is the number of samples) 00128 * MultiArrayView<2, label_t> prediction(pf.shape(0), 1); 00129 * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count()); 00130 * 00131 * // perform prediction on new data 00132 * rf.predict_labels(pf, prediction); 00133 * rf.predict_probabilities(pf, prob); 00134 * 00135 * \endcode 00136 * 00137 * Additional information such as Variable Importance measures are accessed 00138 * via Visitors defined in rf::visitors. 00139 * Have a look at rf::split for other splitting methods. 00140 * 00141 */ 00142 template <class LabelType = double , class PreprocessorTag = ClassificationTag > 00143 class RandomForest 00144 { 00145 00146 public: 00147 //public typedefs 00148 typedef RandomForestOptions Options_t; 00149 typedef detail::DecisionTree DecisionTree_t; 00150 typedef ProblemSpec<LabelType> ProblemSpec_t; 00151 typedef GiniSplit Default_Split_t; 00152 typedef EarlyStoppStd Default_Stop_t; 00153 typedef rf::visitors::StopVisiting Default_Visitor_t; 00154 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00155 StackEntry_t; 00156 typedef LabelType LabelT; 00157 protected: 00158 00159 /** optimisation for predictLabels 00160 * */ 00161 mutable MultiArray<2, double> garbage_prediction_; 00162 00163 public: 00164 00165 //problem independent data. 00166 Options_t options_; 00167 //problem dependent data members - is only set if 00168 //a copy constructor, some sort of import 00169 //function or the learn function is called 00170 ArrayVector<DecisionTree_t> trees_; 00171 ProblemSpec_t ext_param_; 00172 /*mutable ArrayVector<int> tree_indices_;*/ 00173 rf::visitors::OnlineLearnVisitor online_visitor_; 00174 00175 00176 void reset() 00177 { 00178 ext_param_.clear(); 00179 trees_.clear(); 00180 } 00181 00182 public: 00183 00184 /** \name Constructors 00185 * Note: No copy Constructor specified as no pointers are manipulated 00186 * in this class 00187 */ 00188 /*\{*/ 00189 /**\brief default constructor 00190 * 00191 * \param options general options to the Random Forest. Must be of Type 00192 * Options_t 00193 * \param ext_param problem specific values that can be supplied 00194 * additionally. (class weights , labels etc) 00195 * \sa RandomForestOptions, ProblemSpec 00196 * 00197 */ 00198 RandomForest(Options_t const & options = Options_t(), 00199 ProblemSpec_t const & ext_param = ProblemSpec_t()) 00200 : 00201 options_(options), 00202 ext_param_(ext_param)/*, 00203 tree_indices_(options.tree_count_,0)*/ 00204 { 00205 /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii) 00206 tree_indices_[ii] = ii;*/ 00207 } 00208 00209 /**\brief Create RF from external source 00210 * \param treeCount Number of trees to add. 00211 * \param topology_begin 00212 * Iterator to a Container where the topology_ data 00213 * of the trees are stored. 00214 * Iterator should support at least treeCount forward 00215 * iterations. (i.e. topology_end - topology_begin >= treeCount 00216 * \param parameter_begin 00217 * iterator to a Container where the parameters_ data 00218 * of the trees are stored. Iterator should support at 00219 * least treeCount forward iterations. 00220 * \param problem_spec 00221 * Extrinsic parameters that specify the problem e.g. 00222 * ClassCount, featureCount etc. 00223 * \param options (optional) specify options used to train the original 00224 * Random forest. This parameter is not used anywhere 00225 * during prediction and thus is optional. 00226 * 00227 */ 00228 /* TODO: This constructor may be replaced by a Constructor using 00229 * NodeProxy iterators to encapsulate the underlying data type. 00230 */ 00231 template<class TopologyIterator, class ParameterIterator> 00232 RandomForest(int treeCount, 00233 TopologyIterator topology_begin, 00234 ParameterIterator parameter_begin, 00235 ProblemSpec_t const & problem_spec, 00236 Options_t const & options = Options_t()) 00237 : 00238 trees_(treeCount, DecisionTree_t(problem_spec)), 00239 ext_param_(problem_spec), 00240 options_(options) 00241 { 00242 for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin) 00243 { 00244 trees_[k].topology_ = *topology_begin; 00245 trees_[k].parameters_ = *parameter_begin; 00246 } 00247 } 00248 00249 /*\}*/ 00250 00251 00252 /** \name Data Access 00253 * data access interface - usage of member variables is deprecated 00254 */ 00255 00256 /*\{*/ 00257 00258 00259 /**\brief return external parameters for viewing 00260 * \return ProblemSpec_t 00261 */ 00262 ProblemSpec_t const & ext_param() const 00263 { 00264 vigra_precondition(ext_param_.used() == true, 00265 "RandomForest::ext_param(): " 00266 "Random forest has not been trained yet."); 00267 return ext_param_; 00268 } 00269 00270 /**\brief set external parameters 00271 * 00272 * \param in external parameters to be set 00273 * 00274 * set external parameters explicitly. 00275 * If Random Forest has not been trained the preprocessor will 00276 * either ignore filling values set this way or will throw an exception 00277 * if values specified manually do not match the value calculated 00278 & during the preparation step. 00279 */ 00280 void set_ext_param(ProblemSpec_t const & in) 00281 { 00282 vigra_precondition(ext_param_.used() == false, 00283 "RandomForest::set_ext_param():" 00284 "Random forest has been trained! Call reset()" 00285 "before specifying new extrinsic parameters."); 00286 } 00287 00288 /**\brief access random forest options 00289 * 00290 * \return random forest options 00291 */ 00292 Options_t & set_options() 00293 { 00294 return options; 00295 } 00296 00297 00298 /**\brief access const random forest options 00299 * 00300 * \return const Option_t 00301 */ 00302 Options_t const & options() const 00303 { 00304 return options_; 00305 } 00306 00307 /**\brief access const trees 00308 */ 00309 DecisionTree_t const & tree(int index) const 00310 { 00311 return trees_[index]; 00312 } 00313 00314 /**\brief access trees 00315 */ 00316 DecisionTree_t & tree(int index) 00317 { 00318 return trees_[index]; 00319 } 00320 00321 /*\}*/ 00322 00323 /**\brief return number of features used while 00324 * training. 00325 */ 00326 int feature_count() const 00327 { 00328 return ext_param_.column_count_; 00329 } 00330 00331 00332 /**\brief return number of features used while 00333 * training. 00334 * 00335 * deprecated. Use feature_count() instead. 00336 */ 00337 int column_count() const 00338 { 00339 return ext_param_.column_count_; 00340 } 00341 00342 /**\brief return number of classes used while 00343 * training. 00344 */ 00345 int class_count() const 00346 { 00347 return ext_param_.class_count_; 00348 } 00349 00350 /**\brief return number of trees 00351 */ 00352 int tree_count() const 00353 { 00354 return options_.tree_count_; 00355 } 00356 00357 00358 00359 template<class U,class C1, 00360 class U2, class C2, 00361 class Split_t, 00362 class Stop_t, 00363 class Visitor_t, 00364 class Random_t> 00365 void onlineLearn( MultiArrayView<2,U,C1> const & features, 00366 MultiArrayView<2,U2,C2> const & response, 00367 int new_start_index, 00368 Visitor_t visitor_, 00369 Split_t split_, 00370 Stop_t stop_, 00371 Random_t & random, 00372 bool adjust_thresholds=false); 00373 00374 template <class U, class C1, class U2,class C2> 00375 void onlineLearn( MultiArrayView<2, U, C1> const & features, 00376 MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false) 00377 { 00378 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00379 onlineLearn(features, 00380 labels, 00381 new_start_index, 00382 rf_default(), 00383 rf_default(), 00384 rf_default(), 00385 rnd, 00386 adjust_thresholds); 00387 } 00388 00389 template<class U,class C1, 00390 class U2, class C2, 00391 class Split_t, 00392 class Stop_t, 00393 class Visitor_t, 00394 class Random_t> 00395 void reLearnTree(MultiArrayView<2,U,C1> const & features, 00396 MultiArrayView<2,U2,C2> const & response, 00397 int treeId, 00398 Visitor_t visitor_, 00399 Split_t split_, 00400 Stop_t stop_, 00401 Random_t & random); 00402 00403 template<class U, class C1, class U2, class C2> 00404 void reLearnTree(MultiArrayView<2, U, C1> const & features, 00405 MultiArrayView<2, U2, C2> const & labels, 00406 int treeId) 00407 { 00408 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00409 reLearnTree(features, 00410 labels, 00411 treeId, 00412 rf_default(), 00413 rf_default(), 00414 rf_default(), 00415 rnd); 00416 } 00417 00418 00419 /**\name Learning 00420 * Following functions differ in the degree of customization 00421 * allowed 00422 */ 00423 /*\{*/ 00424 /**\brief learn on data with custom config and random number generator 00425 * 00426 * \param features a N x M matrix containing N samples with M 00427 * features 00428 * \param response a N x D matrix containing the corresponding 00429 * response. Current split functors assume D to 00430 * be 1 and ignore any additional columns. 00431 * This is not enforced to allow future support 00432 * for uncertain labels, label independent strata etc. 00433 * The Preprocessor specified during construction 00434 * should be able to handle features and labels 00435 * features and the labels. 00436 * see also: SplitFunctor, Preprocessing 00437 * 00438 * \param visitor visitor which is to be applied after each split, 00439 * tree and at the end. Use rf_default for using 00440 * default value. (No Visitors) 00441 * see also: rf::visitors 00442 * \param split split functor to be used to calculate each split 00443 * use rf_default() for using default value. (GiniSplit) 00444 * see also: rf::split 00445 * \param stop 00446 * predicate to be used to calculate each split 00447 * use rf_default() for using default value. (EarlyStoppStd) 00448 * \param random RandomNumberGenerator to be used. Use 00449 * rf_default() to use default value.(RandomMT19337) 00450 * 00451 * 00452 */ 00453 template <class U, class C1, 00454 class U2,class C2, 00455 class Split_t, 00456 class Stop_t, 00457 class Visitor_t, 00458 class Random_t> 00459 void learn( MultiArrayView<2, U, C1> const & features, 00460 MultiArrayView<2, U2,C2> const & response, 00461 Visitor_t visitor, 00462 Split_t split, 00463 Stop_t stop, 00464 Random_t const & random); 00465 00466 template <class U, class C1, 00467 class U2,class C2, 00468 class Split_t, 00469 class Stop_t, 00470 class Visitor_t> 00471 void learn( MultiArrayView<2, U, C1> const & features, 00472 MultiArrayView<2, U2,C2> const & response, 00473 Visitor_t visitor, 00474 Split_t split, 00475 Stop_t stop) 00476 00477 { 00478 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00479 learn( features, 00480 response, 00481 visitor, 00482 split, 00483 stop, 00484 rnd); 00485 } 00486 00487 template <class U, class C1, class U2,class C2, class Visitor_t> 00488 void learn( MultiArrayView<2, U, C1> const & features, 00489 MultiArrayView<2, U2,C2> const & labels, 00490 Visitor_t visitor) 00491 { 00492 learn( features, 00493 labels, 00494 visitor, 00495 rf_default(), 00496 rf_default()); 00497 } 00498 00499 template <class U, class C1, class U2,class C2, 00500 class Visitor_t, class Split_t> 00501 void learn( MultiArrayView<2, U, C1> const & features, 00502 MultiArrayView<2, U2,C2> const & labels, 00503 Visitor_t visitor, 00504 Split_t split) 00505 { 00506 learn( features, 00507 labels, 00508 visitor, 00509 split, 00510 rf_default()); 00511 } 00512 00513 /**\brief learn on data with default configuration 00514 * 00515 * \param features a N x M matrix containing N samples with M 00516 * features 00517 * \param labels a N x D matrix containing the corresponding 00518 * N labels. Current split functors assume D to 00519 * be 1 and ignore any additional columns. 00520 * this is not enforced to allow future support 00521 * for uncertain labels. 00522 * 00523 * learning is done with: 00524 * 00525 * \sa rf::split, EarlyStoppStd 00526 * 00527 * - Randomly seeded random number generator 00528 * - default gini split functor as described by Breiman 00529 * - default The standard early stopping criterion 00530 */ 00531 template <class U, class C1, class U2,class C2> 00532 void learn( MultiArrayView<2, U, C1> const & features, 00533 MultiArrayView<2, U2,C2> const & labels) 00534 { 00535 learn( features, 00536 labels, 00537 rf_default(), 00538 rf_default(), 00539 rf_default()); 00540 } 00541 /*\}*/ 00542 00543 00544 00545 /**\name prediction 00546 */ 00547 /*\{*/ 00548 /** \brief predict a label given a feature. 00549 * 00550 * \param features: a 1 by featureCount matrix containing 00551 * data point to be predicted (this only works in 00552 * classification setting) 00553 * \param stop: early stopping criterion 00554 * \return double value representing class. You can use the 00555 * predictLabels() function together with the 00556 * rf.external_parameter().class_type_ attribute 00557 * to get back the same type used during learning. 00558 */ 00559 template <class U, class C, class Stop> 00560 LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const; 00561 00562 template <class U, class C> 00563 LabelType predictLabel(MultiArrayView<2, U, C>const & features) 00564 { 00565 return predictLabel(features, rf_default()); 00566 } 00567 /** \brief predict a label with features and class priors 00568 * 00569 * \param features: same as above. 00570 * \param prior: iterator to prior weighting of classes 00571 * \return sam as above. 00572 */ 00573 template <class U, class C> 00574 LabelType predictLabel(MultiArrayView<2, U, C> const & features, 00575 ArrayVectorView<double> prior) const; 00576 00577 /** \brief predict multiple labels with given features 00578 * 00579 * \param features: a n by featureCount matrix containing 00580 * data point to be predicted (this only works in 00581 * classification setting) 00582 * \param labels: a n by 1 matrix passed by reference to store 00583 * output. 00584 */ 00585 template <class U, class C1, class T, class C2> 00586 void predictLabels(MultiArrayView<2, U, C1>const & features, 00587 MultiArrayView<2, T, C2> & labels) const 00588 { 00589 vigra_precondition(features.shape(0) == labels.shape(0), 00590 "RandomForest::predictLabels(): Label array has wrong size."); 00591 for(int k=0; k<features.shape(0); ++k) 00592 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default())); 00593 } 00594 00595 template <class U, class C1, class T, class C2, class Stop> 00596 void predictLabels(MultiArrayView<2, U, C1>const & features, 00597 MultiArrayView<2, T, C2> & labels, 00598 Stop & stop) const 00599 { 00600 vigra_precondition(features.shape(0) == labels.shape(0), 00601 "RandomForest::predictLabels(): Label array has wrong size."); 00602 for(int k=0; k<features.shape(0); ++k) 00603 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop)); 00604 } 00605 /** \brief predict the class probabilities for multiple labels 00606 * 00607 * \param features same as above 00608 * \param prob a n x class_count_ matrix. passed by reference to 00609 * save class probabilities 00610 * \param stop earlystopping criterion 00611 * \sa EarlyStopping 00612 */ 00613 template <class U, class C1, class T, class C2, class Stop> 00614 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00615 MultiArrayView<2, T, C2> & prob, 00616 Stop & stop) const; 00617 template <class T1,class T2, class C> 00618 void predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 00619 MultiArrayView<2, T2, C> & prob); 00620 00621 /** \brief predict the class probabilities for multiple labels 00622 * 00623 * \param features same as above 00624 * \param prob a n x class_count_ matrix. passed by reference to 00625 * save class probabilities 00626 */ 00627 template <class U, class C1, class T, class C2> 00628 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00629 MultiArrayView<2, T, C2> & prob) const 00630 { 00631 predictProbabilities(features, prob, rf_default()); 00632 } 00633 00634 template <class U, class C1, class T, class C2> 00635 void predictRaw(MultiArrayView<2, U, C1>const & features, 00636 MultiArrayView<2, T, C2> & prob) const; 00637 00638 00639 /*\}*/ 00640 00641 }; 00642 00643 00644 template <class LabelType, class PreprocessorTag> 00645 template<class U,class C1, 00646 class U2, class C2, 00647 class Split_t, 00648 class Stop_t, 00649 class Visitor_t, 00650 class Random_t> 00651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features, 00652 MultiArrayView<2,U2,C2> const & response, 00653 int new_start_index, 00654 Visitor_t visitor_, 00655 Split_t split_, 00656 Stop_t stop_, 00657 Random_t & random, 00658 bool adjust_thresholds) 00659 { 00660 online_visitor_.activate(); 00661 online_visitor_.adjust_thresholds=adjust_thresholds; 00662 00663 using namespace rf; 00664 //typedefs 00665 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t; 00666 typedef UniformIntRandomFunctor<Random_t> 00667 RandFunctor_t; 00668 // default values and initialization 00669 // Value Chooser chooses second argument as value if first argument 00670 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00671 // it - just smile and wave. 00672 00673 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00674 Default_Stop_t default_stop(options_); 00675 typename RF_CHOOSER(Stop_t)::type stop 00676 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00677 Default_Split_t default_split; 00678 typename RF_CHOOSER(Split_t)::type split 00679 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00680 rf::visitors::StopVisiting stopvisiting; 00681 typedef rf::visitors::detail::VisitorNode 00682 <rf::visitors::OnlineLearnVisitor, 00683 typename RF_CHOOSER(Visitor_t)::type> 00684 IntermedVis; 00685 IntermedVis 00686 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00687 #undef RF_CHOOSER 00688 00689 // Preprocess the data to get something the split functor can work 00690 // with. Also fill the ext_param structure by preprocessing 00691 // option parameters that could only be completely evaluated 00692 // when the training data is known. 00693 ext_param_.class_count_=0; 00694 Preprocessor_t preprocessor( features, response, 00695 options_, ext_param_); 00696 00697 // Make stl compatible random functor. 00698 RandFunctor_t randint ( random); 00699 00700 // Give the Split functor information about the data. 00701 split.set_external_parameters(ext_param_); 00702 stop.set_external_parameters(ext_param_); 00703 00704 00705 //Create poisson samples 00706 PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_)); 00707 00708 //TODO: visitors for online learning 00709 //visitor.visit_at_beginning(*this, preprocessor); 00710 00711 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00712 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00713 { 00714 online_visitor_.tree_id=ii; 00715 poisson_sampler.sample(); 00716 std::map<int,int> leaf_parents; 00717 leaf_parents.clear(); 00718 //Get all the leaf nodes for that sample 00719 for(int s=0;s<poisson_sampler.numOfSamples();++s) 00720 { 00721 int sample=poisson_sampler[s]; 00722 online_visitor_.current_label=preprocessor.response()(sample,0); 00723 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent; 00724 int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_); 00725 00726 00727 //Add to the list for that leaf 00728 online_visitor_.add_to_index_list(ii,leaf,sample); 00729 //TODO: Class count? 00730 //Store parent 00731 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0) 00732 { 00733 leaf_parents[leaf]=online_visitor_.last_node_id; 00734 } 00735 } 00736 00737 00738 std::map<int,int>::iterator leaf_iterator; 00739 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator) 00740 { 00741 int leaf=leaf_iterator->first; 00742 int parent=leaf_iterator->second; 00743 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf]; 00744 ArrayVector<Int32> indeces; 00745 indeces.clear(); 00746 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]); 00747 StackEntry_t stack_entry(indeces.begin(), 00748 indeces.end(), 00749 ext_param_.class_count_); 00750 00751 00752 if(parent!=-1) 00753 { 00754 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf) 00755 { 00756 stack_entry.leftParent=parent; 00757 } 00758 else 00759 { 00760 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong"); 00761 stack_entry.rightParent=parent; 00762 } 00763 } 00764 //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf); 00765 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1); 00766 //Now, the last one moved onto leaf 00767 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf); 00768 //Now it should be classified correctly! 00769 } 00770 00771 /*visitor 00772 .visit_after_tree( *this, 00773 preprocessor, 00774 poisson_sampler, 00775 stack_entry, 00776 ii);*/ 00777 } 00778 00779 //visitor.visit_at_end(*this, preprocessor); 00780 online_visitor_.deactivate(); 00781 } 00782 00783 template<class LabelType, class PreprocessorTag> 00784 template<class U,class C1, 00785 class U2, class C2, 00786 class Split_t, 00787 class Stop_t, 00788 class Visitor_t, 00789 class Random_t> 00790 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features, 00791 MultiArrayView<2,U2,C2> const & response, 00792 int treeId, 00793 Visitor_t visitor_, 00794 Split_t split_, 00795 Stop_t stop_, 00796 Random_t & random) 00797 { 00798 using namespace rf; 00799 00800 00801 typedef UniformIntRandomFunctor<Random_t> 00802 RandFunctor_t; 00803 00804 // See rf_preprocessing.hxx for more info on this 00805 ext_param_.class_count_=0; 00806 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00807 00808 // default values and initialization 00809 // Value Chooser chooses second argument as value if first argument 00810 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00811 // it - just smile and wave. 00812 00813 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00814 Default_Stop_t default_stop(options_); 00815 typename RF_CHOOSER(Stop_t)::type stop 00816 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00817 Default_Split_t default_split; 00818 typename RF_CHOOSER(Split_t)::type split 00819 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00820 rf::visitors::StopVisiting stopvisiting; 00821 typedef rf::visitors::detail::VisitorNode 00822 <rf::visitors::OnlineLearnVisitor, 00823 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00824 IntermedVis 00825 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00826 #undef RF_CHOOSER 00827 vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled"); 00828 online_visitor_.activate(); 00829 00830 // Make stl compatible random functor. 00831 RandFunctor_t randint ( random); 00832 00833 // Preprocess the data to get something the split functor can work 00834 // with. Also fill the ext_param structure by preprocessing 00835 // option parameters that could only be completely evaluated 00836 // when the training data is known. 00837 Preprocessor_t preprocessor( features, response, 00838 options_, ext_param_); 00839 00840 // Give the Split functor information about the data. 00841 split.set_external_parameters(ext_param_); 00842 stop.set_external_parameters(ext_param_); 00843 00844 /**\todo replace this crappy class out. It uses function pointers. 00845 * and is making code slower according to me. 00846 * Comment from Nathan: This is copied from Rahul, so me=Rahul 00847 */ 00848 Sampler<Random_t > sampler(preprocessor.strata().begin(), 00849 preprocessor.strata().end(), 00850 detail::make_sampler_opt(options_) 00851 .sampleSize(ext_param().actual_msample_), 00852 random); 00853 //initialize First region/node/stack entry 00854 sampler 00855 .sample(); 00856 00857 StackEntry_t 00858 first_stack_entry( sampler.sampledIndices().begin(), 00859 sampler.sampledIndices().end(), 00860 ext_param_.class_count_); 00861 first_stack_entry 00862 .set_oob_range( sampler.oobIndices().begin(), 00863 sampler.oobIndices().end()); 00864 online_visitor_.reset_tree(treeId); 00865 online_visitor_.tree_id=treeId; 00866 trees_[treeId].reset(); 00867 trees_[treeId] 00868 .learn( preprocessor.features(), 00869 preprocessor.response(), 00870 first_stack_entry, 00871 split, 00872 stop, 00873 visitor, 00874 randint); 00875 visitor 00876 .visit_after_tree( *this, 00877 preprocessor, 00878 sampler, 00879 first_stack_entry, 00880 treeId); 00881 00882 online_visitor_.deactivate(); 00883 } 00884 00885 template <class LabelType, class PreprocessorTag> 00886 template <class U, class C1, 00887 class U2,class C2, 00888 class Split_t, 00889 class Stop_t, 00890 class Visitor_t, 00891 class Random_t> 00892 void RandomForest<LabelType, PreprocessorTag>:: 00893 learn( MultiArrayView<2, U, C1> const & features, 00894 MultiArrayView<2, U2,C2> const & response, 00895 Visitor_t visitor_, 00896 Split_t split_, 00897 Stop_t stop_, 00898 Random_t const & random) 00899 { 00900 using namespace rf; 00901 //this->reset(); 00902 //typedefs 00903 typedef UniformIntRandomFunctor<Random_t> 00904 RandFunctor_t; 00905 00906 // See rf_preprocessing.hxx for more info on this 00907 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00908 00909 vigra_precondition(features.shape(0) == response.shape(0), 00910 "RandomForest::learn(): shape mismatch between features and response."); 00911 00912 // default values and initialization 00913 // Value Chooser chooses second argument as value if first argument 00914 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00915 // it - just smile and wave. 00916 00917 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00918 Default_Stop_t default_stop(options_); 00919 typename RF_CHOOSER(Stop_t)::type stop 00920 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00921 Default_Split_t default_split; 00922 typename RF_CHOOSER(Split_t)::type split 00923 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00924 rf::visitors::StopVisiting stopvisiting; 00925 typedef rf::visitors::detail::VisitorNode< 00926 rf::visitors::OnlineLearnVisitor, 00927 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00928 IntermedVis 00929 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00930 #undef RF_CHOOSER 00931 if(options_.prepare_online_learning_) 00932 online_visitor_.activate(); 00933 else 00934 online_visitor_.deactivate(); 00935 00936 00937 // Make stl compatible random functor. 00938 RandFunctor_t randint ( random); 00939 00940 00941 // Preprocess the data to get something the split functor can work 00942 // with. Also fill the ext_param structure by preprocessing 00943 // option parameters that could only be completely evaluated 00944 // when the training data is known. 00945 Preprocessor_t preprocessor( features, response, 00946 options_, ext_param_); 00947 00948 // Give the Split functor information about the data. 00949 split.set_external_parameters(ext_param_); 00950 stop.set_external_parameters(ext_param_); 00951 00952 00953 //initialize trees. 00954 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_)); 00955 00956 Sampler<Random_t > sampler(preprocessor.strata().begin(), 00957 preprocessor.strata().end(), 00958 detail::make_sampler_opt(options_) 00959 .sampleSize(ext_param().actual_msample_), 00960 random); 00961 00962 visitor.visit_at_beginning(*this, preprocessor); 00963 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00964 00965 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00966 { 00967 //initialize First region/node/stack entry 00968 sampler 00969 .sample(); 00970 StackEntry_t 00971 first_stack_entry( sampler.sampledIndices().begin(), 00972 sampler.sampledIndices().end(), 00973 ext_param_.class_count_); 00974 first_stack_entry 00975 .set_oob_range( sampler.oobIndices().begin(), 00976 sampler.oobIndices().end()); 00977 trees_[ii] 00978 .learn( preprocessor.features(), 00979 preprocessor.response(), 00980 first_stack_entry, 00981 split, 00982 stop, 00983 visitor, 00984 randint); 00985 visitor 00986 .visit_after_tree( *this, 00987 preprocessor, 00988 sampler, 00989 first_stack_entry, 00990 ii); 00991 } 00992 00993 visitor.visit_at_end(*this, preprocessor); 00994 // Only for online learning? 00995 online_visitor_.deactivate(); 00996 } 00997 00998 00999 01000 01001 template <class LabelType, class Tag> 01002 template <class U, class C, class Stop> 01003 LabelType RandomForest<LabelType, Tag> 01004 ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const 01005 { 01006 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 01007 "RandomForestn::predictLabel():" 01008 " Too few columns in feature matrix."); 01009 vigra_precondition(rowCount(features) == 1, 01010 "RandomForestn::predictLabel():" 01011 " Feature matrix must have a singlerow."); 01012 typedef MultiArrayShape<2>::type Shp; 01013 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0); 01014 LabelType d; 01015 predictProbabilities(features, garbage_prediction_, stop); 01016 ext_param_.to_classlabel(argMax(garbage_prediction_), d); 01017 return d; 01018 } 01019 01020 01021 //Same thing as above with priors for each label !!! 01022 template <class LabelType, class PreprocessorTag> 01023 template <class U, class C> 01024 LabelType RandomForest<LabelType, PreprocessorTag> 01025 ::predictLabel( MultiArrayView<2, U, C> const & features, 01026 ArrayVectorView<double> priors) const 01027 { 01028 using namespace functor; 01029 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 01030 "RandomForestn::predictLabel(): Too few columns in feature matrix."); 01031 vigra_precondition(rowCount(features) == 1, 01032 "RandomForestn::predictLabel():" 01033 " Feature matrix must have a single row."); 01034 Matrix<double> prob(1,ext_param_.class_count_); 01035 predictProbabilities(features, prob); 01036 std::transform( prob.begin(), prob.end(), 01037 priors.begin(), prob.begin(), 01038 Arg1()*Arg2()); 01039 LabelType d; 01040 ext_param_.to_classlabel(argMax(prob), d); 01041 return d; 01042 } 01043 01044 template<class LabelType,class PreprocessorTag> 01045 template <class T1,class T2, class C> 01046 void RandomForest<LabelType,PreprocessorTag> 01047 ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 01048 MultiArrayView<2, T2, C> & prob) 01049 { 01050 //Features are n xp 01051 //prob is n x NumOfLabel probability for each feature in each class 01052 01053 vigra_precondition(rowCount(predictionSet.features) == rowCount(prob), 01054 "RandomFroest::predictProbabilities():" 01055 " Feature matrix and probability matrix size mismatch."); 01056 // num of features must be bigger than num of features in Random forest training 01057 // but why bigger? 01058 vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_, 01059 "RandomForestn::predictProbabilities():" 01060 " Too few columns in feature matrix."); 01061 vigra_precondition( columnCount(prob) 01062 == (MultiArrayIndex)ext_param_.class_count_, 01063 "RandomForestn::predictProbabilities():" 01064 " Probability matrix must have as many columns as there are classes."); 01065 prob.init(0.0); 01066 //store total weights 01067 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0); 01068 //Go through all trees 01069 int set_id=-1; 01070 for(int k=0; k<options_.tree_count_; ++k) 01071 { 01072 set_id=(set_id+1) % predictionSet.indices[0].size(); 01073 typedef std::set<SampleRange<T1> > my_set; 01074 typedef typename my_set::iterator set_it; 01075 //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it; 01076 //Build a stack with all the ranges we have 01077 std::vector<std::pair<int,set_it> > stack; 01078 stack.clear(); 01079 for(set_it i=predictionSet.ranges[set_id].begin(); 01080 i!=predictionSet.ranges[set_id].end();++i) 01081 stack.push_back(std::pair<int,set_it>(2,i)); 01082 //get weights predicted by single tree 01083 int num_decisions=0; 01084 while(!stack.empty()) 01085 { 01086 set_it range=stack.back().second; 01087 int index=stack.back().first; 01088 stack.pop_back(); 01089 ++num_decisions; 01090 01091 if(trees_[k].isLeafNode(trees_[k].topology_[index])) 01092 { 01093 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_, 01094 trees_[k].parameters_, 01095 index).prob_begin(); 01096 for(int i=range->start;i!=range->end;++i) 01097 { 01098 //update votecount. 01099 for(int l=0; l<ext_param_.class_count_; ++l) 01100 { 01101 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l]; 01102 //every weight in totalWeight. 01103 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l]; 01104 } 01105 } 01106 } 01107 01108 else 01109 { 01110 if(trees_[k].topology_[index]!=i_ThresholdNode) 01111 { 01112 throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes"); 01113 } 01114 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index); 01115 if(range->min_boundaries[node.column()]>=node.threshold()) 01116 { 01117 //Everything goes to right child 01118 stack.push_back(std::pair<int,set_it>(node.child(1),range)); 01119 continue; 01120 } 01121 if(range->max_boundaries[node.column()]<node.threshold()) 01122 { 01123 //Everything goes to the left child 01124 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01125 continue; 01126 } 01127 //We have to split at this node 01128 SampleRange<T1> new_range=*range; 01129 new_range.min_boundaries[node.column()]=FLT_MAX; 01130 range->max_boundaries[node.column()]=-FLT_MAX; 01131 new_range.start=new_range.end=range->end; 01132 int i=range->start; 01133 while(i!=range->end) 01134 { 01135 //Decide for range->indices[i] 01136 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold()) 01137 { 01138 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()], 01139 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01140 --range->end; 01141 --new_range.start; 01142 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]); 01143 01144 } 01145 else 01146 { 01147 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()], 01148 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01149 ++i; 01150 } 01151 } 01152 //The old one ... 01153 if(range->start==range->end) 01154 { 01155 predictionSet.ranges[set_id].erase(range); 01156 } 01157 else 01158 { 01159 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01160 } 01161 //And the new one ... 01162 if(new_range.start!=new_range.end) 01163 { 01164 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range); 01165 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first)); 01166 } 01167 } 01168 } 01169 predictionSet.cumulativePredTime[k]=num_decisions; 01170 } 01171 for(unsigned int i=0;i<totalWeights.size();++i) 01172 { 01173 double test=0.0; 01174 //Normalise votes in each row by total VoteCount (totalWeight 01175 for(int l=0; l<ext_param_.class_count_; ++l) 01176 { 01177 test+=prob(i,l); 01178 prob(i, l) /= totalWeights[i]; 01179 } 01180 assert(test==totalWeights[i]); 01181 assert(totalWeights[i]>0.0); 01182 } 01183 } 01184 01185 template <class LabelType, class PreprocessorTag> 01186 template <class U, class C1, class T, class C2, class Stop_t> 01187 void RandomForest<LabelType, PreprocessorTag> 01188 ::predictProbabilities(MultiArrayView<2, U, C1>const & features, 01189 MultiArrayView<2, T, C2> & prob, 01190 Stop_t & stop_) const 01191 { 01192 //Features are n xp 01193 //prob is n x NumOfLabel probability for each feature in each class 01194 01195 vigra_precondition(rowCount(features) == rowCount(prob), 01196 "RandomForestn::predictProbabilities():" 01197 " Feature matrix and probability matrix size mismatch."); 01198 01199 // num of features must be bigger than num of features in Random forest training 01200 // but why bigger? 01201 vigra_precondition( columnCount(features) >= ext_param_.column_count_, 01202 "RandomForestn::predictProbabilities():" 01203 " Too few columns in feature matrix."); 01204 vigra_precondition( columnCount(prob) 01205 == (MultiArrayIndex)ext_param_.class_count_, 01206 "RandomForestn::predictProbabilities():" 01207 " Probability matrix must have as many columns as there are classes."); 01208 01209 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 01210 Default_Stop_t default_stop(options_); 01211 typename RF_CHOOSER(Stop_t)::type & stop 01212 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 01213 #undef RF_CHOOSER 01214 stop.set_external_parameters(ext_param_, tree_count()); 01215 prob.init(NumericTraits<T>::zero()); 01216 /* This code was originally there for testing early stopping 01217 * - we wanted the order of the trees to be randomized 01218 if(tree_indices_.size() != 0) 01219 { 01220 std::random_shuffle(tree_indices_.begin(), 01221 tree_indices_.end()); 01222 } 01223 */ 01224 //Classify for each row. 01225 for(int row=0; row < rowCount(features); ++row) 01226 { 01227 ArrayVector<double>::const_iterator weights; 01228 01229 //totalWeight == totalVoteCount! 01230 double totalWeight = 0.0; 01231 01232 //Let each tree classify... 01233 for(int k=0; k<options_.tree_count_; ++k) 01234 { 01235 //get weights predicted by single tree 01236 weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row)); 01237 01238 //update votecount. 01239 int weighted = options_.predict_weighted_; 01240 for(int l=0; l<ext_param_.class_count_; ++l) 01241 { 01242 double cur_w = weights[l] * (weighted * (*(weights-1)) 01243 + (1-weighted)); 01244 prob(row, l) += (T)cur_w; 01245 //every weight in totalWeight. 01246 totalWeight += cur_w; 01247 } 01248 if(stop.after_prediction(weights, 01249 k, 01250 rowVector(prob, row), 01251 totalWeight)) 01252 { 01253 break; 01254 } 01255 } 01256 01257 //Normalise votes in each row by total VoteCount (totalWeight 01258 for(int l=0; l< ext_param_.class_count_; ++l) 01259 { 01260 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight); 01261 } 01262 } 01263 01264 } 01265 01266 template <class LabelType, class PreprocessorTag> 01267 template <class U, class C1, class T, class C2> 01268 void RandomForest<LabelType, PreprocessorTag> 01269 ::predictRaw(MultiArrayView<2, U, C1>const & features, 01270 MultiArrayView<2, T, C2> & prob) const 01271 { 01272 //Features are n xp 01273 //prob is n x NumOfLabel probability for each feature in each class 01274 01275 vigra_precondition(rowCount(features) == rowCount(prob), 01276 "RandomForestn::predictProbabilities():" 01277 " Feature matrix and probability matrix size mismatch."); 01278 01279 // num of features must be bigger than num of features in Random forest training 01280 // but why bigger? 01281 vigra_precondition( columnCount(features) >= ext_param_.column_count_, 01282 "RandomForestn::predictProbabilities():" 01283 " Too few columns in feature matrix."); 01284 vigra_precondition( columnCount(prob) 01285 == (MultiArrayIndex)ext_param_.class_count_, 01286 "RandomForestn::predictProbabilities():" 01287 " Probability matrix must have as many columns as there are classes."); 01288 01289 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 01290 prob.init(NumericTraits<T>::zero()); 01291 /* This code was originally there for testing early stopping 01292 * - we wanted the order of the trees to be randomized 01293 if(tree_indices_.size() != 0) 01294 { 01295 std::random_shuffle(tree_indices_.begin(), 01296 tree_indices_.end()); 01297 } 01298 */ 01299 //Classify for each row. 01300 for(int row=0; row < rowCount(features); ++row) 01301 { 01302 ArrayVector<double>::const_iterator weights; 01303 01304 //totalWeight == totalVoteCount! 01305 double totalWeight = 0.0; 01306 01307 //Let each tree classify... 01308 for(int k=0; k<options_.tree_count_; ++k) 01309 { 01310 //get weights predicted by single tree 01311 weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row)); 01312 01313 //update votecount. 01314 int weighted = options_.predict_weighted_; 01315 for(int l=0; l<ext_param_.class_count_; ++l) 01316 { 01317 double cur_w = weights[l] * (weighted * (*(weights-1)) 01318 + (1-weighted)); 01319 prob(row, l) += (T)cur_w; 01320 //every weight in totalWeight. 01321 totalWeight += cur_w; 01322 } 01323 } 01324 } 01325 prob/= options_.tree_count_; 01326 01327 } 01328 01329 //@} 01330 01331 } // namespace vigra 01332 01333 #include "random_forest/rf_algorithm.hxx" 01334 #endif // VIGRA_RANDOM_FOREST_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|