forked from imbs-hl/ranger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathForestProbability.h
86 lines (68 loc) · 2.86 KB
/
ForestProbability.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/*-------------------------------------------------------------------------------
This file is part of Ranger.
Ranger is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Ranger is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Ranger. If not, see <http://www.gnu.org/licenses/>.
Written by:
Marvin N. Wright
Institut für Medizinische Biometrie und Statistik
Universität zu Lübeck
Ratzeburger Allee 160
23562 Lübeck
Germany
http://www.imbs-luebeck.de
#-------------------------------------------------------------------------------*/
#ifndef FORESTPROBABILITY_H_
#define FORESTPROBABILITY_H_
#include <map>
#include <utility>
#include <vector>
#include "globals.h"
#include "Forest.h"
#include "TreeProbability.h"
class ForestProbability: public Forest {
public:
ForestProbability();
virtual ~ForestProbability();
void loadForest(size_t dependent_varID, size_t num_trees,
std::vector<std::vector<std::vector<size_t>> >& forest_child_nodeIDs,
std::vector<std::vector<size_t>>& forest_split_varIDs, std::vector<std::vector<double>>& forest_split_values,
std::vector<double>& class_values, std::vector<std::vector<std::vector<double>>>& forest_terminal_class_counts, std::vector<bool>& is_ordered_variable);
std::vector<std::vector<std::vector<double>>> getTerminalClassCounts() {
std::vector<std::vector<std::vector<double>>> result;
result.reserve(num_trees);
for (Tree* tree : trees) {
TreeProbability* temp = (TreeProbability*) tree;
result.push_back(temp->getTerminalClassCounts());
}
return result;
}
const std::vector<double>& getClassValues() const {
return class_values;
}
protected:
void initInternal(std::string status_variable_name);
void growInternal();
void predictInternal();
void computePredictionErrorInternal();
void writeOutputInternal();
void writeConfusionFile();
void writePredictionFile();
void saveToFileInternal(std::ofstream& outfile);
void loadFromFileInternal(std::ifstream& infile);
// Classes of the dependent variable and classIDs for responses
std::vector<double> class_values;
std::vector<uint> response_classIDs;
// Table with classifications and true classes
std::map<std::pair<double, double>, size_t> classification_table;
private:
DISALLOW_COPY_AND_ASSIGN(ForestProbability);
};
#endif /* FORESTPROBABILITY_H_ */