[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_earlystopping.hxx | ![]() |
00001 #ifndef RF_EARLY_STOPPING_P_HXX 00002 #define RF_EARLY_STOPPING_P_HXX 00003 #include <cmath> 00004 #include "rf_common.hxx" 00005 00006 namespace vigra 00007 { 00008 00009 #if 0 00010 namespace es_detail 00011 { 00012 template<class T> 00013 T power(T const & in, int n) 00014 { 00015 T result = NumericTraits<T>::one(); 00016 for(int ii = 0; ii < n ;++ii) 00017 result *= in; 00018 return result; 00019 } 00020 } 00021 #endif 00022 00023 /**Base class from which all EarlyStopping Functors derive. 00024 */ 00025 class StopBase 00026 { 00027 protected: 00028 ProblemSpec<> ext_param_; 00029 int tree_count_ ; 00030 bool is_weighted_; 00031 00032 public: 00033 template<class T> 00034 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00035 { 00036 ext_param_ = prob; 00037 is_weighted_ = is_weighted; 00038 tree_count_ = tree_count; 00039 } 00040 00041 #ifdef DOXYGEN 00042 /** called after the prediction of a tree was added to the total prediction 00043 * \param weightIter Iterator to the weights delivered by current tree. 00044 * \param k after kth tree 00045 * \param prob Total probability array 00046 * \param totalCt sum of probability array. 00047 */ 00048 template<class WeightIter, class T, class C> 00049 bool after_prediction(WeightIter weightIter, int k, MultiArrayView<2, T, C> const & prob , double totalCt) 00050 #else 00051 template<class WeightIter, class T, class C> 00052 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */) 00053 {return false;} 00054 #endif //DOXYGEN 00055 }; 00056 00057 00058 /**Stop predicting after a set number of trees 00059 */ 00060 class StopAfterTree : public StopBase 00061 { 00062 public: 00063 double max_tree_p; 00064 int max_tree_; 00065 typedef StopBase SB; 00066 00067 ArrayVector<double> depths; 00068 00069 /** Constructor 00070 * \param max_tree number of trees to be used for prediction 00071 */ 00072 StopAfterTree(double max_tree) 00073 : 00074 max_tree_p(max_tree) 00075 {} 00076 00077 template<class T> 00078 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00079 { 00080 max_tree_ = ceil(max_tree_p * tree_count); 00081 SB::set_external_parameters(prob, tree_count, is_weighted); 00082 } 00083 00084 template<class WeightIter, class T, class C> 00085 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */) 00086 { 00087 if(k == SB::tree_count_ -1) 00088 { 00089 depths.push_back(double(k+1)/double(SB::tree_count_)); 00090 return false; 00091 } 00092 if(k < max_tree_) 00093 return false; 00094 depths.push_back(double(k+1)/double(SB::tree_count_)); 00095 return true; 00096 } 00097 }; 00098 00099 /** Stop predicting after a certain amount of votes exceed certain proportion. 00100 * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 00101 * case weighted voting: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ; 00102 * (maximal number of votes possible in both cases) 00103 */ 00104 class StopAfterVoteCount : public StopBase 00105 { 00106 public: 00107 double proportion_; 00108 typedef StopBase SB; 00109 ArrayVector<double> depths; 00110 00111 /** Constructor 00112 * \param proportion specify proportion to be used. 00113 */ 00114 StopAfterVoteCount(double proportion) 00115 : 00116 proportion_(proportion) 00117 {} 00118 00119 template<class WeightIter, class T, class C> 00120 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */) 00121 { 00122 if(k == SB::tree_count_ -1) 00123 { 00124 depths.push_back(double(k+1)/double(SB::tree_count_)); 00125 return false; 00126 } 00127 00128 00129 if(SB::is_weighted_) 00130 { 00131 if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_) 00132 { 00133 depths.push_back(double(k+1)/double(SB::tree_count_)); 00134 return true; 00135 } 00136 } 00137 else 00138 { 00139 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00140 { 00141 depths.push_back(double(k+1)/double(SB::tree_count_)); 00142 return true; 00143 } 00144 } 00145 return false; 00146 } 00147 00148 }; 00149 00150 00151 /** Stop predicting if the 2norm of the probabilities does not change*/ 00152 class StopIfConverging : public StopBase 00153 00154 { 00155 public: 00156 double thresh_; 00157 int num_; 00158 MultiArray<2, double> last_; 00159 MultiArray<2, double> cur_; 00160 ArrayVector<double> depths; 00161 typedef StopBase SB; 00162 00163 /** Constructor 00164 * \param thresh: If the two norm of the probabilities changes less then thresh then stop 00165 * \param num : look at atleast num trees before stopping 00166 */ 00167 StopIfConverging(double thresh, int num = 10) 00168 : 00169 thresh_(thresh), 00170 num_(num) 00171 {} 00172 00173 template<class T> 00174 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00175 { 00176 last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00177 cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00178 SB::set_external_parameters(prob, tree_count, is_weighted); 00179 } 00180 template<class WeightIter, class T, class C> 00181 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> const & prob, double totalCt) 00182 { 00183 if(k == SB::tree_count_ -1) 00184 { 00185 depths.push_back(double(k+1)/double(SB::tree_count_)); 00186 return false; 00187 } 00188 if(k <= num_) 00189 { 00190 last_ = prob; 00191 last_/= last_.norm(1); 00192 return false; 00193 } 00194 else 00195 { 00196 cur_ = prob; 00197 cur_ /= cur_.norm(1); 00198 last_ -= cur_; 00199 double nrm = last_.norm(); 00200 if(nrm < thresh_) 00201 { 00202 depths.push_back(double(k+1)/double(SB::tree_count_)); 00203 return true; 00204 } 00205 else 00206 { 00207 last_ = cur_; 00208 } 00209 } 00210 return false; 00211 } 00212 }; 00213 00214 00215 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion 00216 * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 00217 * case weighted voting: stop if margin exceeds proportion * msample_ * SB::tree_count_ ; 00218 * (maximal number of votes possible in both cases) 00219 */ 00220 class StopIfMargin : public StopBase 00221 { 00222 public: 00223 double proportion_; 00224 typedef StopBase SB; 00225 ArrayVector<double> depths; 00226 00227 /** Constructor 00228 * \param proportion specify proportion to be used. 00229 */ 00230 StopIfMargin(double proportion) 00231 : 00232 proportion_(proportion) 00233 {} 00234 00235 template<class WeightIter, class T, class C> 00236 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */) 00237 { 00238 if(k == SB::tree_count_ -1) 00239 { 00240 depths.push_back(double(k+1)/double(SB::tree_count_)); 00241 return false; 00242 } 00243 int index = argMax(prob); 00244 double a = prob[argMax(prob)]; 00245 prob[argMax(prob)] = 0; 00246 double b = prob[argMax(prob)]; 00247 prob[index] = a; 00248 double margin = a - b; 00249 if(SB::is_weighted_) 00250 { 00251 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_) 00252 { 00253 depths.push_back(double(k+1)/double(SB::tree_count_)); 00254 return true; 00255 } 00256 } 00257 else 00258 { 00259 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00260 { 00261 depths.push_back(double(k+1)/double(SB::tree_count_)); 00262 return true; 00263 } 00264 } 00265 return false; 00266 } 00267 }; 00268 00269 00270 /**Probabilistic Stopping criterion (binomial test) 00271 * 00272 * Can only be used in a two class setting 00273 * 00274 * Stop if the Parameters estimated for the underlying binomial distribution 00275 * can be estimated with certainty over 1-alpha. 00276 * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion 00277 */ 00278 class StopIfBinTest : public StopBase 00279 { 00280 public: 00281 double alpha_; 00282 MultiArrayView<2, double> n_choose_k; 00283 /** Constructor 00284 * \param alpha specify alpha (=proportion) value for binomial test. 00285 * \param nck_ Matrix with precomputed values for n choose k 00286 * nck_(n, k) is n choose k. 00287 */ 00288 StopIfBinTest(double alpha, MultiArrayView<2, double> nck_) 00289 : 00290 alpha_(alpha), 00291 n_choose_k(nck_) 00292 {} 00293 typedef StopBase SB; 00294 00295 /**ArrayVector that will contain the fraction of trees that was visited before terminating 00296 */ 00297 ArrayVector<double> depths; 00298 00299 double binomial(int N, int k, double p) 00300 { 00301 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00302 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00303 } 00304 00305 template<class WeightIter, class T, class C> 00306 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00307 { 00308 if(k == SB::tree_count_ -1) 00309 { 00310 depths.push_back(double(k+1)/double(SB::tree_count_)); 00311 return false; 00312 } 00313 if(k < 10) 00314 { 00315 return false; 00316 } 00317 int index = argMax(prob); 00318 int n_a = prob[index]; 00319 int n_b = prob[(index+1)%2]; 00320 int n_tilde = (SB::tree_count_ - n_a + n_b); 00321 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde); 00322 vigra_precondition(p_a <= 1, "probability should be smaller than 1"); 00323 double cum_val = 0; 00324 int c = 0; 00325 // std::cerr << "prob: " << p_a << std::endl; 00326 if(n_a <= 0)n_a = 0; 00327 if(n_b <= 0)n_b = 0; 00328 for(int ii = 0; ii <= n_b + n_a;++ii) 00329 { 00330 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl; 00331 cum_val += binomial(n_b + n_a, ii, p_a); 00332 if(cum_val >= 1 -alpha_) 00333 { 00334 c = ii; 00335 break; 00336 } 00337 } 00338 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl; 00339 if(c < n_a) 00340 { 00341 depths.push_back(double(k+1)/double(SB::tree_count_)); 00342 return true; 00343 } 00344 00345 return false; 00346 } 00347 }; 00348 00349 /**Probabilistic Stopping criteria. (toChange) 00350 * 00351 * Can only be used in a two class setting 00352 * 00353 * Stop if the probability that the decision will change after seeing all trees falls under 00354 * a specified value alpha. 00355 * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion 00356 */ 00357 class StopIfProb : public StopBase 00358 { 00359 public: 00360 double alpha_; 00361 MultiArrayView<2, double> n_choose_k; 00362 00363 00364 /** Constructor 00365 * \param alpha specify alpha (=proportion) value 00366 * \param nck_ Matrix with precomputed values for n choose k 00367 * nck_(n, k) is n choose k. 00368 */ 00369 StopIfProb(double alpha, MultiArrayView<2, double> nck_) 00370 : 00371 alpha_(alpha), 00372 n_choose_k(nck_) 00373 {} 00374 typedef StopBase SB; 00375 /**ArrayVector that will contain the fraction of trees that was visited before terminating 00376 */ 00377 ArrayVector<double> depths; 00378 00379 double binomial(int N, int k, double p) 00380 { 00381 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00382 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00383 } 00384 00385 template<class WeightIter, class T, class C> 00386 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00387 { 00388 if(k == SB::tree_count_ -1) 00389 { 00390 depths.push_back(double(k+1)/double(SB::tree_count_)); 00391 return false; 00392 } 00393 if(k <= 10) 00394 { 00395 return false; 00396 } 00397 int index = argMax(prob); 00398 int n_a = prob[index]; 00399 int n_b = prob[(index+1)%2]; 00400 int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a; 00401 int n_tilde = SB::tree_count_ - (n_a +n_b); 00402 if(n_tilde <= 0) n_tilde = 0; 00403 if(n_needed <= 0) n_needed = 0; 00404 double p = 0; 00405 for(int ii = n_needed; ii < n_tilde; ++ii) 00406 p += binomial(n_tilde, ii, 0.5); 00407 00408 if(p >= 1-alpha_) 00409 { 00410 depths.push_back(double(k+1)/double(SB::tree_count_)); 00411 return true; 00412 } 00413 00414 return false; 00415 } 00416 }; 00417 } //namespace vigra; 00418 #endif //RF_EARLY_STOPPING_P_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|