Point Cloud Library (PCL) 1.13.1
Loading...
Searching...
No Matches
decision_tree_trainer.h
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/dt/decision_tree.h>
42#include <pcl/ml/dt/decision_tree_data_provider.h>
43#include <pcl/ml/feature_handler.h>
44#include <pcl/ml/stats_estimator.h>
45
46#include <vector>
47
48namespace pcl {
49
50/** Trainer for decision trees. */
51template <class FeatureType,
52 class DataSet,
53 class LabelType,
54 class ExampleIndex,
55 class NodeType>
56class PCL_EXPORTS DecisionTreeTrainer {
57
58public:
59 /** Constructor. */
61
62 /** Destructor. */
64
65 /** Sets the feature handler used to create and evaluate features.
66 *
67 * \param[in] feature_handler the feature handler
68 */
69 inline void
72 {
73 feature_handler_ = &feature_handler;
74 }
75
76 /** Sets the object for estimating the statistics for tree nodes.
77 *
78 * \param[in] stats_estimator the statistics estimator
79 */
80 inline void
83 {
84 stats_estimator_ = &stats_estimator;
85 }
86
87 /** Sets the maximum depth of the learned tree.
88 *
89 * \param[in] max_tree_depth maximum depth of the learned tree
90 */
91 inline void
92 setMaxTreeDepth(const std::size_t max_tree_depth)
93 {
94 max_tree_depth_ = max_tree_depth;
95 }
96
97 /** Sets the number of features used to find optimal decision features.
98 *
99 * \param[in] num_of_features the number of features
100 */
101 inline void
102 setNumOfFeatures(const std::size_t num_of_features)
103 {
104 num_of_features_ = num_of_features;
105 }
106
107 /** Sets the number of thresholds tested for finding the optimal decision
108 * threshold on the feature responses.
109 *
110 * \param[in] num_of_threshold the number of thresholds
111 */
112 inline void
113 setNumOfThresholds(const std::size_t num_of_threshold)
114 {
115 num_of_thresholds_ = num_of_threshold;
116 }
117
118 /** Sets the input data set used for training.
119 *
120 * \param[in] data_set the data set used for training
121 */
122 inline void
123 setTrainingDataSet(DataSet& data_set)
124 {
125 data_set_ = data_set;
126 }
127
128 /** Example indices that specify the data used for training.
129 *
130 * \param[in] examples the examples
131 */
132 inline void
133 setExamples(std::vector<ExampleIndex>& examples)
134 {
135 examples_ = examples;
136 }
137
138 /** Sets the label data corresponding to the example data.
139 *
140 * \param[in] label_data the label data
141 */
142 inline void
143 setLabelData(std::vector<LabelType>& label_data)
144 {
145 label_data_ = label_data;
146 }
147
148 /** Sets the minimum number of examples to continue growing a tree.
149 *
150 * \param[in] n number of examples
151 */
152 inline void
154 {
155 min_examples_for_split_ = n;
156 }
157
158 /** Specify the thresholds to be used when evaluating features.
159 *
160 * \param[in] thres the threshold values
161 */
162 void
163 setThresholds(std::vector<float>& thres)
164 {
165 thresholds_ = thres;
166 }
167
168 /** Specify the data provider.
169 *
170 * \param[in] dtdp the data provider that should implement getDatasetAndLabels()
171 * function
172 */
173 void
175 typename pcl::DecisionTreeTrainerDataProvider<FeatureType,
176 DataSet,
177 LabelType,
178 ExampleIndex,
179 NodeType>::Ptr& dtdp)
180 {
181 decision_tree_trainer_data_provider_ = dtdp;
182 }
183
184 /** Specify if the features are randomly generated at each split node.
185 *
186 * \param[in] b do it or not
187 */
188 void
190 {
191 random_features_at_split_node_ = b;
192 }
193
194 /** Trains a decision tree using the set training data and settings.
195 *
196 * \param[out] tree destination for the trained tree
197 */
198 void
199 train(DecisionTree<NodeType>& tree);
200
201protected:
202 /** Trains a decision tree node from the specified features, label data, and
203 * examples.
204 *
205 * \param[in] features the feature pool used for training
206 * \param[in] examples the examples used for training
207 * \param[in] label_data the label data corresponding to the examples
208 * \param[in] max_depth the maximum depth of the remaining tree
209 * \param[out] node the resulting node
210 */
211 void
212 trainDecisionTreeNode(std::vector<FeatureType>& features,
213 std::vector<ExampleIndex>& examples,
214 std::vector<LabelType>& label_data,
215 std::size_t max_depth,
216 NodeType& node);
217
218 /** Creates uniformly distributed thresholds over the range of the supplied
219 * values.
220 *
221 * \param[in] num_of_thresholds the number of thresholds to create
222 * \param[in] values the values for estimating the expected value range
223 * \param[out] thresholds the resulting thresholds
224 */
225 static void
226 createThresholdsUniform(const std::size_t num_of_thresholds,
227 std::vector<float>& values,
228 std::vector<float>& thresholds);
229
230private:
231 /** Maximum depth of the learned tree. */
232 std::size_t max_tree_depth_;
233 /** Number of features used to find optimal decision features. */
234 std::size_t num_of_features_;
235 /** Number of thresholds. */
236 std::size_t num_of_thresholds_;
237
238 /** FeatureHandler instance, responsible for creating and evaluating features. */
240 /** StatsEstimator instance, responsible for gathering stats about a node. */
242
243 /** The training data set. */
244 DataSet data_set_;
245 /** The label data. */
246 std::vector<LabelType> label_data_;
247 /** The example data. */
248 std::vector<ExampleIndex> examples_;
249
250 /** Minimum number of examples to split a node. */
251 std::size_t min_examples_for_split_;
252 /** Thresholds to be used instead of generating uniform distributed thresholds. */
253 std::vector<float> thresholds_;
254 /** The data provider which is called before training a specific tree, if pointer is
255 * NULL, then data_set_ is used. */
256 typename pcl::DecisionTreeTrainerDataProvider<FeatureType,
257 DataSet,
258 LabelType,
259 ExampleIndex,
260 NodeType>::Ptr
261 decision_tree_trainer_data_provider_;
262 /** If true, random features are generated at each node, otherwise, at start of
263 * training the tree */
264 bool random_features_at_split_node_;
265};
266
267} // namespace pcl
268
269#include <pcl/ml/impl/dt/decision_tree_trainer.hpp>
Class representing a decision tree.
Trainer for decision trees.
void setRandomFeaturesAtSplitNode(bool b)
Specify if the features are randomly generated at each split node.
void setStatsEstimator(pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator)
Sets the object for estimating the statistics for tree nodes.
void setDecisionTreeDataProvider(typename pcl::DecisionTreeTrainerDataProvider< FeatureType, DataSet, LabelType, ExampleIndex, NodeType >::Ptr &dtdp)
Specify the data provider.
void setMaxTreeDepth(const std::size_t max_tree_depth)
Sets the maximum depth of the learned tree.
void setNumOfThresholds(const std::size_t num_of_threshold)
Sets the number of thresholds tested for finding the optimal decision threshold on the feature respon...
void setFeatureHandler(pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler)
Sets the feature handler used to create and evaluate features.
void setTrainingDataSet(DataSet &data_set)
Sets the input data set used for training.
void setMinExamplesForSplit(std::size_t n)
Sets the minimum number of examples to continue growing a tree.
virtual ~DecisionTreeTrainer()
Destructor.
void setThresholds(std::vector< float > &thres)
Specify the thresholds to be used when evaluating features.
void setLabelData(std::vector< LabelType > &label_data)
Sets the label data corresponding to the example data.
void setExamples(std::vector< ExampleIndex > &examples)
Example indices that specify the data used for training.
void setNumOfFeatures(const std::size_t num_of_features)
Sets the number of features used to find optimal decision features.
Utility class interface which is used for creating and evaluating features.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.