[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_online_prediction_set.hxx VIGRA

00001 #include "../multi_array.hxx"
00002 #include <set>
00003 #include <vector>
00004 
00005 namespace vigra
00006 {
00007 
00008 template<class T>
00009 struct SampleRange
00010 {
00011     SampleRange(int start,int end,int num_features)
00012     {
00013         this->start=start;
00014         this->end=end;
00015         this->min_boundaries.resize(num_features,-FLT_MAX);
00016         this->max_boundaries.resize(num_features,FLT_MAX);
00017     }
00018     
00019     int start;
00020     mutable int end;
00021     mutable std::vector<T> max_boundaries;
00022     mutable std::vector<T> min_boundaries;
00023     
00024     bool operator<(const SampleRange& o) const
00025     {
00026         return o.start<start;
00027     }
00028 };
00029 
00030 template<class T>
00031 class OnlinePredictionSet
00032 {
00033 public:
00034     template<class U>
00035     OnlinePredictionSet(MultiArrayView<2,T,U>& features,int num_sets)
00036     {
00037         this->features=features;
00038         std::vector<int> init(features.shape(0));
00039         for(unsigned int i=0;i<init.size();++i)
00040             init[i]=i;
00041         indices.resize(num_sets,init);
00042         std::set<SampleRange<T> > set_init;
00043         set_init.insert(SampleRange<T>(0,init.size(),features.shape(1)));
00044         ranges.resize(num_sets,set_init);
00045         cumulativePredTime.resize(num_sets,0);
00046     }
00047     
00048     int get_worsed_tree()
00049     {
00050         int result=0;
00051         for(unsigned int i=0;i<cumulativePredTime.size();++i)
00052         {
00053             if(cumulativePredTime[i]>cumulativePredTime[result])
00054             {
00055                 result=i;
00056             }
00057         }
00058         return result;
00059     }
00060     
00061     void reset_tree(int index)
00062     {
00063         index=index % ranges.size();
00064         std::set<SampleRange<T> > set_init;
00065         set_init.insert(SampleRange<T>(0,features.shape(0),features.shape(1)));
00066         ranges[index]=set_init;
00067         cumulativePredTime[index]=0;
00068     }
00069     
00070     std::vector<std::set<SampleRange<T> > > ranges;
00071     std::vector<std::vector<int> > indices;
00072     std::vector<int> cumulativePredTime;
00073     MultiArray<2,T> features;
00074 };
00075 
00076 }
00077 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Nov 6 2012)