[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_common.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RF_COMMON_HXX 00038 #define VIGRA_RF_COMMON_HXX 00039 00040 namespace vigra 00041 { 00042 00043 00044 struct ClassificationTag 00045 {}; 00046 00047 struct RegressionTag 00048 {}; 00049 00050 namespace detail 00051 { 00052 class RF_DEFAULT; 00053 } 00054 inline detail::RF_DEFAULT& rf_default(); 00055 namespace detail 00056 { 00057 00058 /* \brief singleton default tag class - 00059 * 00060 * use the rf_default() factory function to use the tag. 00061 * \sa RandomForest<>::learn(); 00062 */ 00063 class RF_DEFAULT 00064 { 00065 private: 00066 RF_DEFAULT() 00067 {} 00068 public: 00069 friend RF_DEFAULT& ::vigra::rf_default(); 00070 00071 /** ok workaround for automatic choice of the decisiontree 00072 * stackentry. 00073 */ 00074 }; 00075 00076 /* \brief chooses between default type and type supplied 00077 * 00078 * This is an internal class and you shouldn't really care about it. 00079 * Just pass on used in RandomForest.learn() 00080 * Usage: 00081 *\code 00082 * // example: use container type supplied by user or ArrayVector if 00083 * // rf_default() was specified as argument; 00084 * template<class Container_t> 00085 * void do_some_foo(Container_t in) 00086 * { 00087 * typedef ArrayVector<int> Default_Container_t; 00088 * Default_Container_t default_value; 00089 * Value_Chooser<Container_t, Default_Container_t> 00090 * choose(in, default_value); 00091 * 00092 * // if the user didn't care and the in was of type 00093 * // RF_DEFAULT then default_value is used. 00094 * do_some_more_foo(choose.value()); 00095 * } 00096 * Value_Chooser choose_val<Type, Default_Type> 00097 *\endcode 00098 */ 00099 template<class T, class C> 00100 class Value_Chooser 00101 { 00102 public: 00103 typedef T type; 00104 static T & choose(T & t, C &) 00105 { 00106 return t; 00107 } 00108 }; 00109 00110 template<class C> 00111 class Value_Chooser<detail::RF_DEFAULT, C> 00112 { 00113 public: 00114 typedef C type; 00115 00116 static C & choose(detail::RF_DEFAULT &, C & c) 00117 { 00118 return c; 00119 } 00120 }; 00121 00122 00123 00124 00125 } //namespace detail 00126 00127 00128 /**\brief factory function to return a RF_DEFAULT tag 00129 * \sa RandomForest<>::learn() 00130 */ 00131 detail::RF_DEFAULT& rf_default() 00132 { 00133 static detail::RF_DEFAULT result; 00134 return result; 00135 } 00136 00137 /** tags used with the RandomForestOptions class 00138 * \sa RF_Traits::Option_t 00139 */ 00140 enum RF_OptionTag { RF_EQUAL, 00141 RF_PROPORTIONAL, 00142 RF_EXTERNAL, 00143 RF_NONE, 00144 RF_FUNCTION, 00145 RF_LOG, 00146 RF_SQRT, 00147 RF_CONST, 00148 RF_ALL}; 00149 00150 00151 /** \addtogroup MachineLearning 00152 **/ 00153 //@{ 00154 00155 /**\brief Options object for the random forest 00156 * 00157 * usage: 00158 * RandomForestOptions a = RandomForestOptions() 00159 * .param1(value1) 00160 * .param2(value2) 00161 * ... 00162 * 00163 * This class only contains options/parameters that are not problem 00164 * dependent. The ProblemSpec class contains methods to set class weights 00165 * if necessary. 00166 * 00167 * Note that the return value of all methods is *this which makes 00168 * concatenating of options as above possible. 00169 */ 00170 class RandomForestOptions 00171 { 00172 public: 00173 /**\name sampling options*/ 00174 /*\{*/ 00175 // look at the member access functions for documentation 00176 double training_set_proportion_; 00177 int training_set_size_; 00178 int (*training_set_func_)(int); 00179 RF_OptionTag 00180 training_set_calc_switch_; 00181 00182 bool sample_with_replacement_; 00183 RF_OptionTag 00184 stratification_method_; 00185 00186 00187 /**\name general random forest options 00188 * 00189 * these usually will be used by most split functors and 00190 * stopping predicates 00191 */ 00192 /*\{*/ 00193 RF_OptionTag mtry_switch_; 00194 int mtry_; 00195 int (*mtry_func_)(int) ; 00196 00197 bool predict_weighted_; 00198 int tree_count_; 00199 int min_split_node_size_; 00200 bool prepare_online_learning_; 00201 /*\}*/ 00202 00203 typedef ArrayVector<double> double_array; 00204 typedef std::map<std::string, double_array> map_type; 00205 00206 int serialized_size() const 00207 { 00208 return 12; 00209 } 00210 00211 00212 bool operator==(RandomForestOptions & rhs) const 00213 { 00214 bool result = true; 00215 #define COMPARE(field) result = result && (this->field == rhs.field); 00216 COMPARE(training_set_proportion_); 00217 COMPARE(training_set_size_); 00218 COMPARE(training_set_calc_switch_); 00219 COMPARE(sample_with_replacement_); 00220 COMPARE(stratification_method_); 00221 COMPARE(mtry_switch_); 00222 COMPARE(mtry_); 00223 COMPARE(tree_count_); 00224 COMPARE(min_split_node_size_); 00225 COMPARE(predict_weighted_); 00226 #undef COMPARE 00227 00228 return result; 00229 } 00230 bool operator!=(RandomForestOptions & rhs_) const 00231 { 00232 return !(*this == rhs_); 00233 } 00234 template<class Iter> 00235 void unserialize(Iter const & begin, Iter const & end) 00236 { 00237 Iter iter = begin; 00238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 00239 "RandomForestOptions::unserialize():" 00240 "wrong number of parameters"); 00241 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00242 PULL(training_set_proportion_, double); 00243 PULL(training_set_size_, int); 00244 ++iter; //PULL(training_set_func_, double); 00245 PULL(training_set_calc_switch_, (RF_OptionTag)int); 00246 PULL(sample_with_replacement_, 0 != ); 00247 PULL(stratification_method_, (RF_OptionTag)int); 00248 PULL(mtry_switch_, (RF_OptionTag)int); 00249 PULL(mtry_, int); 00250 ++iter; //PULL(mtry_func_, double); 00251 PULL(tree_count_, int); 00252 PULL(min_split_node_size_, int); 00253 PULL(predict_weighted_, 0 !=); 00254 #undef PULL 00255 } 00256 template<class Iter> 00257 void serialize(Iter const & begin, Iter const & end) const 00258 { 00259 Iter iter = begin; 00260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 00261 "RandomForestOptions::serialize():" 00262 "wrong number of parameters"); 00263 #define PUSH(item_) *iter = double(item_); ++iter; 00264 PUSH(training_set_proportion_); 00265 PUSH(training_set_size_); 00266 if(training_set_func_ != 0) 00267 { 00268 PUSH(1); 00269 } 00270 else 00271 { 00272 PUSH(0); 00273 } 00274 PUSH(training_set_calc_switch_); 00275 PUSH(sample_with_replacement_); 00276 PUSH(stratification_method_); 00277 PUSH(mtry_switch_); 00278 PUSH(mtry_); 00279 if(mtry_func_ != 0) 00280 { 00281 PUSH(1); 00282 } 00283 else 00284 { 00285 PUSH(0); 00286 } 00287 PUSH(tree_count_); 00288 PUSH(min_split_node_size_); 00289 PUSH(predict_weighted_); 00290 #undef PUSH 00291 } 00292 00293 void make_from_map(map_type & in) // -> const: .operator[] -> .find 00294 { 00295 typedef MultiArrayShape<2>::type Shp; 00296 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 00297 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 00298 PULL(training_set_proportion_,double); 00299 PULL(training_set_size_, int); 00300 PULL(mtry_, int); 00301 PULL(tree_count_, int); 00302 PULL(min_split_node_size_, int); 00303 PULLBOOL(sample_with_replacement_, bool); 00304 PULLBOOL(prepare_online_learning_, bool); 00305 PULLBOOL(predict_weighted_, bool); 00306 00307 PULL(training_set_calc_switch_, (RF_OptionTag)(int)); 00308 00309 PULL(stratification_method_, (RF_OptionTag)(int)); 00310 PULL(mtry_switch_, (RF_OptionTag)(int)); 00311 00312 /*don't pull*/ 00313 //PULL(mtry_func_!=0, int); 00314 //PULL(training_set_func,int); 00315 #undef PULL 00316 #undef PULLBOOL 00317 } 00318 void make_map(map_type & in) const 00319 { 00320 typedef MultiArrayShape<2>::type Shp; 00321 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_)); 00322 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0)); 00323 PUSH(training_set_proportion_,double); 00324 PUSH(training_set_size_, int); 00325 PUSH(mtry_, int); 00326 PUSH(tree_count_, int); 00327 PUSH(min_split_node_size_, int); 00328 PUSH(sample_with_replacement_, bool); 00329 PUSH(prepare_online_learning_, bool); 00330 PUSH(predict_weighted_, bool); 00331 00332 PUSH(training_set_calc_switch_, RF_OptionTag); 00333 PUSH(stratification_method_, RF_OptionTag); 00334 PUSH(mtry_switch_, RF_OptionTag); 00335 00336 PUSHFUNC(mtry_func_, int); 00337 PUSHFUNC(training_set_func_,int); 00338 #undef PUSH 00339 #undef PUSHFUNC 00340 } 00341 00342 00343 /**\brief create a RandomForestOptions object with default initialisation. 00344 * 00345 * look at the other member functions for more information on default 00346 * values 00347 */ 00348 RandomForestOptions() 00349 : 00350 training_set_proportion_(1.0), 00351 training_set_size_(0), 00352 training_set_func_(0), 00353 training_set_calc_switch_(RF_PROPORTIONAL), 00354 sample_with_replacement_(true), 00355 stratification_method_(RF_NONE), 00356 mtry_switch_(RF_SQRT), 00357 mtry_(0), 00358 mtry_func_(0), 00359 predict_weighted_(false), 00360 tree_count_(256), 00361 min_split_node_size_(1), 00362 prepare_online_learning_(false) 00363 {} 00364 00365 /**\brief specify stratification strategy 00366 * 00367 * default: RF_NONE 00368 * possible values: RF_EQUAL, RF_PROPORTIONAL, 00369 * RF_EXTERNAL, RF_NONE 00370 * RF_EQUAL: get equal amount of samples per class. 00371 * RF_PROPORTIONAL: sample proportional to fraction of class samples 00372 * in population 00373 * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object 00374 * has been set externally. (defunct) 00375 */ 00376 RandomForestOptions & use_stratification(RF_OptionTag in) 00377 { 00378 vigra_precondition(in == RF_EQUAL || 00379 in == RF_PROPORTIONAL || 00380 in == RF_EXTERNAL || 00381 in == RF_NONE, 00382 "RandomForestOptions::use_stratification()" 00383 "input must be RF_EQUAL, RF_PROPORTIONAL," 00384 "RF_EXTERNAL or RF_NONE"); 00385 stratification_method_ = in; 00386 return *this; 00387 } 00388 00389 RandomForestOptions & prepare_online_learning(bool in) 00390 { 00391 prepare_online_learning_=in; 00392 return *this; 00393 } 00394 00395 /**\brief sample from training population with or without replacement? 00396 * 00397 * <br> Default: true 00398 */ 00399 RandomForestOptions & sample_with_replacement(bool in) 00400 { 00401 sample_with_replacement_ = in; 00402 return *this; 00403 } 00404 00405 /**\brief specify the fraction of the total number of samples 00406 * used per tree for learning. 00407 * 00408 * This value should be in [0.0 1.0] if sampling without 00409 * replacement has been specified. 00410 * 00411 * <br> default : 1.0 00412 */ 00413 RandomForestOptions & samples_per_tree(double in) 00414 { 00415 training_set_proportion_ = in; 00416 training_set_calc_switch_ = RF_PROPORTIONAL; 00417 return *this; 00418 } 00419 00420 /**\brief directly specify the number of samples per tree 00421 */ 00422 RandomForestOptions & samples_per_tree(int in) 00423 { 00424 training_set_size_ = in; 00425 training_set_calc_switch_ = RF_CONST; 00426 return *this; 00427 } 00428 00429 /**\brief use external function to calculate the number of samples each 00430 * tree should be learnt with. 00431 * 00432 * \param in function pointer that takes the number of rows in the 00433 * learning data and outputs the number samples per tree. 00434 */ 00435 RandomForestOptions & samples_per_tree(int (*in)(int)) 00436 { 00437 training_set_func_ = in; 00438 training_set_calc_switch_ = RF_FUNCTION; 00439 return *this; 00440 } 00441 00442 /**\brief weight each tree with number of samples in that node 00443 */ 00444 RandomForestOptions & predict_weighted() 00445 { 00446 predict_weighted_ = true; 00447 return *this; 00448 } 00449 00450 /**\brief use built in mapping to calculate mtry 00451 * 00452 * Use one of the built in mappings to calculate mtry from the number 00453 * of columns in the input feature data. 00454 * \param in possible values: RF_LOG, RF_SQRT or RF_ALL 00455 * <br> default: RF_SQRT. 00456 */ 00457 RandomForestOptions & features_per_node(RF_OptionTag in) 00458 { 00459 vigra_precondition(in == RF_LOG || 00460 in == RF_SQRT|| 00461 in == RF_ALL, 00462 "RandomForestOptions()::features_per_node():" 00463 "input must be of type RF_LOG or RF_SQRT"); 00464 mtry_switch_ = in; 00465 return *this; 00466 } 00467 00468 /**\brief Set mtry to a constant value 00469 * 00470 * mtry is the number of columns/variates/variables randomly chosen 00471 * to select the best split from. 00472 * 00473 */ 00474 RandomForestOptions & features_per_node(int in) 00475 { 00476 mtry_ = in; 00477 mtry_switch_ = RF_CONST; 00478 return *this; 00479 } 00480 00481 /**\brief use a external function to calculate mtry 00482 * 00483 * \param in function pointer that takes int (number of columns 00484 * of the and outputs int (mtry) 00485 */ 00486 RandomForestOptions & features_per_node(int(*in)(int)) 00487 { 00488 mtry_func_ = in; 00489 mtry_switch_ = RF_FUNCTION; 00490 return *this; 00491 } 00492 00493 /** How many trees to create? 00494 * 00495 * <br> Default: 255. 00496 */ 00497 RandomForestOptions & tree_count(int in) 00498 { 00499 tree_count_ = in; 00500 return *this; 00501 } 00502 00503 /**\brief Number of examples required for a node to be split. 00504 * 00505 * When the number of examples in a node is below this number, 00506 * the node is not split even if class separation is not yet perfect. 00507 * Instead, the node returns the proportion of each class 00508 * (among the remaining examples) during the prediction phase. 00509 * <br> Default: 1 (complete growing) 00510 */ 00511 RandomForestOptions & min_split_node_size(int in) 00512 { 00513 min_split_node_size_ = in; 00514 return *this; 00515 } 00516 }; 00517 00518 00519 /** \brief problem types 00520 */ 00521 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER}; 00522 00523 00524 /** \brief problem specification class for the random forest. 00525 * 00526 * This class contains all the problem specific parameters the random 00527 * forest needs for learning. Specification of an instance of this class 00528 * is optional as all necessary fields will be computed prior to learning 00529 * if not specified. 00530 * 00531 * if needed usage is similar to that of RandomForestOptions 00532 */ 00533 00534 template<class LabelType = double> 00535 class ProblemSpec 00536 { 00537 00538 00539 public: 00540 00541 /** \brief problem class 00542 */ 00543 00544 typedef LabelType Label_t; 00545 ArrayVector<Label_t> classes; 00546 typedef ArrayVector<double> double_array; 00547 typedef std::map<std::string, double_array> map_type; 00548 00549 int column_count_; // number of features 00550 int class_count_; // number of classes 00551 int row_count_; // number of samples 00552 00553 int actual_mtry_; // mtry used in training 00554 int actual_msample_; // number if in-bag samples per tree 00555 00556 Problem_t problem_type_; // classification or regression 00557 00558 int used_; // this ProblemSpec is valid 00559 ArrayVector<double> class_weights_; // if classes have different importance 00560 int is_weighted_; // class_weights_ are used 00561 double precision_; // termination criterion for regression loss 00562 int response_size_; 00563 00564 template<class T> 00565 void to_classlabel(int index, T & out) const 00566 { 00567 out = T(classes[index]); 00568 } 00569 template<class T> 00570 int to_classIndex(T index) const 00571 { 00572 return std::find(classes.begin(), classes.end(), index) - classes.begin(); 00573 } 00574 00575 #define EQUALS(field) field(rhs.field) 00576 ProblemSpec(ProblemSpec const & rhs) 00577 : 00578 EQUALS(column_count_), 00579 EQUALS(class_count_), 00580 EQUALS(row_count_), 00581 EQUALS(actual_mtry_), 00582 EQUALS(actual_msample_), 00583 EQUALS(problem_type_), 00584 EQUALS(used_), 00585 EQUALS(class_weights_), 00586 EQUALS(is_weighted_), 00587 EQUALS(precision_), 00588 EQUALS(response_size_) 00589 { 00590 std::back_insert_iterator<ArrayVector<Label_t> > 00591 iter(classes); 00592 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00593 } 00594 #undef EQUALS 00595 #define EQUALS(field) field(rhs.field) 00596 template<class T> 00597 ProblemSpec(ProblemSpec<T> const & rhs) 00598 : 00599 EQUALS(column_count_), 00600 EQUALS(class_count_), 00601 EQUALS(row_count_), 00602 EQUALS(actual_mtry_), 00603 EQUALS(actual_msample_), 00604 EQUALS(problem_type_), 00605 EQUALS(used_), 00606 EQUALS(class_weights_), 00607 EQUALS(is_weighted_), 00608 EQUALS(precision_), 00609 EQUALS(response_size_) 00610 { 00611 std::back_insert_iterator<ArrayVector<Label_t> > 00612 iter(classes); 00613 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00614 } 00615 #undef EQUALS 00616 00617 #define EQUALS(field) (this->field = rhs.field); 00618 ProblemSpec & operator=(ProblemSpec const & rhs) 00619 { 00620 EQUALS(column_count_); 00621 EQUALS(class_count_); 00622 EQUALS(row_count_); 00623 EQUALS(actual_mtry_); 00624 EQUALS(actual_msample_); 00625 EQUALS(problem_type_); 00626 EQUALS(used_); 00627 EQUALS(is_weighted_); 00628 EQUALS(precision_); 00629 EQUALS(response_size_) 00630 class_weights_.clear(); 00631 std::back_insert_iterator<ArrayVector<double> > 00632 iter2(class_weights_); 00633 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00634 classes.clear(); 00635 std::back_insert_iterator<ArrayVector<Label_t> > 00636 iter(classes); 00637 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00638 return *this; 00639 } 00640 00641 template<class T> 00642 ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs) 00643 { 00644 EQUALS(column_count_); 00645 EQUALS(class_count_); 00646 EQUALS(row_count_); 00647 EQUALS(actual_mtry_); 00648 EQUALS(actual_msample_); 00649 EQUALS(problem_type_); 00650 EQUALS(used_); 00651 EQUALS(is_weighted_); 00652 EQUALS(precision_); 00653 EQUALS(response_size_) 00654 class_weights_.clear(); 00655 std::back_insert_iterator<ArrayVector<double> > 00656 iter2(class_weights_); 00657 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00658 classes.clear(); 00659 std::back_insert_iterator<ArrayVector<Label_t> > 00660 iter(classes); 00661 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00662 return *this; 00663 } 00664 #undef EQUALS 00665 00666 template<class T> 00667 bool operator==(ProblemSpec<T> const & rhs) 00668 { 00669 bool result = true; 00670 #define COMPARE(field) result = result && (this->field == rhs.field); 00671 COMPARE(column_count_); 00672 COMPARE(class_count_); 00673 COMPARE(row_count_); 00674 COMPARE(actual_mtry_); 00675 COMPARE(actual_msample_); 00676 COMPARE(problem_type_); 00677 COMPARE(is_weighted_); 00678 COMPARE(precision_); 00679 COMPARE(used_); 00680 COMPARE(class_weights_); 00681 COMPARE(classes); 00682 COMPARE(response_size_) 00683 #undef COMPARE 00684 return result; 00685 } 00686 00687 bool operator!=(ProblemSpec & rhs) 00688 { 00689 return !(*this == rhs); 00690 } 00691 00692 00693 size_t serialized_size() const 00694 { 00695 return 9 + class_count_ *int(is_weighted_+1); 00696 } 00697 00698 00699 template<class Iter> 00700 void unserialize(Iter const & begin, Iter const & end) 00701 { 00702 Iter iter = begin; 00703 vigra_precondition(end - begin >= 9, 00704 "ProblemSpec::unserialize():" 00705 "wrong number of parameters"); 00706 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00707 PULL(column_count_,int); 00708 PULL(class_count_, int); 00709 00710 vigra_precondition(end - begin >= 9 + class_count_, 00711 "ProblemSpec::unserialize(): 1"); 00712 PULL(row_count_, int); 00713 PULL(actual_mtry_,int); 00714 PULL(actual_msample_, int); 00715 PULL(problem_type_, Problem_t); 00716 PULL(is_weighted_, int); 00717 PULL(used_, int); 00718 PULL(precision_, double); 00719 PULL(response_size_, int); 00720 if(is_weighted_) 00721 { 00722 vigra_precondition(end - begin == 9 + 2*class_count_, 00723 "ProblemSpec::unserialize(): 2"); 00724 class_weights_.insert(class_weights_.end(), 00725 iter, 00726 iter + class_count_); 00727 iter += class_count_; 00728 } 00729 classes.insert(classes.end(), iter, end); 00730 #undef PULL 00731 } 00732 00733 00734 template<class Iter> 00735 void serialize(Iter const & begin, Iter const & end) const 00736 { 00737 Iter iter = begin; 00738 vigra_precondition(end - begin == serialized_size(), 00739 "RandomForestOptions::serialize():" 00740 "wrong number of parameters"); 00741 #define PUSH(item_) *iter = double(item_); ++iter; 00742 PUSH(column_count_); 00743 PUSH(class_count_) 00744 PUSH(row_count_); 00745 PUSH(actual_mtry_); 00746 PUSH(actual_msample_); 00747 PUSH(problem_type_); 00748 PUSH(is_weighted_); 00749 PUSH(used_); 00750 PUSH(precision_); 00751 PUSH(response_size_); 00752 if(is_weighted_) 00753 { 00754 std::copy(class_weights_.begin(), 00755 class_weights_.end(), 00756 iter); 00757 iter += class_count_; 00758 } 00759 std::copy(classes.begin(), 00760 classes.end(), 00761 iter); 00762 #undef PUSH 00763 } 00764 00765 void make_from_map(map_type & in) // -> const: .operator[] -> .find 00766 { 00767 typedef MultiArrayShape<2>::type Shp; 00768 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 00769 PULL(column_count_,int); 00770 PULL(class_count_, int); 00771 PULL(row_count_, int); 00772 PULL(actual_mtry_,int); 00773 PULL(actual_msample_, int); 00774 PULL(problem_type_, (Problem_t)int); 00775 PULL(is_weighted_, int); 00776 PULL(used_, int); 00777 PULL(precision_, double); 00778 PULL(response_size_, int); 00779 class_weights_ = in["class_weights_"]; 00780 #undef PUSH 00781 } 00782 void make_map(map_type & in) const 00783 { 00784 typedef MultiArrayShape<2>::type Shp; 00785 #define PUSH(item_) in[#item_] = double_array(1, double(item_)); 00786 PUSH(column_count_); 00787 PUSH(class_count_) 00788 PUSH(row_count_); 00789 PUSH(actual_mtry_); 00790 PUSH(actual_msample_); 00791 PUSH(problem_type_); 00792 PUSH(is_weighted_); 00793 PUSH(used_); 00794 PUSH(precision_); 00795 PUSH(response_size_); 00796 in["class_weights_"] = class_weights_; 00797 #undef PUSH 00798 } 00799 00800 /**\brief set default values (-> values not set) 00801 */ 00802 ProblemSpec() 00803 : column_count_(0), 00804 class_count_(0), 00805 row_count_(0), 00806 actual_mtry_(0), 00807 actual_msample_(0), 00808 problem_type_(CHECKLATER), 00809 used_(false), 00810 is_weighted_(false), 00811 precision_(0.0), 00812 response_size_(1) 00813 {} 00814 00815 00816 ProblemSpec & column_count(int in) 00817 { 00818 column_count_ = in; 00819 return *this; 00820 } 00821 00822 /**\brief supply with class labels - 00823 * 00824 * the preprocessor will not calculate the labels needed in this case. 00825 */ 00826 template<class C_Iter> 00827 ProblemSpec & classes_(C_Iter begin, C_Iter end) 00828 { 00829 int size = end-begin; 00830 for(int k=0; k<size; ++k, ++begin) 00831 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin)); 00832 class_count_ = size; 00833 return *this; 00834 } 00835 00836 /** \brief supply with class weights - 00837 * 00838 * this is the only case where you would really have to 00839 * create a ProblemSpec object. 00840 */ 00841 template<class W_Iter> 00842 ProblemSpec & class_weights(W_Iter begin, W_Iter end) 00843 { 00844 class_weights_.insert(class_weights_.end(), begin, end); 00845 is_weighted_ = true; 00846 return *this; 00847 } 00848 00849 00850 00851 void clear() 00852 { 00853 used_ = false; 00854 classes.clear(); 00855 class_weights_.clear(); 00856 column_count_ = 0 ; 00857 class_count_ = 0; 00858 actual_mtry_ = 0; 00859 actual_msample_ = 0; 00860 problem_type_ = CHECKLATER; 00861 is_weighted_ = false; 00862 precision_ = 0.0; 00863 response_size_ = 0; 00864 00865 } 00866 00867 bool used() const 00868 { 00869 return used_ != 0; 00870 } 00871 }; 00872 00873 00874 //@} 00875 00876 00877 00878 /**\brief Standard early stopping criterion 00879 * 00880 * Stop if region.size() < min_split_node_size_; 00881 */ 00882 class EarlyStoppStd 00883 { 00884 public: 00885 int min_split_node_size_; 00886 00887 template<class Opt> 00888 EarlyStoppStd(Opt opt) 00889 : min_split_node_size_(opt.min_split_node_size_) 00890 {} 00891 00892 template<class T> 00893 void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false) 00894 {} 00895 00896 template<class Region> 00897 bool operator()(Region& region) 00898 { 00899 return region.size() < min_split_node_size_; 00900 } 00901 00902 template<class WeightIter, class T, class C> 00903 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */) 00904 { 00905 return false; 00906 } 00907 }; 00908 00909 00910 } // namespace vigra 00911 00912 #endif //VIGRA_RF_COMMON_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|