Point Cloud Library (PCL) 1.13.1
Loading...
Searching...
No Matches
regression_variance_stats_estimator.h
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/branch_estimator.h>
42#include <pcl/ml/stats_estimator.h>
43
44#include <istream>
45#include <ostream>
46
47namespace pcl {
48
49/** Node for a regression trees which optimizes variance. */
50template <class FeatureType, class LabelType>
51class PCL_EXPORTS RegressionVarianceNode {
52public:
53 /** Constructor. */
54 RegressionVarianceNode() : value(0), variance(0), threshold(0), sub_nodes() {}
55
56 /** Serializes the node to the specified stream.
57 *
58 * \param[out] stream the destination for the serialization
59 */
60 inline void
61 serialize(std::ostream& stream) const
62 {
63 feature.serialize(stream);
64
65 stream.write(reinterpret_cast<const char*>(&threshold), sizeof(threshold));
66
67 stream.write(reinterpret_cast<const char*>(&value), sizeof(value));
68 stream.write(reinterpret_cast<const char*>(&variance), sizeof(variance));
69
70 const int num_of_sub_nodes = static_cast<int>(sub_nodes.size());
71 stream.write(reinterpret_cast<const char*>(&num_of_sub_nodes),
72 sizeof(num_of_sub_nodes));
73 for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
74 sub_nodes[sub_node_index].serialize(stream);
75 }
76 }
77
78 /** Deserializes a node from the specified stream.
79 *
80 * \param[in] stream the source for the deserialization
81 */
82 inline void
83 deserialize(std::istream& stream)
84 {
85 feature.deserialize(stream);
86
87 stream.read(reinterpret_cast<char*>(&threshold), sizeof(threshold));
88
89 stream.read(reinterpret_cast<char*>(&value), sizeof(value));
90 stream.read(reinterpret_cast<char*>(&variance), sizeof(variance));
91
92 int num_of_sub_nodes;
93 stream.read(reinterpret_cast<char*>(&num_of_sub_nodes), sizeof(num_of_sub_nodes));
94 sub_nodes.resize(num_of_sub_nodes);
95
96 if (num_of_sub_nodes > 0) {
97 for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
98 ++sub_node_index) {
99 sub_nodes[sub_node_index].deserialize(stream);
100 }
101 }
102 }
103
104public:
105 /** The feature associated with the node. */
106 FeatureType feature;
107
108 /** The threshold applied on the feature response. */
110
111 /** The label value of this node. */
112 LabelType value;
113
114 /** The variance of the labels that ended up at this node during training. */
115 LabelType variance;
116
117 /** The child nodes. */
118 std::vector<RegressionVarianceNode> sub_nodes;
119};
120
121/** Statistics estimator for regression trees which optimizes variance. */
122template <class LabelDataType, class NodeType, class DataSet, class ExampleIndex>
124: public pcl::StatsEstimator<LabelDataType, NodeType, DataSet, ExampleIndex> {
125public:
126 /** Constructor. */
128 : branch_estimator_(branch_estimator)
129 {}
130
131 /** Returns the number of branches the corresponding tree has. */
132 inline std::size_t
134 {
135 // return 2;
136 return branch_estimator_->getNumOfBranches();
137 }
138
139 /** Returns the label of the specified node.
140 *
141 * \param[in] node the node which label is returned
142 */
143 inline LabelDataType
144 getLabelOfNode(NodeType& node) const
145 {
146 return node.value;
147 }
148
149 /** Computes the information gain obtained by the specified threshold.
150 *
151 * \param[in] data_set the data set corresponding to the supplied result data
152 * \param[in] examples the examples used for extracting the supplied result data
153 * \param[in] label_data the label data corresponding to the specified examples
154 * \param[in] results the results computed using the specified examples
155 * \param[in] flags the flags corresponding to the results
156 * \param[in] threshold the threshold for which the information gain is computed
157 */
158 float
159 computeInformationGain(DataSet& data_set,
160 std::vector<ExampleIndex>& examples,
161 std::vector<LabelDataType>& label_data,
162 std::vector<float>& results,
163 std::vector<unsigned char>& flags,
164 const float threshold) const
165 {
166 const std::size_t num_of_examples = examples.size();
167 const std::size_t num_of_branches = getNumOfBranches();
168
169 // compute variance
170 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
171 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
172 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
173
174 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
175 branch_element_count[branch_index] = 1;
176 ++branch_element_count[num_of_branches];
177 }
178
179 for (std::size_t example_index = 0; example_index < num_of_examples;
180 ++example_index) {
181 unsigned char branch_index;
182 computeBranchIndex(
183 results[example_index], flags[example_index], threshold, branch_index);
184
185 LabelDataType label = label_data[example_index];
186
187 sums[branch_index] += label;
188 sums[num_of_branches] += label;
189
190 sqr_sums[branch_index] += label * label;
191 sqr_sums[num_of_branches] += label * label;
192
193 ++branch_element_count[branch_index];
194 ++branch_element_count[num_of_branches];
195 }
196
197 std::vector<float> variances(num_of_branches + 1, 0);
198 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
199 ++branch_index) {
200 const float mean_sum =
201 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
202 const float mean_sqr_sum = static_cast<float>(sqr_sums[branch_index]) /
203 branch_element_count[branch_index];
204 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
205 }
206
207 float information_gain = variances[num_of_branches];
208 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
209 // const float weight = static_cast<float>(sums[branchIndex]) /
210 // sums[numOfBranches];
211 const float weight = static_cast<float>(branch_element_count[branch_index]) /
212 static_cast<float>(branch_element_count[num_of_branches]);
213 information_gain -= weight * variances[branch_index];
214 }
215
216 return information_gain;
217 }
218
219 /** Computes the branch indices for all supplied results.
220 *
221 * \param[in] results the results the branch indices will be computed for
222 * \param[in] flags the flags corresponding to the specified results
223 * \param[in] threshold the threshold used to compute the branch indices
224 * \param[out] branch_indices the destination for the computed branch indices
225 */
226 void
227 computeBranchIndices(std::vector<float>& results,
228 std::vector<unsigned char>& flags,
229 const float threshold,
230 std::vector<unsigned char>& branch_indices) const
231 {
232 const std::size_t num_of_results = results.size();
233 const std::size_t num_of_branches = getNumOfBranches();
234
235 branch_indices.resize(num_of_results);
236 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
237 unsigned char branch_index;
238 computeBranchIndex(
239 results[result_index], flags[result_index], threshold, branch_index);
240 branch_indices[result_index] = branch_index;
241 }
242 }
243
244 /** Computes the branch index for the specified result.
245 *
246 * \param[in] result the result the branch index will be computed for
247 * \param[in] flag the flag corresponding to the specified result
248 * \param[in] threshold the threshold used to compute the branch index
249 * \param[out] branch_index the destination for the computed branch index
250 */
251 inline void
252 computeBranchIndex(const float result,
253 const unsigned char flag,
254 const float threshold,
255 unsigned char& branch_index) const
256 {
257 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
258 // branch_index = (result > threshold) ? 1 : 0;
259 }
260
261 /** Computes and sets the statistics for a node.
262 *
263 * \param[in] data_set the data set which is evaluated
264 * \param[in] examples the examples which define which parts of the data set are use
265 * for evaluation
266 * \param[in] label_data the label_data corresponding to the examples
267 * \param[out] node the destination node for the statistics
268 */
269 void
270 computeAndSetNodeStats(DataSet& data_set,
271 std::vector<ExampleIndex>& examples,
272 std::vector<LabelDataType>& label_data,
273 NodeType& node) const
274 {
275 const std::size_t num_of_examples = examples.size();
276
277 LabelDataType sum = 0.0f;
278 LabelDataType sqr_sum = 0.0f;
279 for (std::size_t example_index = 0; example_index < num_of_examples;
280 ++example_index) {
281 const LabelDataType label = label_data[example_index];
282
283 sum += label;
284 sqr_sum += label * label;
285 }
286
287 sum /= num_of_examples;
288 sqr_sum /= num_of_examples;
289
290 const float variance = sqr_sum - sum * sum;
291
292 node.value = sum;
293 node.variance = variance;
294 }
295
296 /** Generates code for branch index computation.
297 *
298 * \param[in] node the node for which code is generated
299 * \param[out] stream the destination for the generated code
300 */
301 void
302 generateCodeForBranchIndexComputation(NodeType& node, std::ostream& stream) const
303 {
304 stream << "ERROR: RegressionVarianceStatsEstimator does not implement "
305 "generateCodeForBranchIndex(...)";
306 }
307
308 /** Generates code for label output.
309 *
310 * \param[in] node the node for which code is generated
311 * \param[out] stream the destination for the generated code
312 */
313 void
314 generateCodeForOutput(NodeType& node, std::ostream& stream) const
315 {
316 stream << "ERROR: RegressionVarianceStatsEstimator does not implement "
317 "generateCodeForBranchIndex(...)";
318 }
319
320private:
321 /// The branch estimator
322 pcl::BranchEstimator* branch_estimator_;
323};
324
325} // namespace pcl
Interface for branch estimators.
Node for a regression trees which optimizes variance.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
LabelType variance
The variance of the labels that ended up at this node during training.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
float threshold
The threshold applied on the feature response.
FeatureType feature
The feature associated with the node.
LabelType value
The label value of this node.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
Statistics estimator for regression trees which optimizes variance.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
std::size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.