forked from google/longbet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathX_struct.h
123 lines (101 loc) · 4.14 KB
/
X_struct.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#ifndef GUARD_X_struct_h
#define GUARD_X_struct_h
#include <iostream>
#include <vector>
#include "common.h"
#include "utility.h"
struct X_struct
{
public:
// Vector pointers
// std::vector<matrix<std::vector<double> *>> data_pointers;
matrix<std::vector<double> *> data_pointers;
std::vector<double> X_values;
std::vector<size_t> X_counts;
std::vector<size_t> variable_ind;
std::vector<size_t> X_num_unique;
std::vector<double> t_values;
std::vector<size_t> t_counts;
std::vector<size_t> t_variable_ind;
std::vector<size_t> t_num_unique;
matrix<double> cov_kernel;
std::vector<double> s_values;
std::vector<size_t> s_counts;
std::vector<size_t> s_variable_ind;
std::vector<size_t> s_num_unique;
const double *X_std; // pointer to original data
const double *y_std; // pointer to y data
const double *t_std; // pointer to t data
const double *Tpt; // pointer to s (cumulative treatment time)
size_t n_y; // number of total data points in root node
size_t p_y;
size_t p_continuous;
size_t p_x;
size_t n_t;
size_t p_t;
X_struct(const double *X_std, const double *y_std, const double *t_std, const double *Tpt, std::vector<double> s_values,
size_t n_y, size_t p_y, std::vector<std::vector<size_t>> &Xorder_std,
std::vector<std::vector<size_t>> &torder_std, std::vector<std::vector<size_t>> &sorder_std,
size_t p_categorical, size_t p_continuous,
std::vector<double> *initial_theta, size_t num_trees,
double &sig_knl, double &lambda_knl)
{
this->variable_ind = std::vector<size_t>(p_categorical + 1);
this->X_num_unique = std::vector<size_t>(p_categorical);
init_tree_pointers(initial_theta, num_trees, n_y, p_y);
// std::cout << "ini dp size = " << this->data_pointers[0].size() << endl;
unique_value_count2(X_std, Xorder_std, X_values, X_counts, variable_ind, n_y, X_num_unique, p_categorical, p_continuous);
size_t t_categorical = 1;
size_t t_continuous = 0;
this->t_variable_ind = std::vector<size_t>(t_categorical + 1);
this->t_num_unique = std::vector<size_t>(t_categorical);
unique_value_count2(t_std, torder_std, t_values, t_counts, t_variable_ind, p_y, t_num_unique, t_categorical, t_continuous);
this->X_std = X_std;
this->y_std = y_std;
this->t_std = t_std;
this->Tpt = Tpt;
this->s_values = s_values;
this->n_y = n_y;
this->p_y = p_y;
this->p_continuous = p_continuous;
this->p_x = Xorder_std.size();
this->n_t = torder_std[0].size();
this->p_t = torder_std.size();
ini_cov_kernel(sig_knl, lambda_knl);
// std::cout << "t_values " << t_values << endl;
// std::cout << "cov " << cov_kernel << endl;
}
void init_tree_pointers(std::vector<double> *initial_theta, size_t num_trees, size_t N, size_t p_y)
{
ini_matrix(this->data_pointers, p_y * N, num_trees);
for (size_t i = 0; i < num_trees; i++)
{
for (size_t j = 0; j < N; j++)
{
std::vector<std::vector<double> *> &pointer_vec = this->data_pointers[i];
for (size_t k = 0; k < p_y; k++){
pointer_vec[j * p_y + k] = initial_theta;
// data_pointers[i][j * p_y + k] = initial_theta;
}
}
}
}
void ini_cov_kernel(double &sig_knl, double &lambda_knl){
double sigma2 = pow(sig_knl, 2);
double lambda2 = pow(lambda_knl, 2);
size_t t_size = t_values.size();
ini_matrix(this->cov_kernel, t_size, t_size);
double diag = squared_exponential(t_values[0], t_values[0], sigma2, lambda2);
for (size_t i = 0; i < t_size; i++){
// calculate diagonal element
cov_kernel[i][i] = diag;
for (size_t j = 0; j < i; j++){
cov_kernel[i][j] = squared_exponential(t_values[i], t_values[j],
sigma2, lambda2);
cov_kernel[j][i] = cov_kernel[i][j];
}
}
// std::cout << "cov_kernel = " << cov_kernel << endl;
}
};
#endif