[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
![]() |
RandomForest< LabelType, PreprocessorTag > Class Template Reference | ![]() |
#include <vigra/random_forest.hxx>
Public Member Functions | |
int | class_count () const |
return number of classes used while training. | |
int | column_count () const |
return number of features used while training. | |
int | feature_count () const |
return number of features used while training. | |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t , class Random_t > | |
void | reLearnTree (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random) |
int | tree_count () const |
return number of trees | |
Constructors | |
Note: No copy Constructor specified as no pointers are manipulated in this class | |
RandomForest (Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t()) | |
default constructor | |
template<class TopologyIterator , class ParameterIterator > | |
RandomForest (int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t()) | |
Create RF from external source. | |
Data Access | |
data access interface - usage of member variables is deprecated | |
ProblemSpec_t const & | ext_param () const |
return external parameters for viewing | |
void | set_ext_param (ProblemSpec_t const &in) |
set external parameters | |
Options_t & | set_options () |
access random forest options | |
Options_t const & | options () const |
access const random forest options | |
DecisionTree_t const & | tree (int index) const |
access const trees | |
DecisionTree_t & | tree (int index) |
access trees | |
Learning | |
Following functions differ in the degree of customization allowed | |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t , class Random_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random) |
learn on data with custom config and random number generator | |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop) |
template<class U , class C1 , class U2 , class C2 , class Visitor_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels, Visitor_t visitor) |
template<class U , class C1 , class U2 , class C2 , class Visitor_t , class Split_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels, Visitor_t visitor, Split_t split) |
template<class U , class C1 , class U2 , class C2 > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels) |
learn on data with default configuration | |
prediction | |
template<class U , class C , class Stop > | |
LabelType | predictLabel (MultiArrayView< 2, U, C >const &features, Stop &stop) const |
predict a label given a feature. | |
template<class U , class C > | |
LabelType | predictLabel (MultiArrayView< 2, U, C >const &features) |
template<class U , class C > | |
LabelType | predictLabel (MultiArrayView< 2, U, C > const &features, ArrayVectorView< double > prior) const |
predict a label with features and class priors | |
template<class U , class C1 , class T , class C2 > | |
void | predictLabels (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const |
predict multiple labels with given features | |
template<class U , class C1 , class T , class C2 , class Stop > | |
void | predictLabels (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const |
template<class U , class C1 , class T , class C2 , class Stop > | |
void | predictProbabilities (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const |
predict the class probabilities for multiple labels | |
template<class T1 , class T2 , class C > | |
void | predictProbabilities (OnlinePredictionSet< T1 > &predictionSet, MultiArrayView< 2, T2, C > &prob) |
template<class U , class C1 , class T , class C2 > | |
void | predictProbabilities (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const |
predict the class probabilities for multiple labels | |
template<class U , class C1 , class T , class C2 > | |
void | predictRaw (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const |
Protected Attributes | |
MultiArray< 2, double > | garbage_prediction_ |
Random Forest class
<PrprocessorTag | = ClassificationTag> Class used to preprocess the input while learning and predicting. Currently Available: ClassificationTag and RegressionTag. It is recommended to use Splitfunctor::Preprocessor_t while using custom splitfunctors as they may need the data to be in a different format. |
simple usage for classification (regression is not yet supported): look at RandomForest::learn() as well as RandomForestOptions() for additional options.
using namespace vigra; using namespace rf; typedef xxx feature_t; \\ replace xxx with whichever type typedef yyy label_t; \\ likewise // allocate the training data MultiArrayView<2, feature_t> f = get_training_features(); MultiArrayView<2, label_t> l = get_training_labels(); RandomForest<> rf; // construct visitor to calculate out-of-bag error visitors::OOB_Error oob_v; // perform training rf.learn(f, l, visitors::create_visitor(oob_v)); std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n"; // get features for new data to be used for prediction MultiArrayView<2, feature_t> pf = get_features(); // allocate space for the response (pf.shape(0) is the number of samples) MultiArrayView<2, label_t> prediction(pf.shape(0), 1); MultiArrayView<2, double> prob(pf.shape(0), rf.class_count()); // perform prediction on new data rf.predict_labels(pf, prediction); rf.predict_probabilities(pf, prob);
Additional information such as Variable Importance measures are accessed via Visitors defined in rf::visitors. Have a look at rf::split for other splitting methods.
RandomForest | ( | Options_t const & | options = Options_t() , |
ProblemSpec_t const & | ext_param = ProblemSpec_t() |
||
) |
default constructor
options | general options to the Random Forest. Must be of Type Options_t |
ext_param | problem specific values that can be supplied additionally. (class weights , labels etc) |
RandomForest | ( | int | treeCount, |
TopologyIterator | topology_begin, | ||
ParameterIterator | parameter_begin, | ||
ProblemSpec_t const & | problem_spec, | ||
Options_t const & | options = Options_t() |
||
) |
Create RF from external source.
treeCount | Number of trees to add. |
topology_begin | Iterator to a Container where the topology_ data of the trees are stored. Iterator should support at least treeCount forward iterations. (i.e. topology_end - topology_begin >= treeCount |
parameter_begin | iterator to a Container where the parameters_ data of the trees are stored. Iterator should support at least treeCount forward iterations. |
problem_spec | Extrinsic parameters that specify the problem e.g. ClassCount, featureCount etc. |
options | (optional) specify options used to train the original Random forest. This parameter is not used anywhere during prediction and thus is optional. |
ProblemSpec_t const& ext_param | ( | ) | const |
return external parameters for viewing
void set_ext_param | ( | ProblemSpec_t const & | in | ) |
set external parameters
in | external parameters to be set |
set external parameters explicitly. If Random Forest has not been trained the preprocessor will either ignore filling values set this way or will throw an exception if values specified manually do not match the value calculated & during the preparation step.
Options_t& set_options | ( | ) |
access random forest options
int column_count | ( | ) | const |
return number of features used while training.
deprecated. Use feature_count() instead.
MultiArray<2, double> garbage_prediction_ [mutable, protected] |
optimisation for predictLabels
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|