forked from imbs-hl/ranger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTree.h
181 lines (131 loc) · 5.26 KB
/
Tree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
/*-------------------------------------------------------------------------------
This file is part of Ranger.
Ranger is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Ranger is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Ranger. If not, see <http://www.gnu.org/licenses/>.
Written by:
Marvin N. Wright
Institut für Medizinische Biometrie und Statistik
Universität zu Lübeck
Ratzeburger Allee 160
23562 Lübeck
Germany
http://www.imbs-luebeck.de
#-------------------------------------------------------------------------------*/
#ifndef TREE_H_
#define TREE_H_
#include <vector>
#include <random>
#include <iostream>
#include "globals.h"
#include "Data.h"
class Tree {
public:
Tree();
// Create from loaded forest
Tree(std::vector<std::vector<size_t>>& child_nodeIDs, std::vector<size_t>& split_varIDs,
std::vector<double>& split_values);
virtual ~Tree();
void init(Data* data, uint mtry, size_t dependent_varID, size_t num_samples, uint seed,
std::vector<size_t>* deterministic_varIDs, std::vector<size_t>* split_select_varIDs,
std::vector<double>* split_select_weights, ImportanceMode importance_mode, uint min_node_size,
bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule,
std::vector<double>* case_weights, bool keep_inbag, double sample_fraction, double alpha, double minprop,
bool holdout, uint num_random_splits);
virtual void initInternal() = 0;
void grow(std::vector<double>* variable_importance);
void predict(const Data* prediction_data, bool oob_prediction);
void computePermutationImportance(std::vector<double>* forest_importance, std::vector<double>* forest_variance);
void appendToFile(std::ofstream& file);
virtual void appendToFileInternal(std::ofstream& file) = 0;
const std::vector<std::vector<size_t> >& getChildNodeIDs() const {
return child_nodeIDs;
}
const std::vector<double>& getSplitValues() const {
return split_values;
}
const std::vector<size_t>& getSplitVarIDs() const {
return split_varIDs;
}
const std::vector<size_t>& getOobSampleIDs() const {
return oob_sampleIDs;
}
size_t getNumSamplesOob() const {
return num_samples_oob;
}
const std::vector<size_t>& getInbagCounts() const {
return inbag_counts;
}
protected:
void createPossibleSplitVarSubset(std::vector<size_t>& result);
bool splitNode(size_t nodeID);
virtual bool splitNodeInternal(size_t nodeID, std::vector<size_t>& possible_split_varIDs) = 0;
void createEmptyNode();
virtual void createEmptyNodeInternal() = 0;
size_t dropDownSamplePermuted(size_t permuted_varID, size_t sampleID, size_t permuted_sampleID);
void permuteAndPredictOobSamples(size_t permuted_varID, std::vector<size_t>& permutations);
virtual double computePredictionAccuracyInternal() = 0;
void bootstrap();
void bootstrapWithoutReplacement();
void bootstrapWeighted();
void bootstrapWithoutReplacementWeighted();
virtual void cleanUpInternal() = 0;
size_t dependent_varID;
uint mtry;
// Number of samples (all samples, not only inbag for this tree)
size_t num_samples;
// Number of OOB samples
size_t num_samples_oob;
// Minimum node size to split, like in original RF nodes of smaller size can be produced
uint min_node_size;
// Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable
// Deterministic variables are always selected
std::vector<size_t>* deterministic_varIDs;
std::vector<size_t>* split_select_varIDs;
std::vector<double>* split_select_weights;
// Bootstrap weights
std::vector<double>* case_weights;
// Splitting variable for each node
std::vector<size_t> split_varIDs;
// Value to split at for each node, for now only binary split
// For terminal nodes the prediction value is saved here
std::vector<double> split_values;
// Vector of left and right child node IDs, 0 for no child
std::vector<std::vector<size_t>> child_nodeIDs;
// For each node a vector with IDs of samples in node
std::vector<std::vector<size_t>> sampleIDs;
// IDs of OOB individuals, sorted
std::vector<size_t> oob_sampleIDs;
// Holdout mode
bool holdout;
// Inbag counts
bool keep_inbag;
std::vector<size_t> inbag_counts;
// Random number generator
std::mt19937_64 random_number_generator;
// Pointer to original data
Data* data;
// Variable importance for all variables
std::vector<double>* variable_importance;
ImportanceMode importance_mode;
// When growing here the OOB set is used
// Terminal nodeIDs for prediction samples
std::vector<size_t> prediction_terminal_nodeIDs;
bool sample_with_replacement;
double sample_fraction;
bool memory_saving_splitting;
SplitRule splitrule;
double alpha;
double minprop;
uint num_random_splits;
private:
DISALLOW_COPY_AND_ASSIGN(Tree);
};
#endif /* TREE_H_ */