113 std::vector<ExampleIndex>& examples,
114 std::vector<LabelType>& label_data,
115 const std::size_t max_depth,
118 const std::size_t num_of_examples = examples.size();
119 if (num_of_examples == 0) {
121 "Reached invalid point in decision tree training: Number of examples is 0!\n");
125 if (max_depth == 0) {
126 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
130 if (examples.size() < min_examples_for_split_) {
131 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
135 if (random_features_at_split_node_) {
137 feature_handler_->createRandomFeatures(num_of_features_, features);
140 std::vector<float> feature_results;
141 std::vector<unsigned char> flags;
143 feature_results.reserve(num_of_examples);
144 flags.reserve(num_of_examples);
147 int best_feature_index = -1;
148 float best_feature_threshold = 0.0f;
149 float best_feature_information_gain = 0.0f;
151 const std::size_t num_of_features = features.size();
152 for (std::size_t feature_index = 0; feature_index < num_of_features;
155 feature_handler_->evaluateFeature(
156 features[feature_index], data_set_, examples, feature_results, flags);
159 if (!thresholds_.empty()) {
162 for (
const float& threshold : thresholds_) {
164 const float information_gain = stats_estimator_->computeInformationGain(
165 data_set_, examples, label_data, feature_results, flags, threshold);
167 if (information_gain > best_feature_information_gain) {
168 best_feature_information_gain = information_gain;
169 best_feature_index =
static_cast<int>(feature_index);
170 best_feature_threshold = threshold;
175 std::vector<float> thresholds;
176 thresholds.reserve(num_of_thresholds_);
177 createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
181 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
183 const float threshold = thresholds[threshold_index];
186 const float information_gain = stats_estimator_->computeInformationGain(
187 data_set_, examples, label_data, feature_results, flags, threshold);
189 if (information_gain > best_feature_information_gain) {
190 best_feature_information_gain = information_gain;
191 best_feature_index =
static_cast<int>(feature_index);
192 best_feature_threshold = threshold;
198 if (best_feature_index == -1) {
199 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
204 std::vector<unsigned char> branch_indices;
205 branch_indices.reserve(num_of_examples);
207 feature_handler_->evaluateFeature(
208 features[best_feature_index], data_set_, examples, feature_results, flags);
210 stats_estimator_->computeBranchIndices(
211 feature_results, flags, best_feature_threshold, branch_indices);
214 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
218 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
220 std::vector<std::size_t> branch_counts(num_of_branches, 0);
221 for (std::size_t example_index = 0; example_index < num_of_examples;
223 ++branch_counts[branch_indices[example_index]];
226 node.feature = features[best_feature_index];
227 node.threshold = best_feature_threshold;
228 node.sub_nodes.resize(num_of_branches);
230 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
231 if (branch_counts[branch_index] == 0) {
232 NodeType branch_node;
233 stats_estimator_->computeAndSetNodeStats(
234 data_set_, examples, label_data, branch_node);
237 node.sub_nodes[branch_index] = branch_node;
242 std::vector<LabelType> branch_labels;
243 std::vector<ExampleIndex> branch_examples;
244 branch_labels.reserve(branch_counts[branch_index]);
245 branch_examples.reserve(branch_counts[branch_index]);
247 for (std::size_t example_index = 0; example_index < num_of_examples;
249 if (branch_indices[example_index] == branch_index) {
250 branch_examples.push_back(examples[example_index]);
251 branch_labels.push_back(label_data[example_index]);
255 trainDecisionTreeNode(features,
259 node.sub_nodes[branch_index]);