[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_decisionTree.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RANDOM_FOREST_DT_HXX 00037 #define VIGRA_RANDOM_FOREST_DT_HXX 00038 00039 #include <algorithm> 00040 #include <map> 00041 #include <numeric> 00042 #include "vigra/multi_array.hxx" 00043 #include "vigra/mathutil.hxx" 00044 #include "vigra/array_vector.hxx" 00045 #include "vigra/sized_int.hxx" 00046 #include "vigra/matrix.hxx" 00047 #include "vigra/random.hxx" 00048 #include "vigra/functorexpression.hxx" 00049 #include <vector> 00050 00051 #include "rf_common.hxx" 00052 #include "rf_visitors.hxx" 00053 #include "rf_nodeproxy.hxx" 00054 namespace vigra 00055 { 00056 00057 namespace detail 00058 { 00059 // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!! 00060 /* decisiontree classifier. 00061 * 00062 * This class is actually meant to be used in conjunction with the 00063 * Random Forest Classifier 00064 * - My suggestion would be to use the RandomForest classifier with 00065 * following parameters instead of directly using this 00066 * class (Preprocessing default values etc is handled in there): 00067 * 00068 * \code 00069 * RandomForest decisionTree(RF_Traits::Options_t() 00070 * .features_per_node(RF_ALL) 00071 * .tree_count(1) ); 00072 * \endcode 00073 * 00074 * \todo remove the classCount and featurecount from the topology 00075 * array. Pass ext_param_ to the nodes! 00076 * \todo Use relative addressing of nodes? 00077 */ 00078 class DecisionTree 00079 { 00080 /* \todo make private?*/ 00081 public: 00082 00083 /* value type of container array. use whenever referencing it 00084 */ 00085 typedef Int32 TreeInt; 00086 00087 ArrayVector<TreeInt> topology_; 00088 ArrayVector<double> parameters_; 00089 00090 ProblemSpec<> ext_param_; 00091 unsigned int classCount_; 00092 00093 00094 public: 00095 /* \brief Create tree with parameters */ 00096 template<class T> 00097 DecisionTree(ProblemSpec<T> ext_param) 00098 : 00099 ext_param_(ext_param), 00100 classCount_(ext_param.class_count_) 00101 {} 00102 00103 /* clears all memory used. 00104 */ 00105 void reset(unsigned int classCount = 0) 00106 { 00107 if(classCount) 00108 classCount_ = classCount; 00109 topology_.clear(); 00110 parameters_.clear(); 00111 } 00112 00113 00114 /* learn a Tree 00115 * 00116 * \tparam StackEntry_t The Stackentry containing Node/StackEntry_t 00117 * Information used during learning. Each Split functor has a 00118 * Stack entry associated with it (Split_t::StackEntry_t) 00119 * \sa RandomForest::learn() 00120 */ 00121 template < class U, class C, 00122 class U2, class C2, 00123 class StackEntry_t, 00124 class Stop_t, 00125 class Split_t, 00126 class Visitor_t, 00127 class Random_t > 00128 void learn( MultiArrayView<2, U, C> const & features, 00129 MultiArrayView<2, U2, C2> const & labels, 00130 StackEntry_t const & stack_entry, 00131 Split_t split, 00132 Stop_t stop, 00133 Visitor_t & visitor, 00134 Random_t & randint); 00135 template < class U, class C, 00136 class U2, class C2, 00137 class StackEntry_t, 00138 class Stop_t, 00139 class Split_t, 00140 class Visitor_t, 00141 class Random_t> 00142 void continueLearn( MultiArrayView<2, U, C> const & features, 00143 MultiArrayView<2, U2, C2> const & labels, 00144 StackEntry_t const & stack_entry, 00145 Split_t split, 00146 Stop_t stop, 00147 Visitor_t & visitor, 00148 Random_t & randint, 00149 //an index to which the last created exterior node will be moved (because it is not used anymore) 00150 int garbaged_child=-1); 00151 00152 /* is a node a Leaf Node? */ 00153 inline bool isLeafNode(TreeInt in) const 00154 { 00155 return (in & LeafNodeTag) == LeafNodeTag; 00156 } 00157 00158 /* data driven traversal from root to leaf 00159 * 00160 * traverse through tree with data given in features. Use Visitors to 00161 * collect statistics along the way. 00162 */ 00163 template<class U, class C, class Visitor_t> 00164 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features, 00165 Visitor_t & visitor) const 00166 { 00167 TreeInt index = 2; 00168 while(!isLeafNode(topology_[index])) 00169 { 00170 visitor.visit_internal_node(*this, index, topology_[index],features); 00171 switch(topology_[index]) 00172 { 00173 case i_ThresholdNode: 00174 { 00175 Node<i_ThresholdNode> 00176 node(topology_, parameters_, index); 00177 index = node.next(features); 00178 break; 00179 } 00180 case i_HyperplaneNode: 00181 { 00182 Node<i_HyperplaneNode> 00183 node(topology_, parameters_, index); 00184 index = node.next(features); 00185 break; 00186 } 00187 case i_HypersphereNode: 00188 { 00189 Node<i_HypersphereNode> 00190 node(topology_, parameters_, index); 00191 index = node.next(features); 00192 break; 00193 } 00194 #if 0 00195 // for quick prototyping! has to be implemented. 00196 case i_VirtualNode: 00197 { 00198 Node<i_VirtualNode> 00199 node(topology_, parameters, index); 00200 index = node.next(features); 00201 } 00202 #endif 00203 default: 00204 vigra_fail("DecisionTree::getToLeaf():" 00205 "encountered unknown internal Node Type"); 00206 } 00207 } 00208 visitor.visit_external_node(*this, index, topology_[index],features); 00209 return index; 00210 } 00211 /* traverse tree to get statistics 00212 * 00213 * Tree is traversed in order the Nodes are in memory (i.e. if no 00214 * relearning//pruning scheme is utilized this will be pre order) 00215 */ 00216 template<class Visitor_t> 00217 void traverse_mem_order(Visitor_t visitor) const 00218 { 00219 TreeInt index = 2; 00220 Int32 ii = 0; 00221 while(index < topology_.size()) 00222 { 00223 if(isLeafNode(topology_[index])) 00224 { 00225 visitor 00226 .visit_external_node(*this, index, topology_[index]); 00227 } 00228 else 00229 { 00230 visitor 00231 ._internal_node(*this, index, topology_[index]); 00232 } 00233 } 00234 } 00235 00236 template<class Visitor_t> 00237 void traverse_post_order(Visitor_t visitor, TreeInt start = 2) const 00238 { 00239 typedef TinyVector<double, 2> Entry; 00240 std::vector<Entry > stack; 00241 std::vector<double> result_stack; 00242 stack.push_back(Entry(2, 0)); 00243 int addr; 00244 while(!stack.empty()) 00245 { 00246 addr = stack.back()[0]; 00247 NodeBase node(topology_, parameters_, stack.back()[0]); 00248 if(stack.back()[1] == 1) 00249 { 00250 stack.pop_back(); 00251 double leftRes = result_stack.back(); 00252 double rightRes = result_stack.back(); 00253 result_stack.pop_back(); 00254 result_stack.pop_back(); 00255 result_stack.push_back(rightRes+ leftRes); 00256 visitor.visit_internal_node(*this, 00257 addr, 00258 node.typeID(), 00259 rightRes+leftRes); 00260 } 00261 else 00262 { 00263 if(isLeafNode(node.typeID())) 00264 { 00265 visitor.visit_external_node(*this, 00266 addr, 00267 node.typeID(), 00268 node.weights()); 00269 stack.pop_back(); 00270 result_stack.push_back(node.weights()); 00271 } 00272 else 00273 { 00274 stack.back()[1] = 1; 00275 stack.push_back(Entry(node.child(0), 0)); 00276 stack.push_back(Entry(node.child(1), 0)); 00277 } 00278 00279 } 00280 } 00281 } 00282 00283 /* same thing as above, without any visitors */ 00284 template<class U, class C> 00285 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const 00286 { 00287 ::vigra::rf::visitors::StopVisiting stop; 00288 return getToLeaf(features, stop); 00289 } 00290 00291 00292 template <class U, class C> 00293 ArrayVector<double>::iterator 00294 predict(MultiArrayView<2, U, C> const & features) const 00295 { 00296 TreeInt nodeindex = getToLeaf(features); 00297 switch(topology_[nodeindex]) 00298 { 00299 case e_ConstProbNode: 00300 return Node<e_ConstProbNode>(topology_, 00301 parameters_, 00302 nodeindex).prob_begin(); 00303 break; 00304 #if 0 00305 //first make the Logistic regression stuff... 00306 case e_LogRegProbNode: 00307 return Node<e_LogRegProbNode>(topology_, 00308 parameters_, 00309 nodeindex).prob_begin(); 00310 #endif 00311 default: 00312 vigra_fail("DecisionTree::predict() :" 00313 " encountered unknown external Node Type"); 00314 } 00315 return ArrayVector<double>::iterator(); 00316 } 00317 00318 00319 00320 template <class U, class C> 00321 Int32 predictLabel(MultiArrayView<2, U, C> const & features) const 00322 { 00323 ArrayVector<double>::const_iterator weights = predict(features); 00324 return argMax(weights, weights+classCount_) - weights; 00325 } 00326 00327 }; 00328 00329 00330 template < class U, class C, 00331 class U2, class C2, 00332 class StackEntry_t, 00333 class Stop_t, 00334 class Split_t, 00335 class Visitor_t, 00336 class Random_t> 00337 void DecisionTree::learn( MultiArrayView<2, U, C> const & features, 00338 MultiArrayView<2, U2, C2> const & labels, 00339 StackEntry_t const & stack_entry, 00340 Split_t split, 00341 Stop_t stop, 00342 Visitor_t & visitor, 00343 Random_t & randint) 00344 { 00345 this->reset(); 00346 topology_.reserve(256); 00347 parameters_.reserve(256); 00348 topology_.push_back(features.shape(1)); 00349 topology_.push_back(classCount_); 00350 continueLearn(features,labels,stack_entry,split,stop,visitor,randint); 00351 } 00352 00353 template < class U, class C, 00354 class U2, class C2, 00355 class StackEntry_t, 00356 class Stop_t, 00357 class Split_t, 00358 class Visitor_t, 00359 class Random_t> 00360 void DecisionTree::continueLearn( MultiArrayView<2, U, C> const & features, 00361 MultiArrayView<2, U2, C2> const & labels, 00362 StackEntry_t const & stack_entry, 00363 Split_t split, 00364 Stop_t stop, 00365 Visitor_t & visitor, 00366 Random_t & randint, 00367 //an index to which the last created exterior node will be moved (because it is not used anymore) 00368 int garbaged_child) 00369 { 00370 std::vector<StackEntry_t> stack; 00371 stack.reserve(128); 00372 ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry); 00373 stack.push_back(stack_entry); 00374 size_t last_node_pos = 0; 00375 StackEntry_t top=stack.back(); 00376 00377 while(!stack.empty()) 00378 { 00379 00380 // Take an element of the stack. Obvious ain't it? 00381 top = stack.back(); 00382 stack.pop_back(); 00383 00384 // Make sure no data from the last round has remained in Pipeline; 00385 child_stack_entry[0].reset(); 00386 child_stack_entry[1].reset(); 00387 split.reset(); 00388 00389 00390 //Either the Stopping criterion decides that the split should 00391 //produce a Terminal Node or the Split itself decides what 00392 //kind of node to make 00393 TreeInt NodeID; 00394 00395 if(stop(top)) 00396 NodeID = split.makeTerminalNode(features, 00397 labels, 00398 top, 00399 randint); 00400 else 00401 { 00402 //TIC; 00403 NodeID = split.findBestSplit(features, 00404 labels, 00405 top, 00406 child_stack_entry, 00407 randint); 00408 //std::cerr << TOC <<" " << NodeID << ";" <<std::endl; 00409 } 00410 00411 // do some visiting yawn - just added this comment as eye candy 00412 // (looks odd otherwise with my syntax highlighting.... 00413 visitor.visit_after_split(*this, split, top, 00414 child_stack_entry[0], 00415 child_stack_entry[1], 00416 features, 00417 labels); 00418 00419 00420 // Update the Child entries of the parent 00421 // Using InteriorNodeBase because exact parameter form not needed. 00422 // look at the Node base before getting scared. 00423 last_node_pos = topology_.size(); 00424 if(top.leftParent != StackEntry_t::DecisionTreeNoParent) 00425 { 00426 NodeBase(topology_, 00427 parameters_, 00428 top.leftParent).child(0) = last_node_pos; 00429 } 00430 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent) 00431 { 00432 NodeBase(topology_, 00433 parameters_, 00434 top.rightParent).child(1) = last_node_pos; 00435 } 00436 00437 00438 // Supply the split functor with the Node type it requires. 00439 // set the address to which the children of this node should point 00440 // to and push back children onto stack 00441 if(!isLeafNode(NodeID)) 00442 { 00443 child_stack_entry[0].leftParent = topology_.size(); 00444 child_stack_entry[1].rightParent = topology_.size(); 00445 child_stack_entry[0].rightParent = -1; 00446 child_stack_entry[1].leftParent = -1; 00447 stack.push_back(child_stack_entry[0]); 00448 stack.push_back(child_stack_entry[1]); 00449 } 00450 00451 //copy the newly created node form the split functor to the 00452 //decision tree. 00453 NodeBase node(split.createNode(), topology_, parameters_ ); 00454 } 00455 if(garbaged_child!=-1) 00456 { 00457 Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos)); 00458 00459 int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size(); 00460 topology_.resize(last_node_pos); 00461 parameters_.resize(parameters_.size() - last_parameter_size); 00462 00463 if(top.leftParent != StackEntry_t::DecisionTreeNoParent) 00464 NodeBase(topology_, 00465 parameters_, 00466 top.leftParent).child(0) = garbaged_child; 00467 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent) 00468 NodeBase(topology_, 00469 parameters_, 00470 top.rightParent).child(1) = garbaged_child; 00471 } 00472 } 00473 00474 } //namespace detail 00475 00476 } //namespace vigra 00477 00478 #endif //VIGRA_RANDOM_FOREST_DT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|