forked from imbs-hl/ranger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathData.h
235 lines (183 loc) · 6.39 KB
/
Data.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
/*-------------------------------------------------------------------------------
This file is part of ranger.
Copyright (c) [2014-2018] [Marvin N. Wright]
This software may be modified and distributed under the terms of the MIT license.
Please note that the C++ core of ranger is distributed under MIT license and the
R package "ranger" under GPL3 license.
#-------------------------------------------------------------------------------*/
#ifndef DATA_H_
#define DATA_H_
#include <vector>
#include <iostream>
#include <numeric>
#include <random>
#include <algorithm>
#include "globals.h"
namespace ranger {
class Data {
public:
Data();
Data(const Data&) = delete;
Data& operator=(const Data&) = delete;
virtual ~Data() = default;
virtual double get(size_t row, size_t col) const = 0;
size_t getVariableID(const std::string& variable_name) const;
virtual void reserveMemory() = 0;
virtual void set(size_t col, size_t row, double value, bool& error) = 0;
void addSnpData(unsigned char* snp_data, size_t num_cols_snp);
bool loadFromFile(std::string filename);
bool loadFromFileWhitespace(std::ifstream& input_file, std::string header_line);
bool loadFromFileOther(std::ifstream& input_file, std::string header_line, char seperator);
void getAllValues(std::vector<double>& all_values, std::vector<size_t>& sampleIDs, size_t varID) const;
void getMinMaxValues(double& min, double&max, std::vector<size_t>& sampleIDs, size_t varID) const;
size_t getIndex(size_t row, size_t col) const {
// Use permuted data for corrected impurity importance
size_t col_permuted = col;
if (col >= num_cols) {
col = getUnpermutedVarID(col);
row = getPermutedSampleID(row);
}
if (col < num_cols_no_snp) {
return index_data[col * num_rows + row];
} else {
return getSnp(row, col, col_permuted);
}
}
size_t getSnp(size_t row, size_t col, size_t col_permuted) const {
// Get data out of snp storage. -1 because of GenABEL coding.
size_t idx = (col - num_cols_no_snp) * num_rows_rounded + row;
size_t result = ((snp_data[idx / 4] & mask[idx % 4]) >> offset[idx % 4]) - 1;
// TODO: Better way to treat missing values?
if (result > 2) {
result = 0;
}
// Order SNPs
if (order_snps) {
if (col_permuted >= num_cols) {
result = snp_order[col_permuted + no_split_variables.size() - 2 * num_cols_no_snp][result];
} else {
result = snp_order[col - num_cols_no_snp][result];
}
}
return result;
}
double getUniqueDataValue(size_t varID, size_t index) const {
// Use permuted data for corrected impurity importance
if (varID >= num_cols) {
varID = getUnpermutedVarID(varID);
}
if (varID < num_cols_no_snp) {
return unique_data_values[varID][index];
} else {
// For GWAS data the index is the value
return (index);
}
}
size_t getNumUniqueDataValues(size_t varID) const {
// Use permuted data for corrected impurity importance
if (varID >= num_cols) {
varID = getUnpermutedVarID(varID);
}
if (varID < num_cols_no_snp) {
return unique_data_values[varID].size();
} else {
// For GWAS data 0,1,2
return (3);
}
}
void sort();
void orderSnpLevels(std::string dependent_variable_name, bool corrected_importance);
const std::vector<std::string>& getVariableNames() const {
return variable_names;
}
size_t getNumCols() const {
return num_cols;
}
size_t getNumRows() const {
return num_rows;
}
size_t getMaxNumUniqueValues() const {
if (snp_data == 0 || max_num_unique_values > 3) {
// If no snp data or one variable with more than 3 unique values, return that value
return max_num_unique_values;
} else {
// If snp data and no variable with more than 3 unique values, return 3
return 3;
}
}
const std::vector<size_t>& getNoSplitVariables() const noexcept {
return no_split_variables;
}
void addNoSplitVariable(size_t varID) {
no_split_variables.push_back(varID);
std::sort(no_split_variables.begin(), no_split_variables.end());
}
std::vector<bool>& getIsOrderedVariable() noexcept {
return is_ordered_variable;
}
void setIsOrderedVariable(const std::vector<std::string>& unordered_variable_names) {
is_ordered_variable.resize(num_cols, true);
for (auto& variable_name : unordered_variable_names) {
size_t varID = getVariableID(variable_name);
is_ordered_variable[varID] = false;
}
}
void setIsOrderedVariable(std::vector<bool>& is_ordered_variable) {
this->is_ordered_variable = is_ordered_variable;
}
bool isOrderedVariable(size_t varID) const {
// Use permuted data for corrected impurity importance
if (varID >= num_cols) {
varID = getUnpermutedVarID(varID);
}
return is_ordered_variable[varID];
}
void permuteSampleIDs(std::mt19937_64 random_number_generator) {
permuted_sampleIDs.resize(num_rows);
std::iota(permuted_sampleIDs.begin(), permuted_sampleIDs.end(), 0);
std::shuffle(permuted_sampleIDs.begin(), permuted_sampleIDs.end(), random_number_generator);
}
size_t getPermutedSampleID(size_t sampleID) const {
return permuted_sampleIDs[sampleID];
}
size_t getUnpermutedVarID(size_t varID) const {
if (varID >= num_cols) {
varID -= num_cols;
for (auto& skip : no_split_variables) {
if (varID >= skip) {
++varID;
}
}
}
return varID;
}
const std::vector<std::vector<size_t>>& getSnpOrder() const {
return snp_order;
}
void setSnpOrder(std::vector<std::vector<size_t>>& snp_order) {
this->snp_order = snp_order;
order_snps = true;
}
protected:
std::vector<std::string> variable_names;
size_t num_rows;
size_t num_rows_rounded;
size_t num_cols;
unsigned char* snp_data;
size_t num_cols_no_snp;
bool externalData;
std::vector<size_t> index_data;
std::vector<std::vector<double>> unique_data_values;
size_t max_num_unique_values;
// Variable to not split at (only dependent_varID for non-survival trees)
std::vector<size_t> no_split_variables;
// For each varID true if ordered
std::vector<bool> is_ordered_variable;
// Permuted samples for corrected impurity importance
std::vector<size_t> permuted_sampleIDs;
// Order of 0/1/2 for ordered splitting
std::vector<std::vector<size_t>> snp_order;
bool order_snps;
};
} // namespace ranger
#endif /* DATA_H_ */