[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_ridge_split.hxx VIGRA

00001 //
00002 // C++ Interface: rf_ridge_split
00003 //
00004 // Description: 
00005 //
00006 //
00007 // Author: Nico Splitthoff <splitthoff@zg00103>, (C) 2009
00008 //
00009 // Copyright: See COPYING file that comes with this distribution
00010 //
00011 //
00012 #ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
00013 #define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
00014 //#include "rf_sampling.hxx"
00015 #include "../sampling.hxx"
00016 #include "rf_split.hxx"
00017 #include "rf_nodeproxy.hxx"
00018 #include "../regression.hxx"
00019 
00020 #define outm(v) std::cout << (#v) << ": " << (v) << std::endl;
00021 #define outm2(v) std::cout << (#v) << ": " << (v) << ", ";
00022 
00023 namespace vigra
00024 {
00025 
00026 /*template<>
00027 class Node<i_RegrNode>
00028 : public NodeBase
00029 {
00030 public:
00031     typedef NodeBase BT;
00032 
00033 
00034     Node(   BT::T_Container_type &   topology,
00035         BT::P_Container_type &   param,
00036             int nNumCols)
00037     :   BT(5+nNumCols,2+nNumCols,topology, param)
00038     {
00039         BT::typeID() = i_RegrNode;
00040     }
00041 
00042     Node(   BT::T_Container_type     &   topology,
00043         BT::P_Container_type     &   param,
00044         INT                   n             )
00045     :   BT(5,2,topology, param, n)
00046     {}
00047 
00048     Node( BT & node_)
00049     :   BT(5, 2, node_) 
00050     {}
00051 
00052     double& threshold()
00053     {
00054         return BT::parameters_begin()[1];
00055     }
00056 
00057     BT::INT& column()
00058     {
00059         return BT::column_data()[0];
00060     }
00061 
00062     template<class U, class C>
00063             BT::INT& next(MultiArrayView<2,U,C> const & feature)
00064             {
00065                 return (feature(0, column()) < threshold())? child(0):child(1);
00066             }
00067 };*/
00068 
00069 
00070 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
00071 class RidgeSplit: public SplitBase<Tag>
00072 {
00073   public:
00074 
00075 
00076     typedef SplitBase<Tag> SB;
00077 
00078     ArrayVector<Int32>          splitColumns;
00079     ColumnDecisionFunctor       bgfunc;
00080 
00081     double                      region_gini_;
00082     ArrayVector<double>         min_gini_;
00083     ArrayVector<std::ptrdiff_t> min_indices_;
00084     ArrayVector<double>         min_thresholds_;
00085 
00086     int                         bestSplitIndex;
00087     
00088     //dns
00089     bool            m_bDoScalingInTraining;
00090     bool            m_bDoBestLambdaBasedOnGini;
00091     
00092     RidgeSplit()
00093     :m_bDoScalingInTraining(true),
00094     m_bDoBestLambdaBasedOnGini(true)
00095     {
00096     }
00097 
00098     double minGini() const
00099     {
00100         return min_gini_[bestSplitIndex];
00101     }
00102     
00103     int bestSplitColumn() const
00104     {
00105         return splitColumns[bestSplitIndex];
00106     }
00107     
00108     bool& doScalingInTraining()
00109     { return m_bDoScalingInTraining; }
00110 
00111     bool& doBestLambdaBasedOnGini()
00112     { return m_bDoBestLambdaBasedOnGini; }
00113 
00114     template<class T>
00115             void set_external_parameters(ProblemSpec<T> const & in)
00116     {
00117         SB::set_external_parameters(in);        
00118         bgfunc.set_external_parameters(in);
00119         int featureCount_ = in.column_count_;
00120         splitColumns.resize(featureCount_);
00121         for(int k=0; k<featureCount_; ++k)
00122             splitColumns[k] = k;
00123         min_gini_.resize(featureCount_);
00124         min_indices_.resize(featureCount_);
00125         min_thresholds_.resize(featureCount_);
00126     }
00127 
00128     
00129     template<class T, class C, class T2, class C2, class Region, class Random>
00130     int findBestSplit(MultiArrayView<2, T, C> features,
00131                       MultiArrayView<2, T2, C2>  multiClassLabels,
00132                       Region & region,
00133                       ArrayVector<Region>& childRegions,
00134                       Random & randint)
00135     {
00136 
00137     //std::cerr << "Split called" << std::endl;
00138     typedef typename Region::IndexIterator IndexIterator;
00139     typedef typename MultiArrayView <2, T, C>::difference_type fShape;
00140     typedef typename MultiArrayView <2, T2, C2>::difference_type lShape;
00141     typedef typename MultiArrayView <2, double>::difference_type dShape;
00142         
00143         // calculate things that haven't been calculated yet. 
00144 //    std::cout << "start" << std::endl;
00145         if(std::accumulate(region.classCounts().begin(),
00146                            region.classCounts().end(), 0) != region.size())
00147         {
00148             RandomForestClassCounter<   MultiArrayView<2,T2, C2>, 
00149                                         ArrayVector<double> >
00150                 counter(multiClassLabels, region.classCounts());
00151             std::for_each(  region.begin(), region.end(), counter);
00152             region.classCountsIsValid = true;
00153         }
00154 
00155 
00156         // Is the region pure already?
00157         region_gini_ = GiniCriterion::impurity(region.classCounts(),
00158                 region.size());
00159         if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2)
00160             return  SB::makeTerminalNode(features, multiClassLabels, region, randint);
00161 
00162         // select columns  to be tried.
00163     for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
00164         std::swap(splitColumns[ii], 
00165             splitColumns[ii+ randint(features.shape(1) - ii)]);
00166 
00167     //do implicit binary case
00168     MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1));
00169       //number of classes should be >1, otherwise makeTerminalNode would have been called
00170       int nNumClasses=0;
00171       for(int n=0; n<(int)region.classCounts().size(); n++)
00172         nNumClasses+=((region.classCounts()[n]>0) ? 1:0);
00173       
00174       //convert to binary case
00175       if(nNumClasses>2)
00176       {
00177         int nMaxClass=0;
00178         int nMaxClassCounts=0;
00179         for(int n=0; n<(int)region.classCounts().size(); n++)
00180         {
00181           //this should occur in any case:
00182           //we had more than two non-zero classes in order to get here
00183           if(region.classCounts()[n]>nMaxClassCounts)
00184           {
00185         nMaxClassCounts=region.classCounts()[n];
00186         nMaxClass=n;
00187           }
00188         }
00189         
00190         //create binary labels
00191         for(int n=0; n<multiClassLabels.shape(0); n++)
00192           labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0);
00193       }
00194       else
00195         labels=multiClassLabels;
00196 
00197     //_do implicit binary case
00198     
00199     //uncomment this for some debugging
00200 /*  int nNumCases=features.shape(0);
00201 
00202     typedef typename MultiArrayView <2, int>::difference_type nShape;
00203     MultiArray<2, int> elementCounterArray(nShape(nNumCases,1),(int)0);
00204     int nUniqueElements=0;
00205     for(int n=0; n<region.size(); n++)
00206         elementCounterArray[region[n]]++;
00207     
00208     for(int n=0; n<nNumCases; n++)
00209         nUniqueElements+=((elementCounterArray[n]>0) ? 1:0);
00210     
00211     outm(nUniqueElements);
00212     nUniqueElements=0;
00213     MultiArray<2, int> elementCounterArray_oob(nShape(nNumCases,1),(int)0);
00214     for(int n=0; n<region.oob_size(); n++)
00215         elementCounterArray_oob[region.oob_begin()[n]]++;
00216     for(int n=0; n<nNumCases; n++)
00217         nUniqueElements+=((elementCounterArray_oob[n]>0) ? 1:0);
00218     outm(nUniqueElements);
00219     
00220     int notUniqueElements=0;
00221     for(int n=0; n<nNumCases; n++)
00222         notUniqueElements+=(((elementCounterArray_oob[n]>0) && (elementCounterArray[n]>0)) ? 1:0);
00223     outm(notUniqueElements);*/
00224     
00225     //outm(SB::ext_param_.actual_mtry_);
00226     
00227     
00228 //select submatrix of features for regression calculation
00229     MultiArrayView<2, T, C> cVector;
00230     MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_));
00231     //we only want -1 and 1 for this
00232     MultiArray<2, double> regrLabels(dShape(region.size(),1));
00233 
00234     //copy data into a vigra data structure and centre and scale while doing so
00235     MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1));
00236     MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1));
00237     for(int m=0; m<SB::ext_param_.actual_mtry_; m++)
00238     {
00239         cVector=columnVector(features, splitColumns[m]);
00240         
00241         //centre and scale the data
00242         double dCurrFeatureColumnMean=0.0;
00243         double dCurrFeatureColumnStd=1.0; //default value
00244         
00245         //calc mean on bootstrap data
00246         for(int n=0; n<region.size(); n++)
00247           dCurrFeatureColumnMean+=cVector[region[n]];
00248         dCurrFeatureColumnMean/=region.size();
00249         //calc scaling
00250         if(m_bDoScalingInTraining)
00251         {
00252           for(int n=0; n<region.size(); n++)
00253           {
00254               dCurrFeatureColumnStd+=
00255             (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean);
00256           }
00257           //unbiased std estimator:
00258           dCurrFeatureColumnStd=sqrt(dCurrFeatureColumnStd/(region.size()-1));
00259         }
00260         //dCurrFeatureColumnStd is still 1.0 if we didn't want scaling
00261         stdMatrix(m,0)=dCurrFeatureColumnStd;
00262         
00263         meanMatrix(m,0)=dCurrFeatureColumnMean;
00264         
00265         //get feature matrix, i.e. A (note that weighting is done automatically
00266         //since rows can occur multiple times -> bagging)
00267         for(int n=0; n<region.size(); n++)
00268             xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd;
00269     }
00270     
00271 //    std::cout << "middle" << std::endl;
00272     //get label vector (i.e. b)
00273     for(int n=0; n<region.size(); n++)
00274     {
00275         //we checked for/built binary case further up.
00276         //class labels should thus be either 0 or 1
00277         //-> convert to -1 and 1 for regression
00278         regrLabels(n,0)=((labels[region[n]]==0) ? -1:1);
00279     }
00280 
00281     MultiArray<2, double> dLambdas(dShape(11,1));
00282     int nCounter=0;
00283     for(int nLambda=-5; nLambda<=5; nLambda++)
00284         dLambdas[nCounter++]=pow(10.0,nLambda);
00285     //destination vector for regression coefficients; use same type as for xtrain
00286     MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11));
00287     ridgeRegressionSeries(xtrain,regrLabels,regrCoef,dLambdas);
00288     
00289     double dMaxRidgeSum=NumericTraits<double>::min();
00290     double dCurrRidgeSum;
00291     int nMaxRidgeSumAtLambdaInd=0;
00292 
00293     for(int nLambdaInd=0; nLambdaInd<11; nLambdaInd++)
00294     {
00295         //just sum up the correct answers
00296         //(correct means >=intercept for class 1, <intercept for class 0)
00297         //(intercept=0 or intercept=threshold based on gini)
00298         dCurrRidgeSum=0.0;
00299         
00300         //assemble projection vector
00301         MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
00302         
00303         for(int n=0; n<region.oob_size(); n++)
00304         {
00305           dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
00306           for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
00307           {
00308             dDistanceFromHyperplane(region.oob_begin()[n],0)+=
00309               features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd);
00310           }
00311         }
00312 
00313         double dCurrIntercept=0.0;
00314         if(m_bDoBestLambdaBasedOnGini)
00315         {
00316           //calculate gini index
00317           bgfunc(dDistanceFromHyperplane,
00318               labels, 
00319               region.oob_begin(), region.oob_end(), 
00320               region.classCounts());
00321           dCurrIntercept=bgfunc.min_threshold_;
00322         }
00323         else
00324         {
00325           for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
00326             dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd);
00327         }
00328         
00329         for(int n=0; n<region.oob_size(); n++)
00330         {
00331             //check what lambda performs best on oob data
00332             int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0);
00333             dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0);
00334         }
00335         if(dCurrRidgeSum>dMaxRidgeSum)
00336         {
00337             dMaxRidgeSum=dCurrRidgeSum;
00338             nMaxRidgeSumAtLambdaInd=nLambdaInd;
00339         }
00340     }
00341 
00342 //    std::cout << "middle2" << std::endl;
00343         //create a Node for output
00344         Node<i_HyperplaneNode>   node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data);
00345 
00346     //normalise coeffs
00347         //data was scaled (by 1.0 or by std) -> take into account
00348         MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1));
00349         for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
00350           dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0);
00351         
00352         //calc norm
00353         double dVnorm=columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm();
00354 
00355         for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
00356             node.weights()[n]=dCoeffVector(n,0)/dVnorm;
00357     //_normalise coeffs
00358     
00359     //save the columns
00360         node.column_data()[0]=SB::ext_param_.actual_mtry_;
00361         for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
00362             node.column_data()[n+1]=splitColumns[n];
00363 
00364     //assemble projection vector
00365         //careful here: "region" is a pointer to indices...
00366         //all the indices in "region" need to have valid data
00367         //convert from "region" space to original "feature" space
00368         MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
00369         
00370         for(int n=0; n<region.size(); n++)
00371         {
00372             dDistanceFromHyperplane(region[n],0)=0.0;
00373             for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
00374             {
00375               dDistanceFromHyperplane(region[n],0)+=
00376                features(region[n],m)*node.weights()[m];
00377             }
00378         }
00379         for(int n=0; n<region.oob_size(); n++)
00380         {
00381             dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
00382             for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
00383             {
00384               dDistanceFromHyperplane(region.oob_begin()[n],0)+=
00385             features(region.oob_begin()[n],m)*node.weights()[m];
00386             }
00387         }
00388         
00389     //calculate gini index
00390         bgfunc(dDistanceFromHyperplane,
00391             labels, 
00392             region.begin(), region.end(), 
00393             region.classCounts());
00394     
00395         // did not find any suitable split
00396     if(closeAtTolerance(bgfunc.min_gini_, NumericTraits<double>::max()))
00397         return  SB::makeTerminalNode(features, multiClassLabels, region, randint);
00398     
00399     //take gini threshold here due to scaling, normalisation, etc. of the coefficients
00400     node.intercept()    = bgfunc.min_threshold_;
00401     SB::node_ = node;
00402     
00403     childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
00404     childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
00405     childRegions[0].classCountsIsValid = true;
00406     childRegions[1].classCountsIsValid = true;
00407     
00408         // Save the ranges of the child stack entries.
00409     childRegions[0].setRange(   region.begin()  , region.begin() + bgfunc.min_index_   );
00410     childRegions[0].rule = region.rule;
00411     childRegions[0].rule.push_back(std::make_pair(1, 1.0));
00412     childRegions[1].setRange(   region.begin() + bgfunc.min_index_       , region.end()    );
00413     childRegions[1].rule = region.rule;
00414     childRegions[1].rule.push_back(std::make_pair(1, 1.0));
00415     
00416     //adjust oob ranges
00417 //    std::cout << "adjust oob" << std::endl;
00418     //sort the oobs
00419       std::sort(region.oob_begin(), region.oob_end(), 
00420             SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0));
00421             
00422       //find split index
00423       int nOOBindx;
00424       for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++)
00425       {
00426         if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept())
00427           break;
00428       }
00429 
00430       childRegions[0].set_oob_range(   region.oob_begin()  , region.oob_begin() + nOOBindx   );
00431       childRegions[1].set_oob_range(   region.oob_begin() + nOOBindx , region.oob_end() );
00432 
00433 //    std::cout << "end" << std::endl;
00434 //    outm2(region.oob_begin());outm2(nOOBindx);outm(region.oob_begin() + nOOBindx);
00435     //_adjust oob ranges
00436 
00437     return i_HyperplaneNode;
00438     }
00439 };
00440 
00441 /** Standard ridge regression split
00442  */
00443 typedef RidgeSplit<BestGiniOfColumn<GiniCriterion> >  GiniRidgeSplit;
00444 
00445 
00446 } //namespace vigra
00447 #endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Nov 6 2012)